mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-21 00:45:48 +08:00
fix single-gpu qlora, and add profiling
This commit is contained in:
@@ -15,6 +15,7 @@ import transformers
|
||||
from transformers import Trainer, GPTQConfig, deepspeed
|
||||
from transformers.trainer_pt_utils import LabelSmoother
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
from accelerate.utils import DistributedType
|
||||
|
||||
|
||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||
@@ -264,6 +265,10 @@ def train():
|
||||
lora_args,
|
||||
) = parser.parse_args_into_dataclasses()
|
||||
|
||||
# This serves for single-gpu qlora.
|
||||
if getattr(training_args, 'deepspeed', None) and getattr(lora_args, 'q_lora', False):
|
||||
training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
|
||||
|
||||
compute_dtype = (
|
||||
torch.float16
|
||||
if training_args.fp16
|
||||
|
||||
Reference in New Issue
Block a user