mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-20 16:35:47 +08:00
first commit
This commit is contained in:
110
eval/evaluate_gsm8k.py
Normal file
110
eval/evaluate_gsm8k.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import random
|
||||
import tqdm
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import torch
|
||||
import numpy as np
|
||||
import jsonlines
|
||||
import argparse
|
||||
import jsonlines
|
||||
import datasets
|
||||
from datasets import load_from_disk,load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
|
||||
ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
|
||||
INVALID_ANS = "[invalid]"
|
||||
|
||||
def doc_to_text(doc):
|
||||
return fewshot_prompt + "\nQuestion: " + doc["question"] + "\nLet's think step by step\n"
|
||||
|
||||
def decode(tokens_list, tokenizer, raw_text_len):
|
||||
sents = []
|
||||
# print(len(tokens_list))
|
||||
for tokens in tokens_list:
|
||||
tokens = tokens.cpu().numpy().tolist()
|
||||
sent = tokenizer.tokenizer.decode(
|
||||
tokens[raw_text_len:])
|
||||
sent = sent.split('<|endoftext|>')[0]
|
||||
sent = sent.split('\n\n\n')[0]
|
||||
sent = sent.split("\n\n")[0]
|
||||
sent = sent.split("Question:")[0]
|
||||
sents.append(sent)
|
||||
return sents
|
||||
|
||||
def generate_sample(model, tokenizer, input_txt):
|
||||
input_ids = tokenizer.tokenizer.encode(input_txt)
|
||||
raw_text_len = len(input_ids)
|
||||
context_enc = torch.tensor(
|
||||
[input_ids]).to(model.device)
|
||||
print(f"Input text: {input_txt}\n")
|
||||
outputs = model.generate(context_enc)
|
||||
output_text = decode(outputs,tokenizer,raw_text_len)[0]
|
||||
print(f"\nOutput text: {output_text}\n")
|
||||
return output_text
|
||||
|
||||
|
||||
def extract_answer_hf(completion):
|
||||
match = ANS_RE.search(completion)
|
||||
if match:
|
||||
match_str = match.group(1).strip()
|
||||
match_str = match_str.replace(",", "")
|
||||
return eval(match_str)
|
||||
else:
|
||||
return INVALID_ANS
|
||||
|
||||
def extract_answer(completion):
|
||||
try:
|
||||
last_number = re.findall(r'\d+', completion)[-1]
|
||||
return eval(last_number)
|
||||
except:
|
||||
return INVALID_ANS
|
||||
|
||||
def is_correct( completion, answer):
|
||||
gold = extract_answer_hf(answer)
|
||||
assert gold != INVALID_ANS, "No ground truth answer found in the document."
|
||||
return extract_answer(completion) == gold
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser(description='Test HF checkpoint.')
|
||||
parser.add_argument("-c", "--checkpoint-path", type=str, help="Checkpoint path", default="Qwen/Qwen-7B")
|
||||
parser.add_argument("-f","--sample-input-file", type=str, default=None)
|
||||
parser.add_argument("-o","--sample-output-file", type=str, default="gsm8k_res.jsonl")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
fewshot_prompt = open("gsm8k_prompt.txt").read()
|
||||
if args.sample_input_file is not None:
|
||||
dataset = load_from_disk(args.sample_input_file)
|
||||
else:
|
||||
config = datasets.DownloadConfig(resume_download=True, max_retries=100)
|
||||
dataset = load_dataset("gsm8k", 'main', download_config=config)
|
||||
|
||||
test = dataset["test"]
|
||||
|
||||
print('Loading tokenizer ...')
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
|
||||
|
||||
print('Loading model ...')
|
||||
model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval()
|
||||
model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
|
||||
model.generation_config.do_sample = False
|
||||
|
||||
f_output = jsonlines.Writer(open(args.sample_output_file, 'w', encoding='utf-8'))
|
||||
tot_length = test.num_rows
|
||||
acc_res = []
|
||||
for doc in test:
|
||||
context = doc_to_text(doc)
|
||||
completion = generate_sample(model, tokenizer, context)
|
||||
answer= doc["answer"]
|
||||
acc = is_correct(completion, answer)
|
||||
doc["completion"]=completion
|
||||
doc["acc"]=acc
|
||||
f_output.write(doc)
|
||||
acc_res.append(acc)
|
||||
|
||||
f_output.close()
|
||||
print("Acc: ",np.mean(acc_res))
|
||||
Reference in New Issue
Block a user