mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-20 16:35:47 +08:00
add 72B and 1.8B Qwen models, add Ascend 910 and Hygon DCU support, add docker support
This commit is contained in:
@@ -2,15 +2,17 @@ import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import requests
|
||||
import math
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from datasets import load_from_disk, load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
'''
|
||||
"""
|
||||
python eval/evaluate_chat_gsm8k.py [--use-fewshot]
|
||||
'''
|
||||
"""
|
||||
|
||||
INVALID_ANS = "[invalid]"
|
||||
DEVICE = "cuda:0"
|
||||
@@ -32,20 +34,6 @@ def doc_to_text(doc, use_fewshot):
|
||||
context = doc["question"]
|
||||
return context
|
||||
|
||||
|
||||
def decode(tokens_list, tokenizer, raw_text_len):
|
||||
sents = []
|
||||
for tokens in tokens_list:
|
||||
tokens = tokens.cpu().numpy().tolist()
|
||||
sent = tokenizer.tokenizer.decode(tokens[raw_text_len:])
|
||||
sent = sent.split("<|endoftext|>")[0]
|
||||
sent = sent.split("\n\n\n")[0]
|
||||
sent = sent.split("\n\n")[0]
|
||||
sent = sent.split("Question:")[0]
|
||||
sents.append(sent)
|
||||
return sents
|
||||
|
||||
|
||||
def generate_sample(model, tokenizer, question):
|
||||
response, _ = model.chat(
|
||||
tokenizer,
|
||||
@@ -58,40 +46,35 @@ def generate_sample(model, tokenizer, question):
|
||||
print("=============")
|
||||
return response
|
||||
|
||||
|
||||
def extract_answer_hf(completion):
|
||||
def _get_last_digit(s):
|
||||
_PAT_LAST_DIGIT = re.compile(
|
||||
r"(?<=(\s|[\$%#{]))([+-])?(?=(\S))(0|([1-9](\d*|\d{0,2}(,\d{3})*)))?(\.\d*[1-9])?(?=(\s|[.,}]|$))"
|
||||
)
|
||||
match = list(_PAT_LAST_DIGIT.finditer(s))
|
||||
if match:
|
||||
last_digit = match[-1].group().replace(",", "").replace("+", "")
|
||||
# print(f"The last digit in {s} is {last_digit}")
|
||||
else:
|
||||
last_digit = None
|
||||
print(f"No digits found in {s!r}")
|
||||
return last_digit
|
||||
|
||||
job_gen = completion.strip(".").replace("\n", "\\n")
|
||||
last_digit = _get_last_digit(job_gen)
|
||||
if last_digit is not None:
|
||||
return eval(last_digit)
|
||||
return INVALID_ANS
|
||||
|
||||
|
||||
def extract_answer(completion):
|
||||
try:
|
||||
last_number = re.findall(r"\d+", completion)[-1]
|
||||
return eval(last_number)
|
||||
except:
|
||||
return INVALID_ANS
|
||||
|
||||
def extract_answer(s):
|
||||
_PAT_LAST_DIGIT = re.compile(
|
||||
r"([+-])?(?=([0-9]|\.[0-9]))(0|([1-9](\d{0,2}(,\d{3})*)|\d*))?(\.\d*)?(?=\D|$)"
|
||||
)
|
||||
match = list(_PAT_LAST_DIGIT.finditer(s))
|
||||
if match:
|
||||
last_digit = match[-1].group().replace(",", "").replace("+", "").strip()
|
||||
# print(f"The last digit in {s} is {last_digit}")
|
||||
else:
|
||||
last_digit = None
|
||||
print(f"No digits found in {s!r}", flush=True)
|
||||
return last_digit
|
||||
|
||||
def is_correct(completion, answer):
|
||||
gold = extract_answer(answer)
|
||||
assert gold != INVALID_ANS, "No ground truth answer found in the document."
|
||||
return extract_answer(completion) == gold
|
||||
assert gold is not None, "No ground truth answer found in the document."
|
||||
|
||||
def number_equal(answer, pred):
|
||||
if pred is None:
|
||||
return False
|
||||
try:
|
||||
return math.isclose(eval(answer), eval(pred), rel_tol=0, abs_tol=1e-4)
|
||||
except:
|
||||
print(
|
||||
f"cannot compare two numbers: answer={answer}, pred={pred}", flush=True
|
||||
)
|
||||
return False
|
||||
|
||||
return number_equal(gold, extract_answer(completion))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -138,7 +121,6 @@ if __name__ == "__main__":
|
||||
acc_res = []
|
||||
for doc in tqdm.tqdm(test):
|
||||
context = doc_to_text(doc, args.use_fewshot)
|
||||
print(context)
|
||||
completion = generate_sample(model, tokenizer, context)
|
||||
answer = doc["answer"]
|
||||
acc = is_correct(completion, answer)
|
||||
|
||||
Reference in New Issue
Block a user