from transformers import AutoModelForCausalLM, AutoTokenizer import torch from peft import PeftModel #device = 'cuda:0' model_name = "DUTIR-BioNLP/Taiyi-LLM" model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, torch_dtype=torch.float16, load_in_8bit=True, #device_map = device ) model.eval() tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True ) import logging logging.disable(logging.WARNING) tokenizer.pad_token_id = tokenizer.eod_id tokenizer.bos_token_id = tokenizer.eod_id tokenizer.eos_token_id = tokenizer.eod_id # 开始对话 history_max_len = 1000 utterance_id = 0 def run(message: str, history: str, max_new_tokens: int = 500, temperature: float = 0.10, top_p: float = 0.9, repetition_penalty: float = 1.0): list1 = [] for question, response in history: question = tokenizer(question, return_tensors="pt", add_special_tokens=False).input_ids # eos_token_id = [tokenizer.eos_token_id] eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long) response = tokenizer(response, return_tensors="pt", add_special_tokens=False).input_ids all_token = torch.concat((question, eos_token_id, response, eos_token_id), dim=1) list1.extend(all_token) connect_tensor = torch.tensor([]) for tensor in list1: connect_tensor = torch.concat((connect_tensor, tensor), dim=0) history_token_ids = connect_tensor.reshape(1,-1) user_input = message input_ids = tokenizer(user_input, return_tensors="pt", add_special_tokens=False).input_ids bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long) eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long) user_input_ids = torch.concat([bos_token_id,input_ids, eos_token_id], dim=1) input_token_ids = torch.concat((history_token_ids, user_input_ids), dim=1) model_input_ids = input_token_ids[:, -history_max_len:].to(torch.int) model_input_ids = model_input_ids.to('cuda') with torch.no_grad(): outputs = model.generate( input_ids=model_input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id ) model_input_ids_len = model_input_ids.size(1) response_ids = outputs[:, model_input_ids_len:] history_token_ids = torch.concat((history_token_ids, response_ids.cpu()), dim=1) response = tokenizer.batch_decode(response_ids) return response[0].strip().replace(tokenizer.eos_token, "").replace("\n", "\n\n")