Add Docker image for CUDA-12.1.

This commit is contained in:
苏阳
2024-01-08 14:22:05 +08:00
parent 4aab1d490b
commit 23a01b0696
10 changed files with 146 additions and 5 deletions

View File

@@ -272,7 +272,7 @@ def train():
local_rank = training_args.local_rank
device_map = "auto"
device_map = None
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if lora_args.q_lora:
@@ -282,6 +282,19 @@ def train():
"FSDP or ZeRO3 are incompatible with QLoRA."
)
is_chat_model = 'chat' in model_args.model_name_or_path.lower()
if (
training_args.use_lora
and not lora_args.q_lora
and deepspeed.is_deepspeed_zero3_enabled()
and not is_chat_model
):
raise RuntimeError("ZeRO3 is incompatible with LoRA when finetuning on base model.")
model_load_kwargs = {}
if deepspeed.is_deepspeed_zero3_enabled():
model_load_kwargs['low_cpu_mem_usage'] = False
# Set RoPE scaling factor
config = transformers.AutoConfig.from_pretrained(
model_args.model_name_or_path,
@@ -302,6 +315,7 @@ def train():
)
if training_args.use_lora and lora_args.q_lora
else None,
**model_load_kwargs,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
@@ -314,7 +328,7 @@ def train():
tokenizer.pad_token_id = tokenizer.eod_id
if training_args.use_lora:
if lora_args.q_lora or 'chat' in model_args.model_name_or_path.lower():
if lora_args.q_lora or is_chat_model:
modules_to_save = None
else:
modules_to_save = ["wte", "lm_head"]