mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-20 16:35:47 +08:00
fix format problems in evaluation code; update ceval extraction rules
This commit is contained in:
@@ -1,14 +1,10 @@
|
||||
import random
|
||||
import tqdm
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import jsonlines
|
||||
import argparse
|
||||
import jsonlines
|
||||
from pathlib import Path
|
||||
|
||||
import re
|
||||
import textwrap
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import tqdm
|
||||
import jsonlines
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
@@ -24,25 +20,31 @@ evaluate_functional_correctness HumanEval_res.jsonl
|
||||
DEVICE = "cuda:0"
|
||||
|
||||
def extract_code(text, entry_point):
|
||||
|
||||
# 正则表达式匹配代码块
|
||||
code_block_pattern = re.compile(rf"```(?:[Pp]ython\n)?.*?def\s+{entry_point}.*?:\n(.*?)\n```", re.DOTALL)
|
||||
code_block_pattern = re.compile(
|
||||
rf"```(?:[Pp]ython\n)?.*?def\s+{entry_point}.*?:\n(.*?)\n```", re.DOTALL
|
||||
)
|
||||
code_block = code_block_pattern.search(text)
|
||||
if code_block is None:
|
||||
code_block_pattern = re.compile(rf"def\s+{entry_point}.*?:\n(.*?)(?:\n(?!\n*(?: |\t))|$)", re.DOTALL)
|
||||
code_block_pattern = re.compile(
|
||||
rf"def\s+{entry_point}.*?:\n(.*?)(?:\n(?!\n*(?: |\t))|$)", re.DOTALL
|
||||
)
|
||||
code_block = code_block_pattern.search(text)
|
||||
if code_block is None:
|
||||
code_block_pattern = re.compile(rf"def.*?:\n(.*?)(?:\n(?!\n*(?: |\t))|$)", re.DOTALL)
|
||||
code_block_pattern = re.compile(
|
||||
r"def.*?:\n(.*?)(?:\n(?!\n*(?: |\t))|$)", re.DOTALL
|
||||
)
|
||||
code_block = code_block_pattern.search(text)
|
||||
|
||||
if code_block is not None:
|
||||
return code_block.group(1)
|
||||
else:
|
||||
# if no code block is found, assume the LM is simply filling the code
|
||||
return textwrap.indent(text, ' ' * 4)
|
||||
|
||||
# if no code block is found, assume the LM is simply filling the code
|
||||
return textwrap.indent(text, " " * 4)
|
||||
|
||||
|
||||
def generate_sample(model, tokenizer, question, entry_point):
|
||||
response, history = model.chat(
|
||||
response, _ = model.chat(
|
||||
tokenizer,
|
||||
question,
|
||||
history=None,
|
||||
@@ -52,31 +54,56 @@ def generate_sample(model, tokenizer, question, entry_point):
|
||||
answer = extract_code(response, entry_point)
|
||||
return answer, response
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser(description='Test HF checkpoint.')
|
||||
parser.add_argument("-c", "--checkpoint-path", type=Path, help='Checkpoint path', default="Qwen/Qwen-7B-Chat")
|
||||
parser.add_argument("-f","--sample-input-file", type=str, default=None, help="data path to HumanEval.jsonl")
|
||||
parser.add_argument("-o","--sample-output-file", type=str, default="HumanEval_res.jsonl")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Test HF checkpoint.")
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--checkpoint-path",
|
||||
type=Path,
|
||||
help="Checkpoint path",
|
||||
default="Qwen/Qwen-7B-Chat",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--sample-input-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="data path to HumanEval.jsonl",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o", "--sample-output-file", type=str, default="HumanEval_res.jsonl"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
print('Loading tokenizer ...')
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
|
||||
print("Loading tokenizer ...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.checkpoint_path, trust_remote_code=True
|
||||
)
|
||||
|
||||
print('Loading model ...')
|
||||
model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True, bf16=True, use_flash_attn=True).eval()
|
||||
model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
|
||||
model.generation_config.do_sample = False # use greedy decoding
|
||||
print("Loading model ...")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.checkpoint_path,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
bf16=True,
|
||||
use_flash_attn=True,
|
||||
).eval()
|
||||
model.generation_config = GenerationConfig.from_pretrained(
|
||||
args.checkpoint_path, trust_remote_code=True
|
||||
)
|
||||
model.generation_config.do_sample = False # use greedy decoding
|
||||
|
||||
f_output = jsonlines.Writer(open(args.sample_output_file, 'w', encoding='utf-8'))
|
||||
f_output = jsonlines.Writer(open(args.sample_output_file, "w", encoding="utf-8"))
|
||||
|
||||
f = jsonlines.open(args.sample_input_file)
|
||||
with f_output as output:
|
||||
for jobj in tqdm.tqdm(f, desc='task_idx'):
|
||||
prompt = "Help me fill the following code.\n" + jobj['prompt']
|
||||
task_id = jobj['task_id']
|
||||
answer, response = generate_sample(model, tokenizer, prompt, jobj['entry_point'])
|
||||
gen_jobjs = {'task_id': task_id, "completion": answer, 'response': response}
|
||||
for jobj in tqdm.tqdm(f, desc="task_idx"):
|
||||
prompt = "Help me fill the following code.\n" + jobj["prompt"]
|
||||
task_id = jobj["task_id"]
|
||||
answer, response = generate_sample(
|
||||
model, tokenizer, prompt, jobj["entry_point"]
|
||||
)
|
||||
gen_jobjs = {"task_id": task_id, "completion": answer, "response": response}
|
||||
output.write(gen_jobjs)
|
||||
f_output.close()
|
||||
|
||||
Reference in New Issue
Block a user