File size: 2,406 Bytes
6b9e4d0 9f52142 7fee7e9 9f52142 7fee7e9 2ee231f e080748 7fee7e9 b1373ff 7fee7e9 2ee231f 7fee7e9 6b9e4d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from transformers import T5ForConditionalGeneration as t5FCG
from transformers.models.t5.configuration_t5 import T5Config
from typing import Optional, Tuple, Union, List, Callable
class T5ForConditionalGeneration(t5FCG):
def __init__(self, config: T5Config):
super().__init__(config)
def preprocess(self,text):
text = text.replace("\n", "\\n").replace("\t", "\\t")
return text
def postprocess(self,text):
return text.replace("\\n", "\n").replace("\\t", "\t").replace('%20',' ')
def get_response(self,tokenizer,text, sample=True, top_p=0.9, temperature=0.7,max_length=1024,no_repeat_ngram_size=12,num_beams=1, length_penalty=0.6):
base_info = ""
text=base_info+text
text = self.preprocess(text)
encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=max_length, return_tensors="pt").to(self.device)
if not sample:
out = self.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=max_length, num_beams=num_beams, length_penalty=length_penalty,do_sample=False)
else:
out = self.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=max_length, do_sample=True, top_p=top_p, temperature=temperature, no_repeat_ngram_size=no_repeat_ngram_size)
out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
return self.postprocess(out_text[0])
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, sample=True, top_p=0.9, temperature=0.7,max_length=2048,no_repeat_ngram_size=12,num_beams=1, length_penalty=0.6):
history = history or []
if len(history) > 5:
history = history[-5:]
context = "\n".join([f"用户:{input_text}\n小元:{answer_text}" for input_text, answer_text in history])
#print(context)
input_text = context + "\n用户:" + query + "\n小元:"
input_text = input_text.strip()
response = self.get_response(tokenizer,input_text,sample=sample, top_p=top_p, temperature=temperature,max_length=max_length,no_repeat_ngram_size=no_repeat_ngram_size,num_beams=num_beams, length_penalty=length_penalty)
history.append((query, response))
return response,history
|