add 72B and 1.8B Qwen models, add Ascend 910 and Hygon DCU support, add docker support

This commit is contained in:
yangapku
2023-11-30 15:29:13 +08:00
parent 981c89b2a9
commit e8e15962d8
52 changed files with 6139 additions and 1435 deletions

View File

@@ -2,15 +2,17 @@ import json
import re
from pathlib import Path
import argparse
import requests
import math
import numpy as np
import tqdm
from datasets import load_from_disk, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
'''
"""
python eval/evaluate_chat_gsm8k.py [--use-fewshot]
'''
"""
INVALID_ANS = "[invalid]"
DEVICE = "cuda:0"
@@ -32,20 +34,6 @@ def doc_to_text(doc, use_fewshot):
context = doc["question"]
return context
def decode(tokens_list, tokenizer, raw_text_len):
sents = []
for tokens in tokens_list:
tokens = tokens.cpu().numpy().tolist()
sent = tokenizer.tokenizer.decode(tokens[raw_text_len:])
sent = sent.split("<|endoftext|>")[0]
sent = sent.split("\n\n\n")[0]
sent = sent.split("\n\n")[0]
sent = sent.split("Question:")[0]
sents.append(sent)
return sents
def generate_sample(model, tokenizer, question):
response, _ = model.chat(
tokenizer,
@@ -58,40 +46,35 @@ def generate_sample(model, tokenizer, question):
print("=============")
return response
def extract_answer_hf(completion):
def _get_last_digit(s):
_PAT_LAST_DIGIT = re.compile(
r"(?<=(\s|[\$%#{]))([+-])?(?=(\S))(0|([1-9](\d*|\d{0,2}(,\d{3})*)))?(\.\d*[1-9])?(?=(\s|[.,}]|$))"
)
match = list(_PAT_LAST_DIGIT.finditer(s))
if match:
last_digit = match[-1].group().replace(",", "").replace("+", "")
# print(f"The last digit in {s} is {last_digit}")
else:
last_digit = None
print(f"No digits found in {s!r}")
return last_digit
job_gen = completion.strip(".").replace("\n", "\\n")
last_digit = _get_last_digit(job_gen)
if last_digit is not None:
return eval(last_digit)
return INVALID_ANS
def extract_answer(completion):
try:
last_number = re.findall(r"\d+", completion)[-1]
return eval(last_number)
except:
return INVALID_ANS
def extract_answer(s):
_PAT_LAST_DIGIT = re.compile(
r"([+-])?(?=([0-9]|\.[0-9]))(0|([1-9](\d{0,2}(,\d{3})*)|\d*))?(\.\d*)?(?=\D|$)"
)
match = list(_PAT_LAST_DIGIT.finditer(s))
if match:
last_digit = match[-1].group().replace(",", "").replace("+", "").strip()
# print(f"The last digit in {s} is {last_digit}")
else:
last_digit = None
print(f"No digits found in {s!r}", flush=True)
return last_digit
def is_correct(completion, answer):
gold = extract_answer(answer)
assert gold != INVALID_ANS, "No ground truth answer found in the document."
return extract_answer(completion) == gold
assert gold is not None, "No ground truth answer found in the document."
def number_equal(answer, pred):
if pred is None:
return False
try:
return math.isclose(eval(answer), eval(pred), rel_tol=0, abs_tol=1e-4)
except:
print(
f"cannot compare two numbers: answer={answer}, pred={pred}", flush=True
)
return False
return number_equal(gold, extract_answer(completion))
if __name__ == "__main__":
@@ -138,7 +121,6 @@ if __name__ == "__main__":
acc_res = []
for doc in tqdm.tqdm(test):
context = doc_to_text(doc, args.use_fewshot)
print(context)
completion = generate_sample(model, tokenizer, context)
answer = doc["answer"]
acc = is_correct(completion, answer)