Spaces:
Runtime error
Runtime error
| import os | |
| os.environ['TRANSFORMERS_CACHE'] = "data/parietal/store3/soda/lihu/hf_model/" | |
| from transformers import AutoTokenizer | |
| import transformers | |
| import torch | |
| model = "PY007/TinyLlama-1.1B-Chat-v0.3" | |
| tokenizer = AutoTokenizer.from_pretrained(model) | |
| pipeline = transformers.pipeline( | |
| "text-generation", | |
| model=model, | |
| torch_dtype=torch.float32, | |
| device_map="auto", | |
| ) | |
| CHAT_EOS_TOKEN_ID = 32002 | |
| def generate_answer(query, sample_num=3): | |
| #prompt = "Who is Lihu Chen?" | |
| formatted_prompt = ( | |
| f"<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n" | |
| ) | |
| sequences = pipeline( | |
| formatted_prompt, | |
| do_sample=True, | |
| top_k=50, | |
| top_p = 0.9, | |
| num_return_sequences=sample_num, | |
| repetition_penalty=1.1, | |
| max_new_tokens=150, | |
| eos_token_id=CHAT_EOS_TOKEN_ID, | |
| ) | |
| answers = list() | |
| for seq in sequences: | |
| answer = seq['generated_text'].replace(formatted_prompt, "") | |
| answers.append(answer) | |
| #print(f"Result: {answer}") | |
| #print("------------------------------------------") | |
| return answers |