mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-20 16:35:47 +08:00
add example: auto_comments
This commit is contained in:
189
examples/auto_comments.py
Normal file
189
examples/auto_comments.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# 运行方式:python auto_comments.py --path 'path of file or folder'
|
||||
# 脚本功能:使用QWen-7B-Chat为提供的代码文件自动生成注释。(详见auto_comments.md)
|
||||
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
MaxLine = 50 # 限制单次处理最大代码行数
|
||||
SplitKey = ["\ndef "] # 自定义的切分代码标识
|
||||
CodeFileType = ["py"] # 目前仅测试过对python文件生成注释
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--path', type=str, default='Qwen-7B/eval/evaluate_ceval.py')
|
||||
parser.add_argument('--regenerate', action='store_true', default=False) #如果已经生成过注释,默认不会重新生成
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
class QWenChat():
|
||||
def __init__(self):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
|
||||
|
||||
# use bf16
|
||||
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval()
|
||||
# use fp16
|
||||
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval()
|
||||
# use cpu only
|
||||
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="cpu", trust_remote_code=True).eval()
|
||||
# use auto mode, automatically select precision based on the device.
|
||||
self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True).eval()
|
||||
|
||||
# Specify hyperparameters for generation
|
||||
self.model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
|
||||
self.history = None
|
||||
|
||||
def chat(self, query, system = ""):
|
||||
|
||||
# use history
|
||||
# response, history = self.model.chat(self.tokenizer, query, history=self.history)
|
||||
|
||||
# 默认不使用history
|
||||
response, history = self.model.chat(self.tokenizer, query, history=None)
|
||||
self.history = history
|
||||
|
||||
return response
|
||||
# 生成注释
|
||||
def gen_code_comments(context, model = None, **kwargs):
|
||||
prompt = "\n为以上代码生成细致的中文注释,注意使用合适的语法。要求必须在每个函数开头生成一段统一的函数功能注释。\n除了注释,请保证原始代码内容不变。不要返回除了注释和代码以外的其余信息,不要生成额外代码。\n"
|
||||
return model.chat(context + prompt)
|
||||
|
||||
def read_file(path):
|
||||
f = open(path, "r",encoding='utf-8')
|
||||
lines = f.readlines()
|
||||
return "".join(lines)
|
||||
|
||||
def write_file(path, context):
|
||||
with open(path,'w') as f:
|
||||
f.write(context)
|
||||
|
||||
# 如果代码文件过长,可以简单按照最大行数切分代码
|
||||
def split_context_by_maxline(text):
|
||||
lines = text.split("\n")
|
||||
lines_len = len(lines)
|
||||
res = []
|
||||
for i in range(MaxLine, lines_len, MaxLine):
|
||||
res.append("\n".join(lines[i-MaxLine:i]))
|
||||
|
||||
if i < lines_len:
|
||||
res.append("\n".join(lines[i:]))
|
||||
return res
|
||||
|
||||
# 如果代码文件过长,可以简单按照函数切分代码
|
||||
def split_context_by_splitkey(text):
|
||||
blocks = text.split(SplitKey[0])
|
||||
return [blocks[0]] + [SplitKey[0]+x for x in blocks[1:]]
|
||||
|
||||
# merge原始代码和生成的注释,目的是保证原始代码不被更改。这部分可以使用各种不同的策略处理。
|
||||
def merge_code_and_comments(original_file, comments_path):
|
||||
res = []
|
||||
ori_f = open(original_file, "r",encoding='utf-8')
|
||||
ori_lines = ori_f.readlines()
|
||||
|
||||
com_f = open(comments_path, "r",encoding='utf-8')
|
||||
com_lines = com_f.readlines()
|
||||
len_com_lines = len(com_lines)
|
||||
p = 0
|
||||
j = 0
|
||||
for i, line in enumerate(ori_lines):
|
||||
if line.isspace():
|
||||
continue
|
||||
if line.strip()[0] == '#':
|
||||
res.append(line)
|
||||
continue
|
||||
while j < len_com_lines and line[:-1] not in com_lines[j]:
|
||||
j += 1
|
||||
if j < len_com_lines:
|
||||
p = j - 1
|
||||
up_comments = []
|
||||
triple_dot_flag = 0
|
||||
while p < j:
|
||||
if p < 0 or (res and res[-1] and com_lines[p] == res[-1]):
|
||||
break
|
||||
if com_lines[p].strip() and (len(com_lines[p].strip())>3 and com_lines[p].strip()[-3:] == '"""' and com_lines[p].strip()[:3] == '"""') or (len(com_lines[p].strip())>3 and com_lines[p].strip()[-3:] == "'''" and com_lines[p].strip()[:3] == "'''"):
|
||||
up_comments.append(com_lines[p])
|
||||
p -= 1
|
||||
continue
|
||||
if com_lines[p].strip() and (com_lines[p].strip()[-3:] == '"""' or com_lines[p].strip()[:3] == '"""' or com_lines[p].strip()[-3:] == "'''" or com_lines[p].strip()[:3] == "'''"):
|
||||
triple_dot_flag = (triple_dot_flag + 1)%2
|
||||
up_comments.append(com_lines[p])
|
||||
p -= 1
|
||||
continue
|
||||
if triple_dot_flag:
|
||||
up_comments.append(com_lines[p])
|
||||
p -= 1
|
||||
continue
|
||||
if (com_lines[p].strip()=="") or (com_lines[p].strip() and com_lines[p].strip()[0] == '#' and "省略部分内容" not in com_lines[p]):
|
||||
up_comments.append(com_lines[p])
|
||||
else:
|
||||
break
|
||||
p -= 1
|
||||
if up_comments:
|
||||
res.extend(reversed(up_comments))
|
||||
if "#" in com_lines[j] and "#" not in line:
|
||||
in_line_comments = " #" + com_lines[j].split("#")[-1]
|
||||
res.append(line[:-1]+in_line_comments)
|
||||
else:
|
||||
res.append(line)
|
||||
p = j+1
|
||||
else:
|
||||
res.append(line)
|
||||
j = p
|
||||
|
||||
write_file(comments_path, "".join(res))
|
||||
|
||||
# 处理单个文件
|
||||
def deal_one_file(model, path, args):
|
||||
context = read_file(path)
|
||||
|
||||
fname = path.split("/")[-1]
|
||||
fpath = "/".join(path.split("/")[:-1])
|
||||
outfname = fname.split(".")[0]+"_comments."+fname.split(".")[-1]
|
||||
|
||||
comments_path = os.path.join(fpath, outfname)
|
||||
if (not args.regenerate) and os.path.exists(comments_path):
|
||||
print("use cache: ", comments_path)
|
||||
return
|
||||
|
||||
context_line = len(context.split("\n"))
|
||||
if context_line < MaxLine:
|
||||
res = gen_code_comments(context, model = model)
|
||||
elif SplitKey[0] not in context:
|
||||
context_list = split_context_by_maxline(context)
|
||||
res = "\n".join([gen_code_comments(context_block, model = model) for context_block in context_list])
|
||||
else:
|
||||
context_list = split_context_by_splitkey(context)
|
||||
res = "\n".join([gen_code_comments(context_block, model = model) for context_block in context_list])
|
||||
|
||||
write_file(comments_path, res)
|
||||
merge_code_and_comments(path, comments_path)
|
||||
|
||||
# 处理文件夹
|
||||
def deal_folder(model, path, args):
|
||||
for fl in os.listdir(path):
|
||||
now_path = os.path.join(path, fl)
|
||||
if os.path.isfile(now_path):
|
||||
if (now_path.split(".")[-1] in CodeFileType) and ("_comments" not in now_path):
|
||||
deal_one_file(model, now_path, args)
|
||||
elif os.path.isdir(now_path):
|
||||
deal_folder(model, now_path, args)
|
||||
else:
|
||||
print("Please specify a correct path!")
|
||||
|
||||
def transfer(args):
|
||||
model = QWenChat()
|
||||
|
||||
if os.path.isfile(args.path):
|
||||
if (args.path.split(".")[-1] in CodeFileType) and ("_comments" not in args.path):
|
||||
deal_one_file(model, args.path, args)
|
||||
elif os.path.isdir(args.path):
|
||||
deal_folder(model, args.path, args)
|
||||
else:
|
||||
print("Please specify a correct path!")
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
print(args)
|
||||
transfer(args)
|
||||
Reference in New Issue
Block a user