update evaluate scripts

This commit is contained in:
yangapku
2023-10-30 19:13:14 +08:00
parent 83368388aa
commit c00209f932
3 changed files with 147 additions and 93 deletions

View File

@@ -20,13 +20,22 @@ python evaluate_ceval.py -d data/ceval/
def load_models_tokenizer(args):
tokenizer = AutoTokenizer.from_pretrained(
args.checkpoint_path, trust_remote_code=True
args.checkpoint_path,
pad_token='<|extra_0|>',
eos_token='<|endoftext|>',
padding_side='left',
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path, device_map="auto", trust_remote_code=True
args.checkpoint_path,
pad_token_id=tokenizer.pad_token_id,
device_map="auto",
trust_remote_code=True
).eval()
model.generation_config = GenerationConfig.from_pretrained(
args.checkpoint_path, trust_remote_code=True
args.checkpoint_path,
pad_token_id=tokenizer.pad_token_id,
trust_remote_code=True
)
return model, tokenizer
@@ -56,11 +65,12 @@ def generate_few_shot_prompt(k, subject, dev_df):
def get_logits(tokenizer, model, inputs: List[str]):
input_ids = tokenizer(inputs, padding=False)["input_ids"]
input_ids = tokenizer(inputs, padding='longest')["input_ids"]
input_ids = torch.tensor(input_ids, device=model.device)
tokens = {"input_ids": input_ids}
attention_mask = input_ids.ne(tokenizer.pad_token_id)
outputs = model(input_ids)["logits"]
outputs = model(input_ids, attention_mask=attention_mask)["logits"]
logits = outputs[:, -1, :]
log_probs = torch.nn.functional.softmax(logits, dim=-1)
return log_probs, {"tokens": tokens}
@@ -76,6 +86,7 @@ def eval_subject(
dev_df=None,
few_shot=False,
save_result_dir=None,
batch_size=1,
**kwargs,
):
result = []
@@ -88,39 +99,39 @@ def eval_subject(
if args.debug:
print(f"few_shot_prompt: {few_shot_prompt}")
for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
question = format_example(row, include_answer=False)
full_prompt = few_shot_prompt + question
choices_ids = torch.tensor(
tokenizer("A")["input_ids"] + tokenizer("B")["input_ids"] +
tokenizer("C")["input_ids"] + tokenizer("D")["input_ids"]
).unsqueeze(0).to(model.device)
output, input_info = get_logits(tokenizer, model, [full_prompt])
assert output.shape[0] == 1
logits = output.flatten()
idx_list = list(range(0, len(test_df), batch_size))
for i in tqdm(idx_list):
full_prompt_list = []
answer_list = []
for row in test_df.iloc[i:i+batch_size].to_dict(orient='records'):
question = format_example(row, include_answer=False)
full_prompt = few_shot_prompt + question
full_prompt_list.append(full_prompt)
if 'answer' in row:
answer_list.append(row['answer'])
softval = torch.nn.functional.softmax(
torch.tensor(
[
logits[tokenizer("A")["input_ids"]],
logits[tokenizer("B")["input_ids"]],
logits[tokenizer("C")["input_ids"]],
logits[tokenizer("D")["input_ids"]],
]
),
dim=0,
)
logits, input_info = get_logits(tokenizer, model, full_prompt_list)
softval = logits.gather(1, choices_ids.expand(logits.size(0), -1)).softmax(1)
if softval.dtype in {torch.bfloat16, torch.float16}:
softval = softval.to(dtype=torch.float32)
probs = softval.detach().cpu().numpy()
for i, choice in enumerate(choices):
all_probs[f"prob_{choice}"].append(probs[i])
pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
for i in range(len(probs)):
for j, choice in enumerate(choices):
all_probs[f"prob_{choice}"].append(probs[i][j])
pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs[i])]
if "answer" in row:
correct = 1 if pred == row["answer"] else 0
score.append(correct)
if args.debug:
print(f'{question} pred: {pred} ref: {row["answer"]}')
result.append(pred)
if answer_list != []:
correct = 1 if pred == answer_list[i] else 0
score.append(correct)
if args.debug:
print(f'{question} pred: {pred} ref: {answer_list[i]}')
result.append(pred)
if score:
correct_ratio = 100 * sum(score) / len(score)
@@ -395,6 +406,7 @@ def main(args):
k=5,
few_shot=True,
save_result_dir=f"outs/ceval_eval_result",
batch_size=args.batch_size
)
dev_result[subject_name] = score
cal_ceval(dev_result)
@@ -425,6 +437,12 @@ if __name__ == "__main__":
group.add_argument(
"--debug", action="store_true", default=False, help="Print infos."
)
group.add_argument(
"--batch-size",
type=int,
default=1,
help="batch size",
)
args = parser.parse_args()
set_seed(args.seed)