import os os.environ['TRANSFORMERS_CACHE'] = "data/parietal/store3/soda/lihu/hf_model/" from transformers import AutoTokenizer import transformers import torch model = "PY007/TinyLlama-1.1B-Chat-v0.3" tokenizer = AutoTokenizer.from_pretrained(model) pipeline = transformers.pipeline( "text-generation", model=model, torch_dtype=torch.float32, device_map="auto", ) CHAT_EOS_TOKEN_ID = 32002 def generate_answer(query, sample_num=3): #prompt = "Who is Lihu Chen?" formatted_prompt = ( f"<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n" ) sequences = pipeline( formatted_prompt, do_sample=True, top_k=50, top_p = 0.9, num_return_sequences=sample_num, repetition_penalty=1.1, max_new_tokens=150, eos_token_id=CHAT_EOS_TOKEN_ID, ) answers = list() for seq in sequences: answer = seq['generated_text'].replace(formatted_prompt, "") answers.append(answer) #print(f"Result: {answer}") #print("------------------------------------------") return answers