ClueAI commited on
Commit
2ee231f
1 Parent(s): 9f52142

Update modeling_t5.py

Browse files
Files changed (1) hide show
  1. modeling_t5.py +4 -13
modeling_t5.py CHANGED
@@ -1,12 +1,3 @@
1
- from transformers import T5ForConditionalGeneration as t5FCG
2
- from transformers.models.t5.configuration_t5 import T5Config
3
- from typing import Optional, Tuple, Union, List, Callable
4
-
5
-
6
-
7
-
8
-
9
-
10
  class T5ForConditionalGeneration(t5FCG):
11
 
12
  def __init__(self, config: T5Config):
@@ -21,7 +12,7 @@ class T5ForConditionalGeneration(t5FCG):
21
  return text.replace("\\n", "\n").replace("\\t", "\t").replace('%20',' ')
22
 
23
 
24
- def get_response(self,tokenizer,text, sample=True, top_p=0.9, temperature=0.7,max_length=1024,no_repeat_ngram_size=12,num_beams=1, length_penalty=0.6,):
25
  base_info = "用户:你是谁?\n小元:我是元语智能公司研发的AI智能助手, 在不违反原则的情况下,我可以回答你的任何问题。\n"
26
  text=base_info+text
27
  text = self.preprocess(text)
@@ -36,7 +27,7 @@ class T5ForConditionalGeneration(t5FCG):
36
  return self.postprocess(out_text[0])
37
 
38
 
39
- def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, sample=True, top_p=0.9, temperature=0.7,max_length=1024):
40
 
41
 
42
  history = history or []
@@ -48,7 +39,7 @@ class T5ForConditionalGeneration(t5FCG):
48
 
49
  input_text = context + "\n用户:" + query + "\n小元:"
50
  input_text = input_text.strip()
51
- response = self.get_response(tokenizer,input_text,sample, top_p, temperature,max_length)
52
 
53
  history.append((query, response))
54
- return response,history
 
 
 
 
 
 
 
 
 
 
1
  class T5ForConditionalGeneration(t5FCG):
2
 
3
  def __init__(self, config: T5Config):
 
12
  return text.replace("\\n", "\n").replace("\\t", "\t").replace('%20',' ')
13
 
14
 
15
+ def get_response(self,tokenizer,text, sample=True, top_p=0.9, temperature=0.7,max_length=1024,no_repeat_ngram_size=12,num_beams=1, length_penalty=0.6):
16
  base_info = "用户:你是谁?\n小元:我是元语智能公司研发的AI智能助手, 在不违反原则的情况下,我可以回答你的任何问题。\n"
17
  text=base_info+text
18
  text = self.preprocess(text)
 
27
  return self.postprocess(out_text[0])
28
 
29
 
30
+ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, sample=True, top_p=0.9, temperature=0.7,max_length=1024,no_repeat_ngram_size=12,num_beams=1, length_penalty=0.6):
31
 
32
 
33
  history = history or []
 
39
 
40
  input_text = context + "\n用户:" + query + "\n小元:"
41
  input_text = input_text.strip()
42
+ response = self.get_response(tokenizer,input_text,sample=sample, top_p=top_p, temperature=temperature,max_length=max_length,no_repeat_ngram_size=no_repeat_ngram_size,num_beams=num_beams, length_penalty=length_penalty)
43
 
44
  history.append((query, response))
45
+ return response,history