Spaces:
Runtime error
Runtime error
import os | |
os.environ['TRANSFORMERS_CACHE'] = "data/parietal/store3/soda/lihu/hf_model/" | |
from transformers import AutoTokenizer | |
import transformers | |
import torch | |
model = "PY007/TinyLlama-1.1B-Chat-v0.3" | |
tokenizer = AutoTokenizer.from_pretrained(model) | |
pipeline = transformers.pipeline( | |
"text-generation", | |
model=model, | |
torch_dtype=torch.float32, | |
device_map="auto", | |
) | |
CHAT_EOS_TOKEN_ID = 32002 | |
def generate_answer(query, sample_num=3): | |
#prompt = "Who is Lihu Chen?" | |
formatted_prompt = ( | |
f"<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n" | |
) | |
sequences = pipeline( | |
formatted_prompt, | |
do_sample=True, | |
top_k=50, | |
top_p = 0.9, | |
num_return_sequences=sample_num, | |
repetition_penalty=1.1, | |
max_new_tokens=150, | |
eos_token_id=CHAT_EOS_TOKEN_ID, | |
) | |
answers = list() | |
for seq in sequences: | |
answer = seq['generated_text'].replace(formatted_prompt, "") | |
answers.append(answer) | |
#print(f"Result: {answer}") | |
#print("------------------------------------------") | |
return answers |