Spaces:
Runtime error
Runtime error
import random | |
import tqdm | |
import os | |
import sys | |
import torch | |
import jsonlines | |
import argparse | |
import jsonlines | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from transformers.generation import GenerationConfig | |
""" | |
git clone https://github.com/openai/human-eval | |
$ pip install -e human-eval | |
evaluate_functional_correctness sample-output-file | |
""" | |
def decode(tokens_list, tokenizer, raw_text_len): | |
sents = [] | |
# print(len(tokens_list)) | |
for tokens in tokens_list: | |
tokens = tokens.cpu().numpy().tolist() | |
sent = tokenizer.tokenizer.decode( | |
tokens[raw_text_len:]) | |
sent = sent.split('<|endoftext|>')[0] | |
sent = sent.split('\n\n\n')[0] | |
sent = sent.split("\n\n")[0] | |
sent = sent.split("def ")[0] | |
sents.append(sent) | |
return sents | |
def generate_sample(model, tokenizer, input_txt): | |
input_ids = tokenizer.tokenizer.encode(input_txt) | |
raw_text_len = len(input_ids) | |
context_enc = torch.tensor([input_ids] ).to(model.device) | |
print(f"Input text: {input_txt}\n") | |
outputs = model.generate(context_enc) | |
output_text = decode(outputs,tokenizer,raw_text_len)[0] | |
print(f"\nOutput text: \n{output_text}\n") | |
return output_text | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Test HF checkpoint.') | |
parser.add_argument("-c", "--checkpoint-path", type=str, help='Checkpoint path', default="Qwen/Qwen-7B") | |
parser.add_argument("-f","--sample-input-file", type=str, default=None, help="data path to HumanEval.jsonl") | |
parser.add_argument("-o","--sample-output-file", type=str, default="HumanEval_res.jsonl") | |
args = parser.parse_args() | |
print('Loading tokenizer ...') | |
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True) | |
print('Loading model ...') | |
model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval() | |
model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True) | |
model.generation_config.do_sample = False | |
f_output = jsonlines.Writer(open(args.sample_output_file, 'w', encoding='utf-8')) | |
f = jsonlines.open(args.sample_input_file) | |
with f_output as output: | |
for jobj in tqdm.tqdm(f, desc='task_idx'): | |
prompt = jobj['prompt'] | |
task_id = jobj['task_id'] | |
gen_sents = generate_sample(model, tokenizer, prompt) | |
gen_jobjs = {'task_id': task_id, "completion": gen_sents} | |
output.write(gen_jobjs) | |
f_output.close() |