print peft trainable params

This commit is contained in:
songt
2023-10-20 17:38:47 +08:00
committed by GitHub
parent 3e63f107fa
commit e46d65084a

View File

@@ -341,6 +341,9 @@ def train():
model = get_peft_model(model, lora_config) model = get_peft_model(model, lora_config)
# Print peft trainable params
model.print_trainable_parameters()
if training_args.gradient_checkpointing: if training_args.gradient_checkpointing:
model.enable_input_require_grads() model.enable_input_require_grads()