fix single-gpu qlora, and add profiling

This commit is contained in:
JustinLin610
2023-10-07 10:37:42 +08:00
parent 76a52cb2a8
commit b5fad3d561
9 changed files with 219 additions and 89 deletions

View File

@@ -15,6 +15,7 @@ import transformers
from transformers import Trainer, GPTQConfig, deepspeed
from transformers.trainer_pt_utils import LabelSmoother
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from accelerate.utils import DistributedType
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
@@ -264,6 +265,10 @@ def train():
lora_args,
) = parser.parse_args_into_dataclasses()
# This serves for single-gpu qlora.
if getattr(training_args, 'deepspeed', None) and getattr(lora_args, 'q_lora', False):
training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
compute_dtype = (
torch.float16
if training_args.fp16