mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-20 16:35:47 +08:00
update evaluate scripts
This commit is contained in:
@@ -20,13 +20,22 @@ python evaluate_ceval.py -d data/ceval/
|
|||||||
|
|
||||||
def load_models_tokenizer(args):
|
def load_models_tokenizer(args):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
args.checkpoint_path, trust_remote_code=True
|
args.checkpoint_path,
|
||||||
|
pad_token='<|extra_0|>',
|
||||||
|
eos_token='<|endoftext|>',
|
||||||
|
padding_side='left',
|
||||||
|
trust_remote_code=True
|
||||||
)
|
)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
args.checkpoint_path, device_map="auto", trust_remote_code=True
|
args.checkpoint_path,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=True
|
||||||
).eval()
|
).eval()
|
||||||
model.generation_config = GenerationConfig.from_pretrained(
|
model.generation_config = GenerationConfig.from_pretrained(
|
||||||
args.checkpoint_path, trust_remote_code=True
|
args.checkpoint_path,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
trust_remote_code=True
|
||||||
)
|
)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
@@ -56,11 +65,12 @@ def generate_few_shot_prompt(k, subject, dev_df):
|
|||||||
|
|
||||||
|
|
||||||
def get_logits(tokenizer, model, inputs: List[str]):
|
def get_logits(tokenizer, model, inputs: List[str]):
|
||||||
input_ids = tokenizer(inputs, padding=False)["input_ids"]
|
input_ids = tokenizer(inputs, padding='longest')["input_ids"]
|
||||||
input_ids = torch.tensor(input_ids, device=model.device)
|
input_ids = torch.tensor(input_ids, device=model.device)
|
||||||
tokens = {"input_ids": input_ids}
|
tokens = {"input_ids": input_ids}
|
||||||
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
|
|
||||||
outputs = model(input_ids)["logits"]
|
outputs = model(input_ids, attention_mask=attention_mask)["logits"]
|
||||||
logits = outputs[:, -1, :]
|
logits = outputs[:, -1, :]
|
||||||
log_probs = torch.nn.functional.softmax(logits, dim=-1)
|
log_probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
return log_probs, {"tokens": tokens}
|
return log_probs, {"tokens": tokens}
|
||||||
@@ -76,6 +86,7 @@ def eval_subject(
|
|||||||
dev_df=None,
|
dev_df=None,
|
||||||
few_shot=False,
|
few_shot=False,
|
||||||
save_result_dir=None,
|
save_result_dir=None,
|
||||||
|
batch_size=1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
result = []
|
result = []
|
||||||
@@ -88,38 +99,38 @@ def eval_subject(
|
|||||||
if args.debug:
|
if args.debug:
|
||||||
print(f"few_shot_prompt: {few_shot_prompt}")
|
print(f"few_shot_prompt: {few_shot_prompt}")
|
||||||
|
|
||||||
for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
|
choices_ids = torch.tensor(
|
||||||
|
tokenizer("A")["input_ids"] + tokenizer("B")["input_ids"] +
|
||||||
|
tokenizer("C")["input_ids"] + tokenizer("D")["input_ids"]
|
||||||
|
).unsqueeze(0).to(model.device)
|
||||||
|
|
||||||
|
idx_list = list(range(0, len(test_df), batch_size))
|
||||||
|
for i in tqdm(idx_list):
|
||||||
|
full_prompt_list = []
|
||||||
|
answer_list = []
|
||||||
|
for row in test_df.iloc[i:i+batch_size].to_dict(orient='records'):
|
||||||
question = format_example(row, include_answer=False)
|
question = format_example(row, include_answer=False)
|
||||||
full_prompt = few_shot_prompt + question
|
full_prompt = few_shot_prompt + question
|
||||||
|
full_prompt_list.append(full_prompt)
|
||||||
|
if 'answer' in row:
|
||||||
|
answer_list.append(row['answer'])
|
||||||
|
|
||||||
output, input_info = get_logits(tokenizer, model, [full_prompt])
|
logits, input_info = get_logits(tokenizer, model, full_prompt_list)
|
||||||
assert output.shape[0] == 1
|
softval = logits.gather(1, choices_ids.expand(logits.size(0), -1)).softmax(1)
|
||||||
logits = output.flatten()
|
|
||||||
|
|
||||||
softval = torch.nn.functional.softmax(
|
|
||||||
torch.tensor(
|
|
||||||
[
|
|
||||||
logits[tokenizer("A")["input_ids"]],
|
|
||||||
logits[tokenizer("B")["input_ids"]],
|
|
||||||
logits[tokenizer("C")["input_ids"]],
|
|
||||||
logits[tokenizer("D")["input_ids"]],
|
|
||||||
]
|
|
||||||
),
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
if softval.dtype in {torch.bfloat16, torch.float16}:
|
if softval.dtype in {torch.bfloat16, torch.float16}:
|
||||||
softval = softval.to(dtype=torch.float32)
|
softval = softval.to(dtype=torch.float32)
|
||||||
probs = softval.detach().cpu().numpy()
|
probs = softval.detach().cpu().numpy()
|
||||||
|
|
||||||
for i, choice in enumerate(choices):
|
for i in range(len(probs)):
|
||||||
all_probs[f"prob_{choice}"].append(probs[i])
|
for j, choice in enumerate(choices):
|
||||||
pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
|
all_probs[f"prob_{choice}"].append(probs[i][j])
|
||||||
|
pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs[i])]
|
||||||
|
|
||||||
if "answer" in row:
|
if answer_list != []:
|
||||||
correct = 1 if pred == row["answer"] else 0
|
correct = 1 if pred == answer_list[i] else 0
|
||||||
score.append(correct)
|
score.append(correct)
|
||||||
if args.debug:
|
if args.debug:
|
||||||
print(f'{question} pred: {pred} ref: {row["answer"]}')
|
print(f'{question} pred: {pred} ref: {answer_list[i]}')
|
||||||
result.append(pred)
|
result.append(pred)
|
||||||
|
|
||||||
if score:
|
if score:
|
||||||
@@ -395,6 +406,7 @@ def main(args):
|
|||||||
k=5,
|
k=5,
|
||||||
few_shot=True,
|
few_shot=True,
|
||||||
save_result_dir=f"outs/ceval_eval_result",
|
save_result_dir=f"outs/ceval_eval_result",
|
||||||
|
batch_size=args.batch_size
|
||||||
)
|
)
|
||||||
dev_result[subject_name] = score
|
dev_result[subject_name] = score
|
||||||
cal_ceval(dev_result)
|
cal_ceval(dev_result)
|
||||||
@@ -425,6 +437,12 @@ if __name__ == "__main__":
|
|||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--debug", action="store_true", default=False, help="Print infos."
|
"--debug", action="store_true", default=False, help="Print infos."
|
||||||
)
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="batch size",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|||||||
@@ -26,13 +26,22 @@ def load_models_tokenizer(args):
|
|||||||
from transformers.generation import GenerationConfig
|
from transformers.generation import GenerationConfig
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
args.checkpoint_path, trust_remote_code=True
|
args.checkpoint_path,
|
||||||
|
pad_token='<|extra_0|>',
|
||||||
|
eos_token='<|endoftext|>',
|
||||||
|
padding_side='left',
|
||||||
|
trust_remote_code=True
|
||||||
)
|
)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
args.checkpoint_path, device_map="auto", trust_remote_code=True
|
args.checkpoint_path,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=True
|
||||||
).eval()
|
).eval()
|
||||||
model.generation_config = GenerationConfig.from_pretrained(
|
model.generation_config = GenerationConfig.from_pretrained(
|
||||||
args.checkpoint_path, trust_remote_code=True
|
args.checkpoint_path,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
trust_remote_code=True
|
||||||
)
|
)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
@@ -62,11 +71,12 @@ def generate_few_shot_prompt(k, subject, dev_df):
|
|||||||
|
|
||||||
|
|
||||||
def get_logits(tokenizer, model, inputs: List[str]):
|
def get_logits(tokenizer, model, inputs: List[str]):
|
||||||
input_ids = tokenizer(inputs, padding=False)["input_ids"]
|
input_ids = tokenizer(inputs, padding='longest')["input_ids"]
|
||||||
input_ids = torch.tensor(input_ids, device=model.device)
|
input_ids = torch.tensor(input_ids, device=model.device)
|
||||||
tokens = {"input_ids": input_ids}
|
tokens = {"input_ids": input_ids}
|
||||||
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
|
|
||||||
outputs = model(input_ids)["logits"]
|
outputs = model(input_ids, attention_mask=attention_mask)["logits"]
|
||||||
logits = outputs[:, -1, :]
|
logits = outputs[:, -1, :]
|
||||||
log_probs = torch.nn.functional.softmax(logits, dim=-1)
|
log_probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
return log_probs, {"tokens": tokens}
|
return log_probs, {"tokens": tokens}
|
||||||
@@ -82,6 +92,7 @@ def eval_subject(
|
|||||||
dev_df=None,
|
dev_df=None,
|
||||||
few_shot=False,
|
few_shot=False,
|
||||||
save_result_dir=None,
|
save_result_dir=None,
|
||||||
|
batch_size=1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
result = []
|
result = []
|
||||||
@@ -94,38 +105,38 @@ def eval_subject(
|
|||||||
if args.debug:
|
if args.debug:
|
||||||
print(f"few_shot_prompt: {few_shot_prompt}")
|
print(f"few_shot_prompt: {few_shot_prompt}")
|
||||||
|
|
||||||
for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
|
choices_ids = torch.tensor(
|
||||||
|
tokenizer("A")["input_ids"] + tokenizer("B")["input_ids"] +
|
||||||
|
tokenizer("C")["input_ids"] + tokenizer("D")["input_ids"]
|
||||||
|
).unsqueeze(0).to(model.device)
|
||||||
|
|
||||||
|
idx_list = list(range(0, len(test_df), batch_size))
|
||||||
|
for i in tqdm(idx_list):
|
||||||
|
full_prompt_list = []
|
||||||
|
answer_list = []
|
||||||
|
for row in test_df.iloc[i:i+batch_size].to_dict(orient='records'):
|
||||||
question = format_example(row, include_answer=False)
|
question = format_example(row, include_answer=False)
|
||||||
full_prompt = few_shot_prompt + question
|
full_prompt = few_shot_prompt + question
|
||||||
|
full_prompt_list.append(full_prompt)
|
||||||
|
if 'Answer' in row:
|
||||||
|
answer_list.append(row['Answer'])
|
||||||
|
|
||||||
output, input_info = get_logits(tokenizer, model, [full_prompt])
|
logits, input_info = get_logits(tokenizer, model, full_prompt_list)
|
||||||
assert output.shape[0] == 1
|
softval = logits.gather(1, choices_ids.expand(logits.size(0), -1)).softmax(1)
|
||||||
logits = output.flatten()
|
|
||||||
|
|
||||||
softval = torch.nn.functional.softmax(
|
|
||||||
torch.tensor(
|
|
||||||
[
|
|
||||||
logits[tokenizer("A")["input_ids"]],
|
|
||||||
logits[tokenizer("B")["input_ids"]],
|
|
||||||
logits[tokenizer("C")["input_ids"]],
|
|
||||||
logits[tokenizer("D")["input_ids"]],
|
|
||||||
]
|
|
||||||
),
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
if softval.dtype in {torch.bfloat16, torch.float16}:
|
if softval.dtype in {torch.bfloat16, torch.float16}:
|
||||||
softval = softval.to(dtype=torch.float32)
|
softval = softval.to(dtype=torch.float32)
|
||||||
probs = softval.detach().cpu().numpy()
|
probs = softval.detach().cpu().numpy()
|
||||||
|
|
||||||
for i, choice in enumerate(choices):
|
for i in range(len(probs)):
|
||||||
all_probs[f"prob_{choice}"].append(probs[i])
|
for j, choice in enumerate(choices):
|
||||||
pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
|
all_probs[f"prob_{choice}"].append(probs[i][j])
|
||||||
|
pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs[i])]
|
||||||
|
|
||||||
if "Answer" in row:
|
if answer_list != []:
|
||||||
correct = 1 if pred == row["Answer"] else 0
|
correct = 1 if pred == answer_list[i] else 0
|
||||||
score.append(correct)
|
score.append(correct)
|
||||||
if args.debug:
|
if args.debug:
|
||||||
print(f'{question} pred: {pred} ref: {row["Answer"]}')
|
print(f'{question} pred: {pred} ref: {answer_list[i]}')
|
||||||
result.append(pred)
|
result.append(pred)
|
||||||
|
|
||||||
if score:
|
if score:
|
||||||
@@ -288,6 +299,7 @@ def main(args):
|
|||||||
k=5,
|
k=5,
|
||||||
few_shot=True,
|
few_shot=True,
|
||||||
save_result_dir=f"outs/cmmlu_eval_result",
|
save_result_dir=f"outs/cmmlu_eval_result",
|
||||||
|
batch_size=args.batch_size
|
||||||
)
|
)
|
||||||
test_result[subject_name] = score
|
test_result[subject_name] = score
|
||||||
cal_cmmlu(test_result)
|
cal_cmmlu(test_result)
|
||||||
@@ -318,6 +330,12 @@ if __name__ == "__main__":
|
|||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--debug", action="store_true", default=False, help="Print infos."
|
"--debug", action="store_true", default=False, help="Print infos."
|
||||||
)
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="batch size",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|||||||
@@ -21,13 +21,22 @@ python eval/evaluate_mmlu.py -d data/mmlu/data/
|
|||||||
|
|
||||||
def load_models_tokenizer(args):
|
def load_models_tokenizer(args):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
args.checkpoint_path, trust_remote_code=True
|
args.checkpoint_path,
|
||||||
|
pad_token='<|extra_0|>',
|
||||||
|
eos_token='<|endoftext|>',
|
||||||
|
padding_side='left',
|
||||||
|
trust_remote_code=True
|
||||||
)
|
)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
args.checkpoint_path, device_map="auto", trust_remote_code=True
|
args.checkpoint_path,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=True
|
||||||
).eval()
|
).eval()
|
||||||
model.generation_config = GenerationConfig.from_pretrained(
|
model.generation_config = GenerationConfig.from_pretrained(
|
||||||
args.checkpoint_path, trust_remote_code=True
|
args.checkpoint_path,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
trust_remote_code=True
|
||||||
)
|
)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
@@ -67,14 +76,15 @@ def generate_few_shot_prompt(k, subject, dev_df):
|
|||||||
|
|
||||||
|
|
||||||
def get_logits(tokenizer, model, inputs: List[str]):
|
def get_logits(tokenizer, model, inputs: List[str]):
|
||||||
input_ids = tokenizer(inputs, padding=False)["input_ids"]
|
input_ids = tokenizer(inputs, padding='longest')["input_ids"]
|
||||||
input_ids = torch.tensor(input_ids, device=model.device)
|
input_ids = torch.tensor(input_ids, device=model.device)
|
||||||
|
|
||||||
if input_ids.shape[1] > args.max_seq_len:
|
if input_ids.shape[1] > args.max_seq_len:
|
||||||
input_ids = input_ids[:, input_ids.shape[1] - args.max_seq_len + 1 :]
|
input_ids = input_ids[:, input_ids.shape[1] - args.max_seq_len + 1 :]
|
||||||
tokens = {"input_ids": input_ids}
|
tokens = {"input_ids": input_ids}
|
||||||
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
|
|
||||||
outputs = model(input_ids)["logits"]
|
outputs = model(input_ids, attention_mask=attention_mask)["logits"]
|
||||||
logits = outputs[:, -1, :]
|
logits = outputs[:, -1, :]
|
||||||
log_probs = torch.nn.functional.softmax(logits, dim=-1)
|
log_probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
return log_probs, {"tokens": tokens}
|
return log_probs, {"tokens": tokens}
|
||||||
@@ -90,6 +100,7 @@ def eval_subject(
|
|||||||
dev_df=None,
|
dev_df=None,
|
||||||
few_shot=False,
|
few_shot=False,
|
||||||
save_result_dir=None,
|
save_result_dir=None,
|
||||||
|
batch_size=1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
result = []
|
result = []
|
||||||
@@ -102,38 +113,38 @@ def eval_subject(
|
|||||||
if args.debug:
|
if args.debug:
|
||||||
print(f"few_shot_prompt: {few_shot_prompt}")
|
print(f"few_shot_prompt: {few_shot_prompt}")
|
||||||
|
|
||||||
for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
|
choices_ids = torch.tensor(
|
||||||
|
tokenizer(" A")["input_ids"] + tokenizer(" B")["input_ids"] +
|
||||||
|
tokenizer(" C")["input_ids"] + tokenizer(" D")["input_ids"]
|
||||||
|
).unsqueeze(0).to(model.device)
|
||||||
|
|
||||||
|
idx_list = list(range(0, len(test_df), batch_size))
|
||||||
|
for i in tqdm(idx_list):
|
||||||
|
full_prompt_list = []
|
||||||
|
answer_list = []
|
||||||
|
for row in test_df.iloc[i:i+batch_size].to_dict(orient='records'):
|
||||||
question = format_example(row, include_answer=False)
|
question = format_example(row, include_answer=False)
|
||||||
full_prompt = few_shot_prompt + question
|
full_prompt = few_shot_prompt + question
|
||||||
|
full_prompt_list.append(full_prompt)
|
||||||
|
if 'answer' in row:
|
||||||
|
answer_list.append(row['answer'])
|
||||||
|
|
||||||
output, input_info = get_logits(tokenizer, model, [full_prompt])
|
logits, input_info = get_logits(tokenizer, model, full_prompt_list)
|
||||||
assert output.shape[0] == 1
|
softval = logits.gather(1, choices_ids.expand(logits.size(0), -1)).softmax(1)
|
||||||
logits = output.flatten()
|
|
||||||
|
|
||||||
softval = torch.nn.functional.softmax(
|
|
||||||
torch.tensor(
|
|
||||||
[
|
|
||||||
logits[tokenizer(" A")["input_ids"]],
|
|
||||||
logits[tokenizer(" B")["input_ids"]],
|
|
||||||
logits[tokenizer(" C")["input_ids"]],
|
|
||||||
logits[tokenizer(" D")["input_ids"]],
|
|
||||||
]
|
|
||||||
),
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
if softval.dtype in {torch.bfloat16, torch.float16}:
|
if softval.dtype in {torch.bfloat16, torch.float16}:
|
||||||
softval = softval.to(dtype=torch.float32)
|
softval = softval.to(dtype=torch.float32)
|
||||||
probs = softval.detach().cpu().numpy()
|
probs = softval.detach().cpu().numpy()
|
||||||
|
|
||||||
for i, choice in enumerate(choices):
|
for i in range(len(probs)):
|
||||||
all_probs[f"prob_{choice}"].append(probs[i])
|
for j, choice in enumerate(choices):
|
||||||
pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
|
all_probs[f"prob_{choice}"].append(probs[i][j])
|
||||||
|
pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs[i])]
|
||||||
|
|
||||||
if "answer" in row:
|
if answer_list != []:
|
||||||
correct = 1 if pred == row["answer"] else 0
|
correct = 1 if pred == answer_list[i] else 0
|
||||||
score.append(correct)
|
score.append(correct)
|
||||||
if args.debug:
|
if args.debug:
|
||||||
print(f'{question} pred: {pred} ref: {row["answer"]}')
|
print(f'{question} pred: {pred} ref: {answer_list[i]}')
|
||||||
result.append(pred)
|
result.append(pred)
|
||||||
|
|
||||||
if save_result_dir:
|
if save_result_dir:
|
||||||
@@ -209,6 +220,7 @@ def main(args):
|
|||||||
k=5,
|
k=5,
|
||||||
few_shot=True,
|
few_shot=True,
|
||||||
save_result_dir=f"outs/mmlu_eval_result",
|
save_result_dir=f"outs/mmlu_eval_result",
|
||||||
|
batch_size=args.batch_size
|
||||||
)
|
)
|
||||||
dev_result[subject_name] = score
|
dev_result[subject_name] = score
|
||||||
cal_mmlu(dev_result)
|
cal_mmlu(dev_result)
|
||||||
@@ -308,6 +320,12 @@ if __name__ == "__main__":
|
|||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--debug", action="store_true", default=False, help="Print infos."
|
"--debug", action="store_true", default=False, help="Print infos."
|
||||||
)
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="batch size",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|||||||
Reference in New Issue
Block a user