mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-20 16:35:47 +08:00
fix bug for ceval
This commit is contained in:
@@ -80,7 +80,7 @@ def eval_subject(
|
|||||||
score = []
|
score = []
|
||||||
|
|
||||||
few_shot_prompt = generate_few_shot_prompt(
|
few_shot_prompt = generate_few_shot_prompt(
|
||||||
k, subject_name, dev_df) if few_shot else []
|
k, subject_name, dev_df) if few_shot else ''
|
||||||
all_probs = {'prob_A': [], 'prob_B': [], 'prob_C': [], 'prob_D': []}
|
all_probs = {'prob_A': [], 'prob_B': [], 'prob_C': [], 'prob_D': []}
|
||||||
if args.debug: print(f"few_shot_prompt: {few_shot_prompt}")
|
if args.debug: print(f"few_shot_prompt: {few_shot_prompt}")
|
||||||
|
|
||||||
@@ -95,10 +95,10 @@ def eval_subject(
|
|||||||
softval = torch.nn.functional.softmax(
|
softval = torch.nn.functional.softmax(
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[
|
[
|
||||||
logits[tokenizer("A")['input_ids']],
|
logits[tokenizer("A")['input_ids'][-1]],
|
||||||
logits[tokenizer("B")['input_ids']],
|
logits[tokenizer("B")['input_ids'][-1]],
|
||||||
logits[tokenizer("C")['input_ids']],
|
logits[tokenizer("C")['input_ids'][-1]],
|
||||||
logits[tokenizer("D")['input_ids']],
|
logits[tokenizer("D")['input_ids'][-1]],
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
dim=0,
|
dim=0,
|
||||||
|
|||||||
Reference in New Issue
Block a user