mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-20 16:35:47 +08:00
specify repetition penalty
This commit is contained in:
@@ -31,6 +31,7 @@ def load_models_tokenizer(args):
|
||||
args.checkpoint_path, trust_remote_code=True
|
||||
)
|
||||
model.generation_config.do_sample = False # use greedy decoding
|
||||
model.generation_config.repetition_penalty = 1.0 # disable repetition penalty
|
||||
return model, tokenizer
|
||||
|
||||
def process_before_extraction(gen, question, choice_dict):
|
||||
|
||||
@@ -129,6 +129,7 @@ if __name__ == "__main__":
|
||||
args.checkpoint_path, trust_remote_code=True
|
||||
)
|
||||
model.generation_config.do_sample = False # use greedy decoding
|
||||
model.generation_config.repetition_penalty = 1.0 # disable repetition penalty
|
||||
|
||||
test = dataset["test"]
|
||||
|
||||
|
||||
@@ -93,6 +93,7 @@ if __name__ == "__main__":
|
||||
args.checkpoint_path, trust_remote_code=True
|
||||
)
|
||||
model.generation_config.do_sample = False # use greedy decoding
|
||||
model.generation_config.repetition_penalty = 1.0 # disable repetition penalty
|
||||
|
||||
f_output = jsonlines.Writer(open(args.sample_output_file, "w", encoding="utf-8"))
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ def load_models_tokenizer(args):
|
||||
args.checkpoint_path, trust_remote_code=True
|
||||
)
|
||||
model.generation_config.do_sample = False # use greedy decoding
|
||||
model.generation_config.repetition_penalty = 1.0 # disable repetition penalty
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user