lewiswu1209 commited on
Commit
f3c6b77
1 Parent(s): a0ed808

Refactoring

Browse files
.gitattributes CHANGED
@@ -9,13 +9,9 @@
9
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
  *.model filter=lfs diff=lfs merge=lfs -text
11
  *.msgpack filter=lfs diff=lfs merge=lfs -text
12
- *.npy filter=lfs diff=lfs merge=lfs -text
13
- *.npz filter=lfs diff=lfs merge=lfs -text
14
  *.onnx filter=lfs diff=lfs merge=lfs -text
15
  *.ot filter=lfs diff=lfs merge=lfs -text
16
  *.parquet filter=lfs diff=lfs merge=lfs -text
17
- *.pickle filter=lfs diff=lfs merge=lfs -text
18
- *.pkl filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
  *.pt filter=lfs diff=lfs merge=lfs -text
21
  *.pth filter=lfs diff=lfs merge=lfs -text
 
9
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
  *.model filter=lfs diff=lfs merge=lfs -text
11
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
12
  *.onnx filter=lfs diff=lfs merge=lfs -text
13
  *.ot filter=lfs diff=lfs merge=lfs -text
14
  *.parquet filter=lfs diff=lfs merge=lfs -text
 
 
15
  *.pb filter=lfs diff=lfs merge=lfs -text
16
  *.pt filter=lfs diff=lfs merge=lfs -text
17
  *.pth filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: Winnie
3
- emoji: 🔥
4
- colorFrom: indigo
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 3.1.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
  title: Winnie
3
+ emoji:
4
+ colorFrom: pink
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 3.0.24
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py CHANGED
@@ -1,25 +1,78 @@
1
 
 
 
2
  import gradio as gr
3
 
4
- from bot.interface import Chatbot
 
 
 
5
 
6
- bot = Chatbot()
 
 
 
 
 
 
 
7
 
8
- def greet(input_txt, history = []):
9
  global bot
10
 
11
- if bot is None:
12
- bot = Chatbot()
 
 
 
 
 
 
 
 
 
13
 
14
- history.append(input_txt)
15
- response = bot.chat(history)
16
- history.append(response)
17
 
18
- return response, history
 
 
 
 
 
19
 
20
- if __name__ == "__main__":
21
- gr.Interface(fn=greet,
22
- # title="使用中文和脑子瓦特了的Vicky聊天",
23
- inputs=["text", "state"],
24
- outputs=["text", "state"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  ).launch()
 
 
 
 
1
 
2
+ import os
3
+
4
  import gradio as gr
5
 
6
+ from bot.chatbot import ChatBot
7
+ from bot.config import special_token_list
8
+
9
+ bot:ChatBot = None
10
 
11
+ def get_skill_list() -> list:
12
+ path:str = os.path.split( os.path.realpath(__file__) )[0]
13
+ file_list:list[str] = os.listdir( path + "/bot/skills/" )
14
+ plugin_list:list[str] = []
15
+ for file in file_list:
16
+ if file.endswith(".py"):
17
+ plugin_list.append( file[:-3] )
18
+ return plugin_list
19
 
20
+ def general(input_txt:str, state:dict = {}):
21
  global bot
22
 
23
+ history_list:list = state.get("history", [])
24
+ role_card:dict[str, str] = state.get("role_card", {
25
+ "<NAME>": "Winnie",
26
+ "<GENDER>": "女",
27
+ "<YEAROFBIRTH>":"1995",
28
+ "<MONTHOFBIRTH>":"5",
29
+ "<DAYOFBIRTH>":"6",
30
+ "<ZODIAC>":"金牛座",
31
+ "<AGE>":"27"
32
+ }
33
+ )
34
 
35
+ output_txt:str = None
 
 
36
 
37
+ for skill_name in get_skill_list():
38
+ if output_txt is None:
39
+ plugin = __import__("bot.skills."+skill_name, fromlist=[skill_name])
40
+ plugin_class = getattr(plugin, "Skill")
41
+ p = plugin_class()
42
+ output_txt, history_list, role_card = p.process(input_txt, history_list, role_card)
43
 
44
+ if output_txt is None:
45
+ res, history_list = bot.chat(input_txt, history_list, role_card=role_card)
46
+ output_txt = "".join(res)
47
+
48
+ state["history"] = history_list
49
+ state["role_card"] = role_card
50
+
51
+ return output_txt, state
52
+
53
+ def main() -> None:
54
+ global bot
55
+
56
+ bot = ChatBot.get_chat_bot("lewiswu1209/Winnie", special_token_list=special_token_list)
57
+
58
+ title:str = "使用中文和Winnie聊天"
59
+
60
+ description:str = "输入任意文字,Winnie会和你对话<br>"
61
+ description += "输入ERASE MEMORY,会清空Winnie的记忆<br>"
62
+ description += "输入\"&lt;TAG&gt;=&lt;VALUE&gt;\",可以修改Winnie的角色信息<br>"
63
+ description += "例如:&lt;NAME&gt;=Vicky,会修改Winnie的名字<br>"
64
+ description += "可以修改的角色信息有:<br>"
65
+ description += "&lt;NAME&gt;, &lt;GENDER&gt;, &lt;YEAROFBIRTH&gt;, &lt;MONTHOFBIRTH&gt;, &lt;DAYOFBIRTH&gt;, &lt;ZODIAC&gt;, &lt;AGE&gt;<br>"
66
+ description += "输入“上联:XXXXXXX”,Winnie会和你对对联<br>"
67
+ description += "输入“写诗:XXXXXXX”,Winnie会以XXXXXXX为开头写诗"
68
+
69
+ gr.Interface(
70
+ fn = general,
71
+ title = title,
72
+ description = description,
73
+ inputs = ["text", "state"],
74
+ outputs = ["text", "state"]
75
  ).launch()
76
+
77
+ if __name__ == "__main__":
78
+ main()
bot/chatbot.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from tokenize import tokenize
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from transformers import BertTokenizer, GPT2LMHeadModel
7
+
8
+ def replace_special_tokens(tokens:list, special_token_dict:dict)->list:
9
+ replaced_tokens:list = []
10
+
11
+ for token in tokens:
12
+ if token in special_token_dict.keys():
13
+ replaced_tokens.append( special_token_dict[token] )
14
+ else:
15
+ replaced_tokens.append( token )
16
+
17
+ return replaced_tokens
18
+
19
+ def top_k_top_p_filtering(logits, top_k:int=0, top_p:float=0.0, filter_value:float=-float('Inf')):
20
+ top_k = min( top_k, logits.size(-1) )
21
+
22
+ if top_k > 0:
23
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
24
+ logits[indices_to_remove] = filter_value
25
+
26
+ if top_p > 0.0:
27
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
28
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
29
+ sorted_indices_to_remove = cumulative_probs > top_p
30
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
31
+ sorted_indices_to_remove[..., 0] = 0
32
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
33
+ logits[indices_to_remove] = filter_value
34
+
35
+ return logits
36
+
37
+ class ChatBot():
38
+
39
+ def get_chat_bot(checkpoint:str, vocab_path:str = None, special_token_list:list = [])->object:
40
+ tokenizer = ChatBot.get_tokenizer(checkpoint, vocab_path, special_token_list)
41
+ model = GPT2LMHeadModel.from_pretrained(checkpoint)
42
+
43
+ return ChatBot(tokenizer, model)
44
+
45
+ def get_tokenizer(checkpoint:str, vocab_path:str = None, special_token_list:list = [])->object:
46
+ if vocab_path is None:
47
+ tokenizer = BertTokenizer.from_pretrained(checkpoint, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
48
+ else:
49
+ tokenizer = BertTokenizer(vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
50
+ tokenizer.add_special_tokens( {'additional_special_tokens':special_token_list} )
51
+
52
+ return tokenizer
53
+
54
+ def __init__(self, tokenizer:object, model:object)->None:
55
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
56
+ self.tokenizer = tokenizer
57
+ self.model = model
58
+ self.model.to(self.device)
59
+
60
+ def convert_ids_to_tokens(self, ids:list):
61
+ return self.tokenizer.convert_ids_to_tokens(ids)
62
+
63
+ def convert_ids_to_text(self, ids):
64
+ return "".join( self.convert_ids_to_tokens(ids) )
65
+
66
+ def convert_text_to_ids(self, text, add_special_tokens=False):
67
+ return self.tokenizer.encode(text, add_special_tokens=add_special_tokens)
68
+
69
+ def get_prediction(self, input_tensor, input_ids, repetition_penalty, temperature, top_k, top_p):
70
+ self.model.eval()
71
+
72
+ generated_ids = []
73
+ for _ in range(64):
74
+ output_pt = self.model(input_tensor)
75
+
76
+ next_token_logits = output_pt.logits[0, -1, :]
77
+ for id in set(input_ids):
78
+ if id != self.tokenizer.sep_token_id:
79
+ next_token_logits[id] /= repetition_penalty
80
+ for id in set(generated_ids):
81
+ next_token_logits[id] /= repetition_penalty
82
+ next_token_logits = next_token_logits / temperature
83
+ next_token_logits[self.tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
84
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
85
+ next_token = torch.multinomial( F.softmax(filtered_logits, dim=-1), num_samples=1 )
86
+
87
+ if next_token == self.tokenizer.sep_token_id:
88
+ break
89
+
90
+ input_tensor = torch.cat( (input_tensor, next_token.unsqueeze(0)), dim=1 )
91
+ generated_ids.append( next_token.item() )
92
+
93
+ return generated_ids
94
+
95
+ def chat(self:object, text:str, history:list, role_card:dict={}) -> str:
96
+ text_ids = self.tokenizer.encode(text, add_special_tokens=False)
97
+ history.append(text_ids)
98
+ input_ids = [self.tokenizer.cls_token_id]
99
+ for history_utr in history[-50:]:
100
+ input_ids.extend(history_utr)
101
+ input_ids.append(self.tokenizer.sep_token_id)
102
+ input_tensor = torch.tensor(input_ids).to(self.device).unsqueeze(0)
103
+ generated_ids = self.get_prediction(input_tensor, input_ids, repetition_penalty=1.2, temperature=0.73, top_k=10, top_p=0.7)
104
+
105
+ history.append(generated_ids)
106
+
107
+ generated_tokens = replace_special_tokens( self.convert_ids_to_tokens(generated_ids), role_card )
108
+
109
+ return "".join(generated_tokens), history
bot/config.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ special_token_list:list = [
2
+ '<NAME>',
3
+ '<GENDER>',
4
+ '<YEAROFBIRTH>',
5
+ '<MONTHOFBIRTH>',
6
+ '<DAYOFBIRTH>',
7
+ '<ZODIAC>',
8
+ '<AGE>',
9
+ '<CMD>',
10
+ '<NICK>',
11
+ '<HEIGHT>',
12
+ '<WEIGHT>',
13
+ '<WORK>',
14
+ '<HOBBY>',
15
+ '<HOMETOWN>',
16
+ '<CITY>',
17
+ '<BUST>',
18
+ '<WAIST>',
19
+ '<HIP>',
20
+ '<CUP>'
21
+ ]
bot/interface.py DELETED
@@ -1,48 +0,0 @@
1
-
2
- from random import choice
3
- from random import randint
4
- from random import uniform
5
-
6
- from bot.simctgdialogue import SimCTGDialogue
7
-
8
- class Chatbot():
9
- def __init__(self):
10
- self.model = SimCTGDialogue("cambridgeltl/simctg_lccc_dialogue", [])
11
- self.tokenizer = self.model.tokenizer
12
- self.model.eval()
13
-
14
- def __contrastive_search(self, context_list):
15
- print("__contrastive_search")
16
- print(context_list)
17
- beam_width, alpha, decoding_len = randint(1, 8), uniform(0.10, 0.40), 64
18
- return self.model.contrastive_search(context_list, beam_width, alpha, decoding_len)
19
-
20
- def __diverse_contrastive_search(self, context_list):
21
- print("__diverse_contrastive_search")
22
- print(context_list)
23
- sample_step, nucleus_p = 1, uniform(0.10, 0.40)
24
- beam_width, alpha, decoding_len = randint(1, 5), uniform(0.10, 0.40), 64
25
- return self.model.diverse_contrastive_search(context_list, sample_step, nucleus_p, beam_width, alpha, decoding_len)
26
-
27
- def __greedy_search(self, context_list):
28
- print("__greedy_search")
29
- print(context_list)
30
- decoding_len = 64
31
- return self.model.greedy_search(context_list, decoding_len)
32
-
33
- def __beam_search(self, context_list):
34
- print("__beam_search")
35
- print(context_list)
36
- beam_width, decoding_len = randint(1, 9), 64
37
- return self.model.beam_search(context_list, beam_width, decoding_len)
38
-
39
- def chat(self, prefix = []):
40
- methods_for_sort_dialogue = [self.__contrastive_search, self.__greedy_search]
41
- methods_for_long_dialogue = [self.__beam_search, self.__diverse_contrastive_search, self.__greedy_search, self.__contrastive_search]
42
-
43
- if ( len(prefix) < 4 ):
44
- response = choice(methods_for_sort_dialogue)(prefix)
45
- else:
46
- response = choice(methods_for_long_dialogue)(prefix)
47
-
48
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bot/simctgdialogue.py DELETED
@@ -1,177 +0,0 @@
1
-
2
- import torch
3
-
4
- from torch import nn
5
-
6
- class SimCTGDialogue(nn.Module):
7
- def __init__(self, model_name, additional_special_tokens):
8
- super(SimCTGDialogue, self).__init__()
9
- from transformers import AutoTokenizer, GPT2LMHeadModel
10
- eos_token = '[SEP]'
11
- pad_token = '[PAD]'
12
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, additional_special_tokens=additional_special_tokens)
13
- self.vocab_size = len(self.tokenizer)
14
- self.model = GPT2LMHeadModel.from_pretrained(model_name)
15
- self.embed_dim = self.model.config.hidden_size
16
- if pad_token in self.tokenizer.vocab:
17
- print ('PAD token exists.')
18
- else:
19
- print ('Add PAD token to the tokenizer.')
20
- print ('Original vocabulary size is {}'.format(len(self.tokenizer)))
21
- self.tokenizer.add_tokens([pad_token])
22
- print ('Vocabulary size after extension is {}'.format(len(self.tokenizer)))
23
- assert len(self.tokenizer.convert_tokens_to_ids([pad_token])) == 1
24
- self.model.resize_token_embeddings(len(self.tokenizer))
25
- self.pad_token_id = self.tokenizer.convert_tokens_to_ids([pad_token])[0]
26
- self.vocab_size = len(self.tokenizer)
27
- if 'e' in eos_token:
28
- self.eos_token = self.tokenizer.eos_token
29
- else:
30
- self.eos_token = eos_token
31
- print (self.eos_token)
32
-
33
- def parse_dialogue_context(self, context_list, cuda_available=False, device=0):
34
- # context_list: a list of utterances in the dialogue session
35
- uttr_num = len(context_list)
36
- context_text = self.eos_token.join(context_list).strip(self.eos_token) + self.eos_token
37
- #print (context_text)
38
- tokens = self.tokenizer.tokenize(context_text)
39
- input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
40
- input_ids = input_ids
41
- input_ids = torch.LongTensor(input_ids).view(1,-1)
42
- if cuda_available:
43
- input_ids = input_ids.cuda(device)
44
- return input_ids, uttr_num
45
-
46
- def extract_response(self, output_ids, uttr_num):
47
- output_text = self.tokenizer.decode(output_ids)
48
- # extract response
49
- item_list = output_text.split(self.eos_token)
50
- response = item_list[uttr_num].strip()
51
- if self.eos_token == '<|endoftext|>': # English GPT
52
- response = ' '.join(response.split())
53
- else:
54
- response = ''.join(response.split())
55
- return response
56
-
57
- def contrastive_search(self, context_list, beam_width, alpha, decoding_len,
58
- cuda_available=False, device=0):
59
- input_ids, uttr_num = self.parse_dialogue_context(context_list,
60
- cuda_available=cuda_available, device=device)
61
- output = self.fast_contrastive_generation(input_ids, beam_width, alpha, decoding_len)
62
- return self.extract_response(output, uttr_num)
63
-
64
- def diverse_contrastive_search(self, context_list, sample_step, nucleus_p,
65
- beam_width, alpha, decoding_len, cuda_available=False, device=0):
66
- input_ids, uttr_num = self.parse_dialogue_context(context_list,
67
- cuda_available=cuda_available, device=device)
68
- output = self.diverse_contrastive_generation(input_ids, sample_step, nucleus_p,
69
- beam_width, alpha, decoding_len)
70
- return self.extract_response(output, uttr_num)
71
-
72
- def greedy_search(self, context_list, decoding_len, cuda_available=False, device=0):
73
- input_ids, uttr_num = self.parse_dialogue_context(context_list,
74
- cuda_available=cuda_available, device=device)
75
- output = self.greedy_generation(input_ids, decoding_len)
76
- return self.extract_response(output, uttr_num)
77
-
78
- def beam_search(self, context_list, beam_width, decoding_len,
79
- cuda_available=False, device=0):
80
- input_ids, uttr_num = self.parse_dialogue_context(context_list,
81
- cuda_available=cuda_available, device=device)
82
- output = self.beam_generation(input_ids, beam_width, decoding_len)
83
- return self.extract_response(output, uttr_num)
84
-
85
- def nucleus_sampling(self, context_list, nucleus_p, decoding_len,
86
- cuda_available=False, device=0):
87
- input_ids, uttr_num = self.parse_dialogue_context(context_list,
88
- cuda_available=cuda_available, device=device)
89
- output = self.nucleus_generation(input_ids, nucleus_p, decoding_len)
90
- return self.extract_response(output, uttr_num)
91
-
92
- def fast_contrastive_generation(self, input_ids, beam_width, alpha, decoding_len):
93
- '''
94
- input_ids: prefix input; 1 x prefix_len
95
- decoding_len: how many tokens to generate
96
- beam_width: size of candidate pool during decoding
97
- alpha: regulates importance of model confidence and degeneration penalty
98
- '''
99
- self.model.eval()
100
- from bot.utlis import ContrastiveDecodingOneStepFast
101
- # sanity check
102
- assert alpha >= 0. and alpha <= 1.0
103
-
104
- # fast mode
105
- batch_size, seqlen = input_ids.size()
106
- #generated = [[] for _ in range(batch_size)]
107
- generated = [item for item in input_ids.tolist()]
108
- past_key_values = None
109
- last_hidden_states = None
110
- logits = None
111
- for step in range(decoding_len):
112
- input_ids, past_key_values, last_hidden_states, logits = ContrastiveDecodingOneStepFast(
113
- self.model,
114
- input_ids,
115
- beam_width,
116
- alpha,
117
- past_key_values,
118
- last_hidden_states,
119
- self.tokenizer,
120
- logits,
121
- first_step=step == 0,
122
- )
123
- tokens = input_ids.squeeze(dim=-1).tolist()
124
- for idx, t in enumerate(tokens):
125
- generated[idx].append(t)
126
- return generated[0]
127
-
128
- def diverse_contrastive_generation(self, input_ids, sample_step, nucleus_p, beam_width, alpha, decoding_len):
129
- '''
130
- sample_step:
131
- number of steps to decode with nucleus sampling,
132
- for the remaining steps we use contrastive search
133
- decoding_len:
134
- the total number of generated tokens
135
- beam_width:
136
- size of candidate pool during decoding
137
- alpha:
138
- regulates importance of model confidence and degeneration penalty
139
-
140
- '''
141
- contrastive_step = decoding_len - sample_step
142
- _, prefix_len = input_ids.size()
143
- # first do sample
144
- input_ids = self.model.generate(
145
- input_ids,
146
- do_sample=True,
147
- max_length=prefix_len+sample_step,
148
- top_p=nucleus_p,
149
- top_k=0)
150
- # then do contrastive search
151
- output = self.fast_contrastive_generation(input_ids, beam_width, alpha, contrastive_step)
152
- return output
153
-
154
- def greedy_generation(self, input_ids, decoding_len):
155
- _, prefix_len = input_ids.size()
156
- output = self.model.generate(
157
- input_ids,
158
- max_length=prefix_len+decoding_len)
159
- return output[0]
160
-
161
- def beam_generation(self, input_ids, beam_width, decoding_len):
162
- _, prefix_len = input_ids.size()
163
- output = self.model.generate(
164
- input_ids,
165
- max_length=prefix_len+decoding_len,
166
- num_beams=beam_width)
167
- return output[0]
168
-
169
- def nucleus_generation(self, input_ids, nucleus_p, decoding_len):
170
- _, prefix_len = input_ids.size()
171
- output = self.model.generate(
172
- input_ids,
173
- do_sample=True,
174
- max_length=prefix_len+decoding_len,
175
- top_p=nucleus_p,
176
- top_k=0)
177
- return output[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bot/skills/couplet.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import requests
3
+
4
+ class Skill:
5
+ def __init__(self:object) -> None:
6
+ pass
7
+
8
+ def process(self:object, input_txt:str, history_list:list, role_card:dict):
9
+ output_text:str = None
10
+ if input_txt.startswith("上联:") or input_txt.startswith("上联:"):
11
+ output_text = requests.post(
12
+ url='https://hf.space/embed/lewiswu1209/gpt2-chinese-couplet/+/api/predict/',
13
+ json={"data": [input_txt[3:]]}
14
+ ).json()["data"][0]
15
+ output_text = "我对下联:" + output_text
16
+ return output_text, history_list, role_card
bot/skills/delete_memory.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class Skill:
3
+ def __init__(self:object) -> None:
4
+ pass
5
+
6
+ def process(self:object, input_txt:str, history_list:list, role_card:dict):
7
+ output_txt:str = None
8
+ if input_txt.upper()=="ERASE MEMORY":
9
+ history_list = []
10
+ output_txt = "我是谁?我在哪?我在干什么?"
11
+ return output_txt, history_list, role_card
bot/skills/give_role.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class Skill:
3
+ def __init__(self:object) -> None:
4
+ pass
5
+
6
+ def process(self:object, input_txt:str, history_list:list, role_card:dict):
7
+ output_txt:str = None
8
+ for tag in role_card.keys():
9
+ prefix:str = "{}=".format(tag)
10
+ if input_txt.startswith( prefix ):
11
+ role_card[tag]=input_txt[len(prefix):]
12
+ output_txt = "已设置{}为{}".format(tag, role_card[tag])
13
+ break
14
+ return output_txt, history_list, role_card
bot/skills/poem.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import requests
3
+
4
+ class Skill:
5
+ def __init__(self:object) -> None:
6
+ pass
7
+
8
+ def process(self:object, input_txt:str, history_list:list, role_card:dict):
9
+ output_text:str = None
10
+ if input_txt.startswith("写诗:") or input_txt.startswith("写诗:"):
11
+ output_text = requests.post(
12
+ url='https://hf.space/embed/lewiswu1209/gpt2-chinese-poem/+/api/predict/',
13
+ json={"data": [input_txt[3:]]}
14
+ ).json()["data"][0]
15
+ return output_text, history_list, role_card
bot/utlis.py DELETED
@@ -1,174 +0,0 @@
1
-
2
- import torch
3
- import random
4
- import torch.nn.functional as F
5
-
6
- def ranking(context_hidden, next_hidden, next_top_k_ids, next_top_k_probs, alpha):
7
- '''
8
- context_hidden: beam_width x context_len x embed_dim
9
- next_hidden: beam_width x 1 x embed_dim
10
- next_top_k_ids: beam_width x 1
11
- '''
12
- beam_width, context_len, embed_dim = context_hidden.size()
13
- assert next_hidden.size() == torch.Size([beam_width, 1, embed_dim])
14
- norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
15
- norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
16
- cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1)
17
- assert cosine_matrix.size() == torch.Size([beam_width, context_len])
18
- scores, _ = torch.max(cosine_matrix, dim = -1)
19
- assert scores.size() == torch.Size([beam_width])
20
- next_top_k_probs = next_top_k_probs.view(-1)
21
- scores = (1.0 - alpha) * next_top_k_probs - alpha * scores
22
- _, selected_idx = torch.topk(scores, k = 1)
23
- assert selected_idx.size() == torch.Size([1])
24
- selected_idx = selected_idx.unsqueeze(0)
25
- assert selected_idx.size() == torch.Size([1,1])
26
- next_id = torch.gather(next_top_k_ids, dim = 0, index=selected_idx)
27
- assert next_id.size() == torch.Size([1,1])
28
- return next_id
29
-
30
- def ContrastiveDecodingOneStep(model, input_ids, beam_width, alpha):
31
- '''
32
- model: the generation model, e.g., gpt2
33
- input_ids: 1 x seqlen
34
- '''
35
- prev_hidden_states, logits = model.compute_logits_and_hidden_states(input_ids)
36
- _, seqlen, embed_dim = prev_hidden_states.size()
37
- _, _, vocab_size = logits.size()
38
- p = random.uniform(0, 1)
39
-
40
- logit_for_next_step = logits[:,-1,:]
41
- assert logit_for_next_step.size() == torch.Size([1, vocab_size])
42
-
43
- next_probs = F.softmax(logit_for_next_step, dim = -1)
44
- assert next_probs.size() == logit_for_next_step.size()
45
-
46
- _, top_k_ids = torch.topk(logit_for_next_step, dim = -1, k = beam_width)
47
- assert top_k_ids.size() == torch.Size([1, beam_width])
48
-
49
- top_k_probs = torch.gather(next_probs, dim = 1, index=top_k_ids)
50
-
51
- assert top_k_probs.size() == top_k_ids.size()
52
- # compute new hidden
53
- expanded_context = [input_ids for _ in range(beam_width)]
54
- expanded_context = torch.cat(expanded_context, dim = 0)
55
- assert expanded_context.size() == torch.Size([beam_width, seqlen])
56
- top_k_ids = top_k_ids.view(beam_width, 1)
57
- next_input_ids = torch.cat([expanded_context, top_k_ids], dim = -1)
58
- assert next_input_ids.size() == torch.Size([beam_width, seqlen+1])
59
- new_hidden_states, next_logits = model.compute_logits_and_hidden_states(next_input_ids)
60
- assert new_hidden_states.size() == torch.Size([beam_width, seqlen+1, embed_dim])
61
- context_hidden = new_hidden_states[:,:seqlen,:]
62
- assert context_hidden.size() == torch.Size([beam_width, seqlen, embed_dim])
63
- next_hidden = new_hidden_states[:,seqlen:,:]
64
- assert next_hidden.size() == torch.Size([beam_width, 1, embed_dim])
65
-
66
- next_id = ranking(context_hidden, next_hidden, top_k_ids, top_k_probs, alpha)
67
-
68
- next_input_ids = torch.cat([input_ids, next_id], dim = -1)
69
- assert next_input_ids.size() == torch.Size([1, seqlen+1])
70
- return next_input_ids
71
-
72
- # ========== batch version ========= #
73
- def ranking_fast(context_hidden, next_hidden, next_top_k_probs, alpha, beam_width):
74
- '''
75
- context_hidden: bsz*beam x seqlen x embed_dim
76
- next_hidden: bsz*beam x 1 x embed_dim
77
- next_top_k_probs: bsz x beam
78
- '''
79
- _, context_len, embed_dim = context_hidden.size()
80
- norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
81
- norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
82
- cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1) # [B*K, S]
83
- scores, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
84
- next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
85
- scores = (1.0 - alpha) * next_top_k_probs - alpha * scores
86
- scores = torch.stack(torch.split(scores, beam_width)) # [B, K]
87
- selected_idx = scores.max(dim=-1)[1] # [B]
88
- return selected_idx
89
-
90
- def ContrastiveDecodingOneStepFast(
91
- model,
92
- ids,
93
- beam_width,
94
- alpha,
95
- past_key_values,
96
- last_hidden_states,
97
- vocab,
98
- logit_for_next_step,
99
- first_step=False,
100
- ):
101
- # input_ids: [B, S]
102
- if first_step:
103
- output = model(
104
- input_ids=ids,
105
- past_key_values=past_key_values,
106
- use_cache=True,
107
- output_hidden_states=True
108
- )
109
- past_key_values = output.past_key_values
110
- last_hidden_states = output.hidden_states[-1] # [B, S, E]
111
- logit_for_next_step = output.logits[:, -1, :] # [B, V]
112
- bsz, seqlen, embed_dim = last_hidden_states.size()
113
- p = random.uniform(0, 1)
114
-
115
- next_probs = F.softmax(logit_for_next_step, dim=-1)
116
- _, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width) # [B, K]
117
- top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids) # [B, K]
118
- # compute new hidden
119
- past_key_values = enlarge_past_key_values(past_key_values, beam_width)
120
- output = model(
121
- input_ids=top_k_ids.view(-1, 1),
122
- attention_mask=torch.ones_like(top_k_ids.view(-1, 1)),
123
- past_key_values=past_key_values,
124
- output_hidden_states=True,
125
- use_cache=True,
126
- )
127
- past_key_values = output.past_key_values
128
- logits = output.logits[:, -1, :] # [B*K, V]
129
- next_hidden = output.hidden_states[-1] # [B*K, 1, E]
130
- context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz*beam_width, seqlen, embed_dim) # [B*K, S, E]
131
-
132
- selected_idx = ranking_fast(
133
- context_hidden,
134
- next_hidden,
135
- top_k_probs, # [B, K]
136
- alpha,
137
- beam_width,
138
- ) # [B]
139
- # prepare for the next step
140
- next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1) # [B, 1]
141
- next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width)) # [B, K, E]
142
- next_hidden = next_hidden[range(bsz), selected_idx, :] # [B, E]
143
- last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) # [B, S, E]
144
- past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx)
145
- logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :] # [B, V]
146
- # next_id: [B, 1]
147
- return next_id, past_key_values, last_hidden_states, logits
148
-
149
- def enlarge_past_key_values(past_key_values, beam_width):
150
- # from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz]
151
- new_key_values = []
152
- for layer in past_key_values:
153
- items = []
154
- for item in layer:
155
- # item is the key and value matrix
156
- bsz, num_head, seq_len, esz = item.size()
157
- item = item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz*beam_width, num_head, seq_len, esz) # [bsz*beam, num_head, seq_len, esz]
158
- items.append(item)
159
- new_key_values.append(items)
160
- return new_key_values
161
-
162
- def select_past_key_values(past_key_values, beam_width, selected_idx):
163
- '''select_idx: [B]'''
164
- new_key_values = []
165
- for layer in past_key_values:
166
- items = []
167
- for item in layer:
168
- bsz_and_beam, num_head, seq_len, esz = item.size()
169
- bsz = int(bsz_and_beam//beam_width)
170
- item = torch.stack(torch.split(item, beam_width, dim=0)) # [B, K, num_head, seq_len, esz]
171
- item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz]
172
- items.append(item)
173
- new_key_values.append(items)
174
- return new_key_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/.gitkeep ADDED
File without changes
data_parallel.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn.parallel import DataParallel
2
+ import torch
3
+ from torch.nn.parallel._functions import Scatter
4
+ from torch.nn.parallel.parallel_apply import parallel_apply
5
+
6
+
7
+ def scatter(inputs, target_gpus, chunk_sizes, dim=0):
8
+ r"""
9
+ Slices tensors into approximately equal chunks and
10
+ distributes them across given GPUs. Duplicates
11
+ references to objects that are not tensors.
12
+ """
13
+ def scatter_map(obj):
14
+ if isinstance(obj, torch.Tensor):
15
+ try:
16
+ return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
17
+ except:
18
+ print('obj', obj.size())
19
+ print('dim', dim)
20
+ print('chunk_sizes', chunk_sizes)
21
+ quit()
22
+ if isinstance(obj, tuple) and len(obj) > 0:
23
+ return list(zip(*map(scatter_map, obj)))
24
+ if isinstance(obj, list) and len(obj) > 0:
25
+ return list(map(list, zip(*map(scatter_map, obj))))
26
+ if isinstance(obj, dict) and len(obj) > 0:
27
+ return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
28
+ return [obj for targets in target_gpus]
29
+
30
+ # After scatter_map is called, a scatter_map cell will exist. This cell
31
+ # has a reference to the actual function scatter_map, which has references
32
+ # to a closure that has a reference to the scatter_map cell (because the
33
+ # fn is recursive). To avoid this reference cycle, we set the function to
34
+ # None, clearing the cell
35
+ try:
36
+ return scatter_map(inputs)
37
+ finally:
38
+ scatter_map = None
39
+
40
+
41
+ def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
42
+ r"""Scatter with support for kwargs dictionary"""
43
+ inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
44
+ kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
45
+ if len(inputs) < len(kwargs):
46
+ inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
47
+ elif len(kwargs) < len(inputs):
48
+ kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
49
+ inputs = tuple(inputs)
50
+ kwargs = tuple(kwargs)
51
+ return inputs, kwargs
52
+
53
+
54
+ class BalancedDataParallel(DataParallel):
55
+ def __init__(self, gpu0_bsz, *args, **kwargs):
56
+ self.gpu0_bsz = gpu0_bsz
57
+ super().__init__(*args, **kwargs)
58
+
59
+ def forward(self, *inputs, **kwargs):
60
+ if not self.device_ids:
61
+ return self.module(*inputs, **kwargs)
62
+ if self.gpu0_bsz == 0:
63
+ device_ids = self.device_ids[1:]
64
+ else:
65
+ device_ids = self.device_ids
66
+ inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
67
+ # print('len(inputs)1: ', str(len(inputs)))
68
+ # print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)]))
69
+ if len(self.device_ids) == 1:
70
+ return self.module(*inputs[0], **kwargs[0])
71
+ replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
72
+ if self.gpu0_bsz == 0:
73
+ replicas = replicas[1:]
74
+ outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
75
+ return self.gather(outputs, self.output_device)
76
+
77
+ def parallel_apply(self, replicas, device_ids, inputs, kwargs):
78
+ return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)])
79
+
80
+ def scatter(self, inputs, kwargs, device_ids):
81
+ bsz = inputs[0].size(self.dim)
82
+ num_dev = len(self.device_ids)
83
+ gpu0_bsz = self.gpu0_bsz
84
+ bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
85
+ if gpu0_bsz < bsz_unit:
86
+ chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
87
+ delta = bsz - sum(chunk_sizes)
88
+ for i in range(delta):
89
+ chunk_sizes[i + 1] += 1
90
+ if gpu0_bsz == 0:
91
+ chunk_sizes = chunk_sizes[1:]
92
+ else:
93
+ return super().scatter(inputs, kwargs, device_ids)
94
+
95
+ # print('bsz: ', bsz)
96
+ # print('num_dev: ', num_dev)
97
+ # print('gpu0_bsz: ', gpu0_bsz)
98
+ # print('bsz_unit: ', bsz_unit)
99
+ # print('chunk_sizes: ', chunk_sizes)
100
+ return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)
dataset.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import torch
3
+
4
+
5
+ class MyDataset(Dataset):
6
+ """
7
+
8
+ """
9
+
10
+ def __init__(self, input_list, max_len):
11
+ self.input_list = input_list
12
+ self.max_len = max_len
13
+
14
+ def __getitem__(self, index):
15
+ input_ids = self.input_list[index]
16
+ input_ids = input_ids[:self.max_len]
17
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
18
+ return input_ids
19
+
20
+ def __len__(self):
21
+ return len(self.input_list)
preprocess.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizers import BertWordPieceTokenizer
2
+ from transformers import BertTokenizer
3
+ from transformers import BertTokenizerFast
4
+ import argparse
5
+ import pandas as pd
6
+ import pickle
7
+ import jieba.analyse
8
+ from tqdm import tqdm
9
+ from transformers import GPT2TokenizerFast, GPT2LMHeadModel
10
+ import logging
11
+ import numpy as np
12
+ from chatbot.config import config
13
+
14
+
15
+ def create_logger(log_path):
16
+ """
17
+ 将日志输出到日志文件和控制台
18
+ """
19
+ logger = logging.getLogger(__name__)
20
+ logger.setLevel(logging.INFO)
21
+
22
+ formatter = logging.Formatter(
23
+ '%(asctime)s - %(levelname)s - %(message)s')
24
+
25
+ # 创建一个handler,用于写入日志文件
26
+ file_handler = logging.FileHandler(
27
+ filename=log_path)
28
+ file_handler.setFormatter(formatter)
29
+ file_handler.setLevel(logging.INFO)
30
+ logger.addHandler(file_handler)
31
+
32
+ # 创建一个handler,用于将日志输出到控制台
33
+ console = logging.StreamHandler()
34
+ console.setLevel(logging.DEBUG)
35
+ console.setFormatter(formatter)
36
+ logger.addHandler(console)
37
+
38
+ return logger
39
+
40
+
41
+ def preprocess():
42
+ """
43
+ 对原始语料进行tokenize,将每段对话处理成如下形式:"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]"
44
+ """
45
+ # 设置参数
46
+ parser = argparse.ArgumentParser()
47
+ parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, required=False,
48
+ help='词表路径')
49
+ parser.add_argument('--log_path', default='data/preprocess.log', type=str, required=False, help='训练日志存放位置')
50
+ parser.add_argument('--train_path', default='data/train.txt', type=str, required=False, help='训练日志存放位置')
51
+ parser.add_argument('--save_path', default='data/train.pkl', type=str, required=False, help='tokenize的训练数据集')
52
+ args = parser.parse_args()
53
+
54
+ # 初始化日志对象
55
+ logger = create_logger(args.log_path)
56
+
57
+ # 初始化tokenizer
58
+ tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
59
+ special_tokens = []
60
+ for key in config["mask_token"].keys():
61
+ special_tokens.append(key)
62
+ tokenizer.add_special_tokens( {'additional_special_tokens':special_tokens} )
63
+ sep_id = tokenizer.sep_token_id
64
+ cls_id = tokenizer.cls_token_id
65
+ logger.info("preprocessing data,data path:{}, save path:{}".format(args.train_path, args.save_path))
66
+
67
+ # 读取训练数据集
68
+ with open(args.train_path, 'rb') as f:
69
+ data = f.read().decode("utf-8")
70
+
71
+ # 需要区分linux和windows环境下的换行符
72
+ if "\r\n" in data:
73
+ train_data = data.split("\r\n\r\n")
74
+ else:
75
+ train_data = data.split("\n\n")
76
+ logger.info("there are {} dialogue in dataset".format(len(train_data)))
77
+
78
+ # 开始进行tokenize
79
+ # 保存所有的对话数据,每条数据的格式为:"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]"
80
+ dialogue_len = [] # 记录所有对话tokenize之后的长度,用于统计中位数与均值
81
+ dialogue_list = []
82
+ with open(args.save_path, "w", encoding="utf-8") as f:
83
+ for index, dialogue in enumerate(tqdm(train_data)):
84
+ if "\r\n" in data:
85
+ utterances = dialogue.split("\r\n")
86
+ else:
87
+ utterances = dialogue.split("\n")
88
+
89
+ input_ids = [cls_id] # 每个dialogue以[CLS]开头
90
+ for utterance in utterances:
91
+ input_ids += tokenizer.encode(utterance, add_special_tokens=False)
92
+ input_ids.append(sep_id) # 每个utterance之后添加[SEP],表示utterance结束
93
+ dialogue_len.append(len(input_ids))
94
+ dialogue_list.append(input_ids)
95
+ len_mean = np.mean(dialogue_len)
96
+ len_median = np.median(dialogue_len)
97
+ len_max = np.max(dialogue_len)
98
+ with open(args.save_path, "wb") as f:
99
+ pickle.dump(dialogue_list, f)
100
+ logger.info("finish preprocessing data,the result is stored in {}".format(args.save_path))
101
+ logger.info("mean of dialogue len:{},median of dialogue len:{},max len:{}".format(len_mean, len_median, len_max))
102
+
103
+
104
+ if __name__ == '__main__':
105
+ preprocess()
pytorchtools.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from os.path import join
4
+ import os
5
+
6
+ class EarlyStopping:
7
+ """Early stops the training if validation loss doesn't improve after a given patience."""
8
+ def __init__(self, patience=7, verbose=False, delta=0, save_path="."):
9
+ """
10
+ Args:
11
+ patience (int): How long to wait after last time validation loss improved.
12
+ Default: 7
13
+ verbose (bool): If True, prints a message for each validation loss improvement.
14
+ Default: False
15
+ delta (float): Minimum change in the monitored quantity to qualify as an improvement.
16
+ Default: 0
17
+ """
18
+ self.patience = patience
19
+ self.verbose = verbose
20
+ self.counter = 0
21
+ self.best_score = None
22
+ self.early_stop = False
23
+ self.val_loss_min = np.Inf
24
+ self.delta = delta
25
+ self.save_path = save_path
26
+
27
+ def __call__(self, val_loss, model):
28
+
29
+ score = -val_loss
30
+
31
+ if self.best_score is None:
32
+ self.best_score = score
33
+ self.save_checkpoint(val_loss, model)
34
+ elif score < self.best_score + self.delta:
35
+ self.counter += 1
36
+ print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
37
+ if self.counter >= self.patience:
38
+ self.early_stop = True
39
+ else:
40
+ self.best_score = score
41
+ self.save_checkpoint(val_loss, model)
42
+ self.counter = 0
43
+
44
+ def save_checkpoint(self, val_loss, model):
45
+ '''Saves model when validation loss decrease.'''
46
+ if self.verbose:
47
+ print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
48
+ # save_path = join(self.save_path, "best_model")
49
+ # if not os.path.exists(save_path):
50
+ # os.mkdir(save_path)
51
+ # model_to_save = model.module if hasattr(model, 'module') else model
52
+ # model_to_save.save_pretrained(save_path)
53
+ self.val_loss_min = val_loss
requirements.txt CHANGED
@@ -1,18 +1,2 @@
1
- absl-py
2
- pytest
3
- sacrebleu==1.4.10
4
- six
5
- wheel
6
- progressbar
7
- sklearn
8
- torch==1.6.0
9
- torchvision==0.7.0
10
- transformers==4.7.0
11
- pyyaml
12
- nltk
13
- sentencepiece
14
- spacy
15
- gdown
16
- seaborn
17
- matplotlib
18
- pandas
 
1
+ torch
2
+ transformers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
templates/chat_template.html ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <html lang="zh">
2
+ <head>
3
+ <meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
4
+ <title>聊天机器人</title>
5
+ <style>
6
+ body {
7
+ padding:0;
8
+ margin:0;
9
+ background:-moz-linear-gradient(-45deg,#183850 0,#183850 25%,#192C46 50%,#22254C 75%,#22254C 100%);
10
+ background:-webkit-linear-gradient(-45deg,#183850 0,#183850 25%,#192C46 50%,#22254C 75%,#22254C 100%);
11
+ background-repeat:no-repeat;
12
+ background-attachment:fixed
13
+ }
14
+ ::-webkit-scrollbar {
15
+ width:10px
16
+ }
17
+ ::-webkit-scrollbar-track {
18
+ border-radius:10px;
19
+ background-color:rgba(25,147,147,0.1)
20
+ }
21
+ ::-webkit-scrollbar-thumb {
22
+ border-radius:10px;
23
+ background-color:rgba(25,147,147,0.2)
24
+ }
25
+ .chat-thread {
26
+ margin:24px auto 0 auto;
27
+ padding:0 20px 0 0;
28
+ list-style:none;
29
+ overflow-y:scroll;
30
+ overflow-x:hidden
31
+ }
32
+ .chat-thread li {
33
+ position:relative;
34
+ clear:both;
35
+ display:inline-block;
36
+ padding:16px 40px 16px 20px;
37
+ margin:0 0 20px 0;
38
+ font:16px/20px "Noto Sans",sans-serif;
39
+ border-radius:10px;
40
+ background-color:rgba(25,147,147,0.2)
41
+ }
42
+ .chat-thread li:before {
43
+ position:absolute;
44
+ top:0;
45
+ width:50px;
46
+ height:50px;
47
+ border-radius:50px;
48
+ content:""
49
+ }
50
+ .chat-thread li:after {
51
+ position:absolute;
52
+ top:15px;
53
+ content:"";
54
+ width:0;
55
+ height:0;
56
+ border-top:15px solid rgba(25,147,147,0.2)
57
+ }
58
+ .chat-thread li:nth-child(odd) {
59
+ animation:show-chat-odd .15s 1 ease-in;
60
+ -moz-animation:show-chat-odd .15s 1 ease-in;
61
+ -webkit-animation:show-chat-odd .15s 1 ease-in;
62
+ float:right;
63
+ margin-right:80px;
64
+ color:#0AD5C1
65
+ }
66
+ .chat-thread li:nth-child(odd):before {
67
+ right:-80px;
68
+ background-image:url(data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEAYABgAAD/4QAiRXhpZgAATU0AKgAAAAgAAQESAAMAAAABAAEAAAAAAAD/2wBDAAIBAQIBAQICAgICAgICAwUDAwMDAwYEBAMFBwYHBwcGBwcICQsJCAgKCAcHCg0KCgsMDAwMBwkODw0MDgsMDAz/2wBDAQICAgMDAwYDAwYMCAcIDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAz/wAARCAAwADADASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD9WfjH8RY/gz8G/FXi6SxuNUj8K6Pd6u9pbusctyIIWlKBm+VcheWbhRkngV/Ip+1z+1Rrv7ZH7RHij4i+I5IbO98Uag96LeKRpIrRGPyRqxwz7F2qGI6IAAAAo/rc/aa8ITfED9mH4keH7eSSG413wnq2nQyIwVo3mspo1YEggEFgQSCBX8o//BPf9k9P2s/jCuj3kzW1jHD58pUcsuR/PNVicRCjTdWeyNcJh516qpQ3Z4/B4wOht5Kw2l1GpKkqpQsPUHP69ah8Q6omtx281v50jIpMxcfNx/e+nY85Hev248D/APBAf4QeJNLjElncblKecwlkEijPODuwc+4r2+5/4IhfAPwZ4Cls7HwkvmX0IjaWR2eSNQQSQzE9eM4xXz8uJqDhzwi393+Z9J/qzXUuSc1+P+R+RX/BGX9rXVv2VP26fhnq0OuSafoOvapB4b8SLI5e3nsLuZY2Eirk/u3MUqnGQyDoM5/qKv7fAYHduXIr+XH9sH9jS7+CH/BQ7Q/h38OzfC417VtMtdFB/eNHPPcIkZTPQiQqeoHBr+pbWGzcTN+7ZixyyDap57DnA9BX0GDxEa1JVobSVz5zG0XRqulPeOhx/wC0s3iOf9nHxwPB+p3Gj+Ko9EuptKvYDtlguEjLoVPZiVwCMYJ6jqPwV/Yw+F+rfCU/FvxZ4V023/tJdaaCxN5ZStDHbeTHct+5g+fav2jACf3V9MV/QxC2VwQrA9mGQR6EV+fPxM/Z/tv2GPja81q1rd+GPG2pXeo2luEZTZoRCrQtuJyU+QAg4IxwOg8XPqdRUnUjqrWa6bpr9Ue/w5Uoyl7Cekua8X11TTV/ua+Z5b+zN/wUq+JkHiqz0jXvhTZtpNxcJpsmpWcWoW/7w+V8wjvIFBA82M5SQn5wBnBx137Sf/BQ/wCNHhj4qt4L8O/De1h0vfLBLq95pV3qUilVclhHCEjjXCOQ0suMqRjOAfV9YbwX4An8OXka+H9Bs9V1CDz7maSO2RlEilYkLY+ZmIIUdcH059Q8W6x4H+Ivi7XNWsX8PeIotNvZBBeWzR3Qhl3F3hYjO11ypK5yAwPevj1Wgpe0Ufd2tf8Ar8j7SWGly+zbbff57f079fI+J/AnwN8SePf+Ci37OPjfVGSx8RW9hqF1PNFYvYCVYrixXa0UpMke+0vL0AH5g+CAo6fqzqH+rYjjivF/2Z/Ctp8SfENx48uods2l3E2l6dFsG1V2Rs8mevJOMdMoD1Ax7JqDARt9K+34fo1I4ZSn12Xlv+rPgOIq1OWI5KfTd+e2/XRL8Sp4l8U6X4E8O3Wsa5qem6Lo+nrvur/ULqO1tbZfWSWQhEHuxFflH/wVU/4K7eA/iR+0X8Nfh78O9S0HxpodrbX1/rHiDT5TcRxXLKRHaQSLhGwtu8khG5W3whWBRgfyz/am/a8+JH7VOrW99468b+KvFqwP5kS6tfl4LUn/AJ4W64ggz38tQT615J4b1a+k+L1vdW5bzNLX7TFAnGUSMb1x/wBcw/4V7uZYJLCzUne6PHy3ESjioSj0aP2U1n40eIPiT4A023tdH8H+ItBkKtJb65bXlxGMfxbbYFmBGRxyPQ549x8CfF/XvDXwW8zWLXwlovh+0jCWVnpGnXlrsXAUbjckFvTPloWJHAr86vgn+1F48+A0EM3hO4WazviJI7e4G9Iy2CWUggr26HFWP25f2/8Axlr/AIANn4g1yO81O8zst7VRHHBJtIUYHzFgCSSTwucckZ/L44KpUaoU0tX8z9YlmFKnSdSa21uffv8AwR3/AOCr+l/Fzx18Ufhz401rQ9Hh07xNczeDLq4uEtYLmzUxW72gd8K0hlQzLli0hnkA4VVr9FtSDLvVlZWUcgjBFfyGaTdtb+DZLKQS3PmKwcSlW8yR+OmSPvkH2r7S/Ym/4LS/Fz9jfwvZeHG1mDxh4VsU2QaZr4e4S2jHGyGVSJYQvZFbyx12Hqf1qhlvJSjGD2SX3H4/iMX7SrKo18Tb+8//2Q==)
69
+ }
70
+ .chat-thread li:nth-child(odd):after {
71
+ border-right:15px solid transparent;
72
+ right:-15px
73
+ }
74
+ .chat-thread li:nth-child(even) {
75
+ animation:show-chat-even .15s 1 ease-in;
76
+ -moz-animation:show-chat-even .15s 1 ease-in;
77
+ -webkit-animation:show-chat-even .15s 1 ease-in;
78
+ float:left;
79
+ margin-left:80px;
80
+ color:#0EC879
81
+ }
82
+ .chat-thread li:nth-child(even):before {
83
+ left:-80px;
84
+ background-image:url(data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEAYABgAAD/4QAiRXhpZgAATU0AKgAAAAgAAQESAAMAAAABAAEAAAAAAAD/2wBDAAIBAQIBAQICAgICAgICAwUDAwMDAwYEBAMFBwYHBwcGBwcICQsJCAgKCAcHCg0KCgsMDAwMBwkODw0MDgsMDAz/2wBDAQICAgMDAwYDAwYMCAcIDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAz/wAARCAAwADADASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD9aDlnwuSW6e1fiN/wcL/8FGNW+K/xsm+A/g/VZ7Pwn4PZR4ke3nKjWNQYbvKlI5MUK4+T+JzzwBX7SeOfGtv8M/AeveJrxo47Xw7plxqUrynCqIonk59iVA/Gv5Pzreq/G/4hXWrzLLda1401WXUJVLFnllupi4BPXChgPotZ4mdloehltHmlc0L/AMM3114JMOhxXVxbuB9peBA0l4wwcKoAYbSAAc7eRwvU7PxZ+FF/8Bf2g9f0f7RdWY0vU5re3ubiEtDKiNtDFSo+RuOVxxt+tfuF/wAE2f2IvDvwM+HmnRz6XY3msSIkl1czwK5LEdBuGML/APrr1z9r/wD4J1+Cf2svhRJZ6jZ20OqeU5tL2OELJAT0yQMkZPr714/1+cnotEfRSwVOLSb1fkfmT/wTA/4KL6t+xf8AE+w0zVtUvLj4Z+IriODXNIllaaHRZXIVb62VifLwSC+0gMhyRkZr9wt6OiSRyLNHIodJE+66sMhlPoRyMckdTwK/md+Ofwy179nv4n3XgvxJD5epaPKbMsVwLuBsiNwf4g3Sv2N/4Iaftmt+07+yzN4V1a6a48UfC+VNMnaRsvdWDA/ZpSepK4MZHUbQe4r0MHir+6eTmeDSXPH5/wCZof8ABe74sXvwz/4JgeP4dOby5vEyx6NNKr7TFbu4FwVH8R2Ky49OTivxg/4J4/BtdW8V3njfVpNasNB8HkSzXOkWouLyJ8Ar5SEH1Hbp1IFfrl/wcf6ZPrf/AATU1ZYGk8631SK7kZWO7yY1fzFb+8pTgk8nPtXwb/wRF/aA0X/hO7/w3NHEl5d20U95bOg8ssn7vzVGMYYYB59awzWUuT3ex0cPxg5LmP0A/YX/AGjo/Et60tr4h+JWp6d9sbSZ9L8aaZaW9w0yvEjzWlzbkpcRBpoUJBwGlVQSQceo/tfftRX/AOz98WtL0SXUPiRNDql6limm+D9Ctbnyy+8LLcTXBUQxsysN2dq7QGK7lLX/AIreMtI8J6j8PrWWS3tdP1TW7eWWST5YkEBUpk5xkeYyqAAdrNjjOfoZ9Q0vxFq9xNHHDcT6fdT2ouFwWUhyJEDDH8QIK+g75ryaWt0vI9rEbpvzPx//AOC9PwM/4Sn4I+DfjBp7a5JeRzLZzf2ppwsdQMMnzRmWIKCHVgRyDwcjjBPk/wDwQD+KzeD/APgoEtrAzJZ+P9Ams7iHcQi3EbBs49AwZuTkfKeARX3t/wAF8vij4V8JfsjGz8Sala2Md9dl7K1aJJLjVZ0UlYYwzDHJBLD7oFflj/wQxW/vP+Cl3wzt7WQx26T6hdSxDG3yTp03mbcYGSfLz3yvpXdhLp3OHGOM6fqmfuJ/wUA/Z/j/AGnP2SvHXhXy2kurrRrs2xHVX+zuAR745Hvkda/mj/Zf+OU/7LPx30PxZNBJPZ2f+i6rAv3pIX4kx6srZYfSv6lPjvDfXfw9utLsZJrdtXcWU8yY3RQyHbIRn+EpkZ49ODX8vvxE+CGteO/F3xRt/Cfh+61TTfh9eXlzq0thG08djZJcPGrFlG0IgwuSynI4zXr4qKl7rPDy/mS5o9Hc/aVfjL4u+NvhbwBqXw1l0PWtLup0uLt72zF4UUYMbLGeCpG9WHUEj3r61+A3iXx3aeDobjx1Joq7YlKfZrVrZ4lUAfvixIYgADd1OM9a/Mz/AIJRL4s8K+CvCb6HNNHJdW+yWCVN8TFF3BuowSpPQjoO/Nfon8bNfutL+HdxfeItWiVbG0aZwxEMAYLnkD+FevJ/GvjYS5JSR+hVpqeHjScVfe/XU/Gv/guj+3bZ/tp/tJr4f0GCT/hG/he8ul2M0h2i9vJCPtE/+6OFXPoTXgv/AATw/aGX9lL9sb4d+NpvMls9B1X7PqkSnEj2kwNvcbR/eEMjuP8AajUd66z9qH9jLxx4c8Bx/GCTRb+68O/EK7vNR0kWVu8lxDaRSH/TJ4wCVhkOWR8bSoznBFV/2OP+CaHjf/go1Lcat4DktdF0/Q7iO31vUdWtJ1023YjcRbSxqTPcBcFrcYYblLPGCCfo8PJygmkfJ1qfJN32P//Z)
85
+ }
86
+ .chat-thread li:nth-child(even):after {
87
+ border-left:15px solid transparent;
88
+ left:-15px
89
+ }
90
+ .chat-window {
91
+ position:fixed;
92
+ bottom:18px
93
+ }
94
+ .chat-window-message {
95
+ width:100%;
96
+ height:48px;
97
+ font:32px/48px "Noto Sans",sans-serif;
98
+ background:0;
99
+ color:#0AD5C1;
100
+ border:0;
101
+ border-bottom:1px solid rgba(25,147,147,0.2);
102
+ outline:0
103
+ }
104
+ @media all and (max-width:767px) {
105
+ .chat-thread {
106
+ width:90%;
107
+ height:90%
108
+ }
109
+ .chat-window {
110
+ left:5%;
111
+ width:90%
112
+ }
113
+ }
114
+ @media all and (min-width:768px) {
115
+ .chat-thread {
116
+ width:50%;
117
+ height:90%
118
+ }
119
+ .chat-window {
120
+ left:25%;
121
+ width:50%
122
+ }
123
+ }
124
+ @keyframes show-chat-even {
125
+ 0% {
126
+ margin-left:-480px
127
+ }
128
+ 100% {
129
+ margin-left:0
130
+ }
131
+ }
132
+ @-moz-keyframes show-chat-even {
133
+ 0% {
134
+ margin-left:-480px
135
+ }
136
+ 100% {
137
+ margin-left:0
138
+ }
139
+ }
140
+ @-webkit-keyframes show-chat-even {
141
+ 0% {
142
+ margin-left:-480px
143
+ }
144
+ 100% {
145
+ margin-left:0
146
+ }
147
+ }
148
+ @keyframes show-chat-odd {
149
+ 0% {
150
+ margin-right:-480px
151
+ }
152
+ 100% {
153
+ margin-right:0
154
+ }
155
+ }
156
+ @-moz-keyframes show-chat-odd {
157
+ 0% {
158
+ margin-right:-480px
159
+ }
160
+ 100% {
161
+ margin-right:0
162
+ }
163
+ }
164
+ @-webkit-keyframes show-chat-odd {
165
+ 0% {
166
+ margin-right:-480px
167
+ }
168
+ 100% {
169
+ margin-right:0
170
+ }
171
+ }
172
+ </style>
173
+ </head>
174
+ <body onload="loadhistory()">
175
+ <ul class="chat-thread">
176
+
177
+ </ul>
178
+ <div class="chat-window">
179
+ <input class="chat-window-message" name="chat-window-message" type="text" autocomplete="off" autofocus="" placeholder="对我说HELP,看看我能干什么~">
180
+ </div>
181
+ <script src="https://cdn.bootcdn.net/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
182
+ <script>
183
+ var chat_window = document.querySelector(".chat-window");
184
+
185
+ chat_window.onkeydown=function(event){
186
+ var e = event || window.event || arguments.callee.caller.arguments[0];
187
+ if (e && e.keyCode == 13 ) {
188
+ send_data();
189
+ }
190
+ }
191
+
192
+ function send_data() {
193
+ var chat_thread = document.querySelector(".chat-thread");
194
+ var chat_window_message = document.querySelector(".chat-window-message");
195
+ chat_window_message.disabled = true;
196
+ var text = chat_window_message.value;
197
+ var new_li_label = document.createElement("li"), new_li_text = document.createTextNode(text);
198
+ new_li_label.appendChild(new_li_text);
199
+ chat_thread.appendChild(new_li_label);
200
+ chat_thread.scrollTop = chat_thread.scrollHeight;
201
+ chat_window_message.value = "";
202
+ document.title = "聊天机器人 ~ 对方正在输入…"
203
+ $.getJSON("/chitchat/chat?text="+text, function(data){
204
+ var new_li_label = document.createElement("li");
205
+ data.forEach(function(item){
206
+ var new_text = document.createTextNode(item);
207
+ var new_span = document.createElement("span");
208
+ new_span.appendChild(new_text);
209
+ var new_br = document.createElement("br");
210
+ new_li_label.appendChild(new_span);
211
+ new_li_label.appendChild(new_br);
212
+ });
213
+ chat_thread.appendChild(new_li_label);
214
+ chat_thread.scrollTop = chat_thread.scrollHeight;
215
+
216
+ document.title = "聊天机器人"
217
+ chat_window_message.disabled = false;
218
+ });
219
+ }
220
+
221
+ function loadhistory() {
222
+ var chat_thread = document.querySelector(".chat-thread");
223
+ var chat_window_message = document.querySelector(".chat-window-message");
224
+ chat_window_message.disabled = true;
225
+ document.title = "聊天机器人 ~ 正在回忆…"
226
+ $.getJSON("/chitchat/history", function(data){
227
+ data.forEach(function(item) {
228
+ var new_li_label = document.createElement("li"),new_li_text = document.createTextNode(item);
229
+ new_li_label.appendChild(new_li_text);
230
+ chat_thread.appendChild(new_li_label);
231
+ });
232
+ chat_thread.scrollTop = chat_thread.scrollHeight;
233
+
234
+ chat_window_message.disabled = false;
235
+ document.title = "聊天机器人"
236
+ });
237
+ }
238
+ </script>
239
+ </body>
240
+ </html>
train.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import time
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torch.optim as optim
7
+ import logging
8
+ from datetime import datetime
9
+ import os
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from os.path import join, exists
12
+ from torch.nn import CrossEntropyLoss
13
+ from tqdm import tqdm
14
+ from torch.nn import DataParallel
15
+ import transformers
16
+ import pickle
17
+ import sys
18
+ from pytorchtools import EarlyStopping
19
+ from sklearn.model_selection import train_test_split
20
+ from data_parallel import BalancedDataParallel
21
+ from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPT2Config
22
+ from transformers import BertTokenizerFast
23
+ import pandas as pd
24
+ import torch.nn.utils.rnn as rnn_utils
25
+ import numpy as np
26
+ from dataset import MyDataset
27
+ from chatbot.config import config
28
+
29
+
30
+ def set_args():
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument('--device', default='3', type=str, required=False, help='设置使用哪些显卡')
33
+ parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行训练')
34
+ parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, required=False,
35
+ help='词表路径')
36
+ parser.add_argument('--model_config', default='config/config.json', type=str, required=False,
37
+ help='设置模型参数')
38
+ parser.add_argument('--train_path', default='data/train.pkl', type=str, required=False, help='训练集路径')
39
+ parser.add_argument('--max_len', default=150, type=int, required=False, help='训练时,输入数据的最大长度')
40
+
41
+ parser.add_argument('--log_path', default='data/train.log', type=str, required=False, help='训练日志存放位置')
42
+ parser.add_argument('--log', default=True, help="是否记录日志")
43
+ parser.add_argument('--ignore_index', default=-100, type=int, required=False, help='对于ignore_index的label token不计算梯度')
44
+ # parser.add_argument('--input_len', default=200, type=int, required=False, help='输入的长度')
45
+ parser.add_argument('--epochs', default=100, type=int, required=False, help='训练的最大轮次')
46
+ parser.add_argument('--batch_size', default=4, type=int, required=False, help='训练的batch size')
47
+ parser.add_argument('--gpu0_bsz', default=10, type=int, required=False, help='0号卡的batch size')
48
+ parser.add_argument('--lr', default=2.6e-5, type=float, required=False, help='学习率')
49
+ parser.add_argument('--eps', default=1.0e-09, type=float, required=False, help='衰减率')
50
+ parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss')
51
+ parser.add_argument('--gradient_accumulation_steps', default=4, type=int, required=False, help='梯度积累')
52
+ parser.add_argument('--max_grad_norm', default=2.0, type=float, required=False)
53
+ parser.add_argument('--save_model_path', default='model', type=str, required=False,
54
+ help='模型输出路径')
55
+ parser.add_argument('--pretrained_model', default='', type=str, required=False,
56
+ help='预训练的模型的路径')
57
+ # parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
58
+ parser.add_argument('--num_workers', type=int, default=0, help="dataloader加载数据时使用的线程数量")
59
+ parser.add_argument('--patience', type=int, default=0, help="用于early stopping,设为0时,不进行early stopping.early stop得到的模型的生成效果不一定会更好。")
60
+ parser.add_argument('--warmup_steps', type=int, default=4000, help='warm up步数')
61
+ # parser.add_argument('--label_smoothing', default=True, action='store_true', help='是否进行标签平滑')
62
+ parser.add_argument('--val_num', type=int, default=8000, help='验证集大小')
63
+ args = parser.parse_args()
64
+ return args
65
+
66
+
67
+ def create_logger(args):
68
+ """
69
+ 将日志输出到日志文件和控制台
70
+ """
71
+ logger = logging.getLogger(__name__)
72
+ logger.setLevel(logging.INFO)
73
+
74
+ formatter = logging.Formatter(
75
+ '%(asctime)s - %(levelname)s - %(message)s')
76
+
77
+ # 创建一个handler,用于写入日志文件
78
+ file_handler = logging.FileHandler(
79
+ filename=args.log_path)
80
+ file_handler.setFormatter(formatter)
81
+ file_handler.setLevel(logging.INFO)
82
+ logger.addHandler(file_handler)
83
+
84
+ # 创建一个handler,用于将日志输出到控制台
85
+ console = logging.StreamHandler()
86
+ console.setLevel(logging.DEBUG)
87
+ console.setFormatter(formatter)
88
+ logger.addHandler(console)
89
+
90
+ return logger
91
+
92
+
93
+ def collate_fn(batch):
94
+ input_ids = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=0)
95
+ labels = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=-100)
96
+ return input_ids, labels
97
+
98
+
99
+ # def padding_batch(data_list, pad_id):
100
+ # """
101
+ # 使用pad_id将data_list的每条数据,填充至data_list中最长的长度
102
+ # :param data_list:
103
+ # :param pad_id:
104
+ # :return:
105
+ # """
106
+ # # 统计data_list中的最大长度
107
+ # max_len = 0
108
+ # for data in data_list:
109
+ # max_len = max_len if max_len > len(data) else len(data)
110
+ #
111
+ # # 对数据进行padding
112
+ # new_data_list = []
113
+ # for data in data_list:
114
+ # new_data = data + [pad_id] * (max_len - len(data))
115
+ # new_data_list.append(new_data)
116
+ # return new_data_list
117
+
118
+
119
+ def load_dataset(logger, args):
120
+ """
121
+ 加载训练集和验证集
122
+ """
123
+ logger.info("loading training dataset and validating dataset")
124
+ train_path = args.train_path
125
+
126
+ with open(train_path, "rb") as f:
127
+ input_list = pickle.load(f)
128
+
129
+ # 划分训练集与验证集
130
+ val_num = args.val_num
131
+ input_list_train = input_list[val_num:]
132
+ input_list_val = input_list[:val_num]
133
+ # test
134
+ # input_list_train = input_list_train[:24]
135
+ # input_list_val = input_list_val[:24]
136
+
137
+ train_dataset = MyDataset(input_list_train, args.max_len)
138
+ val_dataset = MyDataset(input_list_val, args.max_len)
139
+
140
+ return train_dataset, val_dataset
141
+
142
+
143
+ def train_epoch(model, train_dataloader, optimizer, scheduler, logger,
144
+ epoch, args):
145
+ model.train()
146
+ device = args.device
147
+ # pad_id = args.pad_id
148
+ # sep_id = args.sep_id
149
+ ignore_index = args.ignore_index
150
+ epoch_start_time = datetime.now()
151
+ total_loss = 0 # 记录下整个epoch的loss的总和
152
+
153
+ # epoch_correct_num:每个epoch中,output预测正确的word的数量
154
+ # epoch_total_num: 每个epoch中,output预测的word的总数量
155
+ epoch_correct_num, epoch_total_num = 0, 0
156
+
157
+ for batch_idx, (input_ids, labels) in enumerate(train_dataloader):
158
+ # 捕获cuda out of memory exception
159
+ try:
160
+ input_ids = input_ids.to(device)
161
+ labels = labels.to(device)
162
+ outputs = model.forward(input_ids, labels=labels)
163
+ logits = outputs.logits
164
+ loss = outputs.loss
165
+ loss = loss.mean()
166
+
167
+ # 统计该batch的预测token的正确数与总数
168
+ batch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index)
169
+ # 统计该epoch的预测token的正确数与总数
170
+ epoch_correct_num += batch_correct_num
171
+ epoch_total_num += batch_total_num
172
+ # 计算该batch的accuracy
173
+ batch_acc = batch_correct_num / batch_total_num
174
+
175
+ total_loss += loss.item()
176
+ if args.gradient_accumulation_steps > 1:
177
+ loss = loss / args.gradient_accumulation_steps
178
+
179
+ loss.backward()
180
+ # 梯度裁剪
181
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
182
+
183
+ # 进行一定step的梯度累计之后,更新参数
184
+ if (batch_idx + 1) % args.gradient_accumulation_steps == 0:
185
+ # 更新参数
186
+ optimizer.step()
187
+ # 更新学习率
188
+ scheduler.step()
189
+ # 清空梯度信息
190
+ optimizer.zero_grad()
191
+
192
+ if (batch_idx + 1) % args.log_step == 0:
193
+ logger.info(
194
+ "batch {} of epoch {}, loss {}, batch_acc {}, lr {}".format(
195
+ batch_idx + 1, epoch + 1, loss.item() * args.gradient_accumulation_steps, batch_acc, scheduler.get_lr()))
196
+
197
+ del input_ids, outputs
198
+
199
+ except RuntimeError as exception:
200
+ if "out of memory" in str(exception):
201
+ logger.info("WARNING: ran out of memory")
202
+ if hasattr(torch.cuda, 'empty_cache'):
203
+ torch.cuda.empty_cache()
204
+ else:
205
+ logger.info(str(exception))
206
+ raise exception
207
+
208
+ # 记录当前epoch的平均loss与accuracy
209
+ epoch_mean_loss = total_loss / len(train_dataloader)
210
+ epoch_mean_acc = epoch_correct_num / epoch_total_num
211
+ logger.info(
212
+ "epoch {}: loss {}, predict_acc {}".format(epoch + 1, epoch_mean_loss, epoch_mean_acc))
213
+
214
+ # save model
215
+ logger.info('saving model for epoch {}'.format(epoch + 1))
216
+ model_path = join(args.save_model_path, 'epoch{}'.format(epoch + 1))
217
+ if not os.path.exists(model_path):
218
+ os.mkdir(model_path)
219
+ model_to_save = model.module if hasattr(model, 'module') else model
220
+ model_to_save.save_pretrained(model_path)
221
+ logger.info('epoch {} finished'.format(epoch + 1))
222
+ epoch_finish_time = datetime.now()
223
+ logger.info('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time))
224
+
225
+ return epoch_mean_loss
226
+
227
+
228
+ def validate_epoch(model, validate_dataloader, logger, epoch, args):
229
+ logger.info("start validating")
230
+ model.eval()
231
+ device = args.device
232
+ # pad_id = args.pad_id
233
+ # sep_id = args.sep_id
234
+ ignore_index = args.ignore_index
235
+ epoch_start_time = datetime.now()
236
+ total_loss = 0
237
+ # 捕获cuda out of memory exception
238
+ try:
239
+ with torch.no_grad():
240
+ for batch_idx, (input_ids, labels) in enumerate(validate_dataloader):
241
+ input_ids = input_ids.to(device)
242
+ labels = labels.to(device)
243
+ outputs = model.forward(input_ids, labels=labels)
244
+ logits = outputs.logits
245
+ loss = outputs.loss
246
+ loss = loss.mean()
247
+
248
+ total_loss += loss.item()
249
+ del input_ids, outputs
250
+
251
+ # 记录当前epoch的平均loss
252
+ epoch_mean_loss = total_loss / len(validate_dataloader)
253
+ logger.info(
254
+ "validate epoch {}: loss {}".format(epoch+1, epoch_mean_loss))
255
+ epoch_finish_time = datetime.now()
256
+ logger.info('time for validating one epoch: {}'.format(epoch_finish_time - epoch_start_time))
257
+ return epoch_mean_loss
258
+ except RuntimeError as exception:
259
+ if "out of memory" in str(exception):
260
+ logger.info("WARNING: ran out of memory")
261
+ if hasattr(torch.cuda, 'empty_cache'):
262
+ torch.cuda.empty_cache()
263
+ else:
264
+ logger.info(str(exception))
265
+ raise exception
266
+
267
+
268
+ def train(model, logger, train_dataset, validate_dataset, args):
269
+ train_dataloader = DataLoader(
270
+ train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn,
271
+ drop_last=True
272
+ )
273
+ validate_dataloader = DataLoader(validate_dataset, batch_size=args.batch_size, shuffle=True,
274
+ num_workers=args.num_workers, collate_fn=collate_fn, drop_last=True)
275
+ early_stopping = EarlyStopping(args.patience, verbose=True, save_path=args.save_model_path)
276
+ t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.epochs
277
+ optimizer = transformers.AdamW(model.parameters(), lr=args.lr, eps=args.eps)
278
+ # scheduler = transformers.WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
279
+ scheduler = transformers.get_linear_schedule_with_warmup(
280
+ optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
281
+ )
282
+
283
+ logger.info('starting training')
284
+
285
+ # 用于记录每个epoch训练和验证的loss
286
+ train_losses, validate_losses = [], []
287
+ # 记录验证集的最小loss
288
+ best_val_loss = 10000
289
+ # 开始训练
290
+ for epoch in range(args.epochs):
291
+ # ========== train ========== #
292
+ train_loss = train_epoch(
293
+ model=model, train_dataloader=train_dataloader,
294
+ optimizer=optimizer, scheduler=scheduler,
295
+ logger=logger, epoch=epoch, args=args)
296
+ train_losses.append(train_loss)
297
+
298
+ # ========== validate ========== #
299
+ validate_loss = validate_epoch(
300
+ model=model, validate_dataloader=validate_dataloader,
301
+ logger=logger, epoch=epoch, args=args)
302
+ validate_losses.append(validate_loss)
303
+
304
+ # 保存当前困惑度最低的模型,困惑度低,模型的生成效果不一定会越好
305
+ if validate_loss < best_val_loss:
306
+ best_val_loss = validate_loss
307
+ logger.info('saving current best model for epoch {}'.format(epoch + 1))
308
+ model_path = join(args.save_model_path, 'min_ppl_model'.format(epoch + 1))
309
+ if not os.path.exists(model_path):
310
+ os.mkdir(model_path)
311
+ model_to_save = model.module if hasattr(model, 'module') else model
312
+ model_to_save.save_pretrained(model_path)
313
+
314
+ # 如果patience=0,则不进行early stopping
315
+ if args.patience == 0:
316
+ continue
317
+ early_stopping(validate_loss, model)
318
+ if early_stopping.early_stop:
319
+ logger.info("Early stopping")
320
+ break
321
+ logger.info('training finished')
322
+ logger.info("train_losses:{}".format(train_losses))
323
+ logger.info("validate_losses:{}".format(validate_losses))
324
+
325
+
326
+ def caculate_loss(logit, target, pad_idx, smoothing=True):
327
+ if smoothing:
328
+ logit = logit[..., :-1, :].contiguous().view(-1, logit.size(2))
329
+ target = target[..., 1:].contiguous().view(-1)
330
+
331
+ eps = 0.1
332
+ n_class = logit.size(-1)
333
+
334
+ one_hot = torch.zeros_like(logit).scatter(1, target.view(-1, 1), 1)
335
+ one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
336
+ log_prb = F.log_softmax(logit, dim=1)
337
+
338
+ non_pad_mask = target.ne(pad_idx)
339
+ loss = -(one_hot * log_prb).sum(dim=1)
340
+ loss = loss.masked_select(non_pad_mask).mean() # average later
341
+ else:
342
+ # loss = F.cross_entropy(predict_logit, target, ignore_index=pad_idx)
343
+ logit = logit[..., :-1, :].contiguous().view(-1, logit.size(-1))
344
+ labels = target[..., 1:].contiguous().view(-1)
345
+ loss = F.cross_entropy(logit, labels, ignore_index=pad_idx)
346
+ return loss
347
+
348
+
349
+ def calculate_acc(logit, labels, ignore_index=-100):
350
+ logit = logit[..., :-1, :].contiguous().view(-1, logit.size(-1))
351
+ labels = labels[..., 1:].contiguous().view(-1)
352
+
353
+ _, logit = logit.max(dim=-1) # 对于每条数据,返回最大的index
354
+ # 进行��运算,返回一个tensor,若labels的第i个位置为pad_id,则置为0,否则为1
355
+ non_pad_mask = labels.ne(ignore_index)
356
+ n_correct = logit.eq(labels).masked_select(non_pad_mask).sum().item()
357
+ n_word = non_pad_mask.sum().item()
358
+ return n_correct, n_word
359
+
360
+
361
+ def main():
362
+ # 初始化参数
363
+ args = set_args()
364
+
365
+ # 设置使用哪些显卡进行训练
366
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.device
367
+
368
+ args.cuda = not args.no_cuda
369
+
370
+ if args.batch_size < 2048 and args.warmup_steps <= 4000:
371
+ print('[Warning] The warmup steps may be not enough.\n' \
372
+ '(sz_b, warmup) = (2048, 4000) is the official setting.\n' \
373
+ 'Using smaller batch w/o longer warmup may cause ' \
374
+ 'the warmup stage ends with only little data trained.')
375
+
376
+ # 创建日志对象
377
+ logger = create_logger(args)
378
+ # 当用户使用GPU,并且GPU可用时
379
+ args.cuda = torch.cuda.is_available() and not args.no_cuda
380
+ device = 'cuda:0' if args.cuda else 'cpu'
381
+ args.device = device
382
+ logger.info('using device:{}'.format(device))
383
+
384
+ # 初始化tokenizer
385
+ tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
386
+ special_tokens = []
387
+ for key in config["mask_token"].keys():
388
+ special_tokens.append(key)
389
+ tokenizer.add_special_tokens( {'additional_special_tokens':special_tokens} )
390
+ args.sep_id = tokenizer.sep_token_id
391
+ args.pad_id = tokenizer.pad_token_id
392
+ args.cls_id = tokenizer.cls_token_id
393
+
394
+ # 创建模型的输出目录
395
+ if not os.path.exists(args.save_model_path):
396
+ os.mkdir(args.save_model_path)
397
+
398
+ # 创建模型
399
+ if args.pretrained_model: # 加载预训练模型
400
+ model = GPT2LMHeadModel.from_pretrained(args.pretrained_model)
401
+ else: # 初始化模型
402
+ model_config = GPT2Config.from_json_file(args.model_config)
403
+ model = GPT2LMHeadModel(config=model_config)
404
+ model = model.to(device)
405
+ logger.info('model config:\n{}'.format(model.config.to_json_string()))
406
+ assert model.config.vocab_size == tokenizer.vocab_size
407
+
408
+ # 并行训练模型
409
+ if args.cuda and torch.cuda.device_count() > 1:
410
+ model = DataParallel(model).cuda()
411
+ # model = BalancedDataParallel(args.gpu0_bsz, model, dim=0).cuda()
412
+ logger.info("use GPU {} to train".format(args.device))
413
+
414
+ # 计算模型参数数量
415
+ num_parameters = 0
416
+ parameters = model.parameters()
417
+ for parameter in parameters:
418
+ num_parameters += parameter.numel()
419
+ logger.info('number of model parameters: {}'.format(num_parameters))
420
+
421
+ # 记录参数设置
422
+ logger.info("args:{}".format(args))
423
+
424
+ # 加载训练集和验证集
425
+ # ========= Loading Dataset ========= #
426
+ train_dataset, validate_dataset = load_dataset(logger, args)
427
+
428
+ train(model, logger, train_dataset, validate_dataset, args)
429
+
430
+
431
+ if __name__ == '__main__':
432
+ main()
web.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import random
4
+ import re
5
+ import requests
6
+ import argparse
7
+ import string
8
+
9
+ from datetime import timedelta
10
+ from flask import Flask, session, request, jsonify, render_template
11
+
12
+ from transformers.models.bert.tokenization_bert import BertTokenizer
13
+
14
+ from bot.chatbot import ChatBot
15
+ from bot.config import special_token_list
16
+
17
+ app = Flask(__name__)
18
+ app.config["SECRET_KEY"] = os.urandom(74)
19
+ app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(days=7)
20
+
21
+ tokenizer:BertTokenizer = None
22
+
23
+ history_matrix:dict = {}
24
+
25
+ def move_history_from_session_to_global_memory() -> None:
26
+ global history_matrix
27
+
28
+ if session.get( "session_hash") and session["history"]:
29
+ history_matrix[session["session_hash"]] = session["history"]
30
+
31
+ def move_history_from_global_memory_to_session() -> None:
32
+ global history_matrix
33
+
34
+ if session.get( "session_hash"):
35
+ session["history"] = history_matrix.get( session.get( "session_hash") )
36
+
37
+ def set_args() -> argparse.Namespace:
38
+ parser:argparse.ArgumentParser = argparse.ArgumentParser()
39
+ parser.add_argument("--vocab_path", default=None, type=str, required=False, help="选择词库")
40
+ parser.add_argument("--model_path", default="lewiswu1209/Winnie", type=str, required=False, help="对话模型路径")
41
+
42
+ return parser.parse_args()
43
+
44
+ @app.route("/chitchat/history", methods = ["GET"])
45
+ def get_history_list() -> str:
46
+ global tokenizer
47
+
48
+ move_history_from_global_memory_to_session()
49
+
50
+ history_list:list = session.get("history")
51
+ if history_list is None:
52
+ history_list = []
53
+
54
+ history:list = []
55
+ for history_ids in history_list:
56
+ tokens = tokenizer.convert_ids_to_tokens(history_ids)
57
+ fixed_tokens = []
58
+ for token in tokens:
59
+ if token.startswith("##"):
60
+ token = token[2:]
61
+ fixed_tokens.append(token)
62
+ history.append( "".join( fixed_tokens ) )
63
+
64
+ return jsonify(history)
65
+
66
+ @app.route("/chitchat/chat", methods = ["GET"])
67
+ def talk() -> str:
68
+ global tokenizer
69
+ global history_matrix
70
+
71
+ if request.args.get("hash"):
72
+ session["session_hash"] = request.args.get("hash")
73
+ move_history_from_global_memory_to_session()
74
+
75
+ if session.get("session_hash") is None:
76
+ session["session_hash"] = "".join( random.sample(string.ascii_lowercase + string.digits, 11) )
77
+
78
+ if request.args.get("text"):
79
+ input_text = request.args.get("text")
80
+ history_list = session.get("history")
81
+
82
+ if input_text.upper()=="HELP":
83
+ help_info_list = ["输入任意文字,Winnie会回答你的问题",
84
+ "输入ERASE MEMORY,Winnie会清空记忆",
85
+ "输入\"<TAG>=<VALUE>\",Winnie会记录你的角色信息",
86
+ "例如:<NAME>=Vicky,Winnie会修改自己的名字",
87
+ "可以修改的角色信息有:",
88
+ "<NAME>, <GENDER>, <YEAROFBIRTH>, <MONTHOFBIRTH>, <DAYOFBIRTH>, <ZODIAC>, <AGE>",
89
+ "输入“上联:XXXXXXX”,Winnie会和你对对联",
90
+ "输入“写诗:XXXXXXX”,Winnie会以XXXXXXX为开头写诗"
91
+ ]
92
+ return jsonify(help_info_list)
93
+
94
+ if history_list is None or len(history_list)==0 or input_text == "ERASE MEMORY":
95
+ history_list = []
96
+ output_text = requests.post(
97
+ url='https://hf.space/embed/lewiswu1209/Winnie/+/api/predict/',
98
+ json={"data": ["ERASE MEMORY"], "session_hash": session["session_hash"]}
99
+ ).json()["data"][0]
100
+
101
+ if input_text != "ERASE MEMORY":
102
+ if not re.match( r"^<.+>=.+$", input_text ):
103
+ history_list.append( tokenizer.encode(input_text, add_special_tokens=False) )
104
+ output_text = requests.post(
105
+ url='https://hf.space/embed/lewiswu1209/Winnie/+/api/predict/',
106
+ json={"data": [input_text], "session_hash": session["session_hash"]}
107
+ ).json()["data"][0]
108
+ if not re.match( r"^<.+>=.+$", input_text ):
109
+ history_list.append( tokenizer.encode(output_text, add_special_tokens=False) )
110
+
111
+ session["history"] = history_list
112
+ history_matrix[session["session_hash"]] = history_list
113
+ return jsonify([output_text])
114
+ else:
115
+ return jsonify([""])
116
+
117
+ @app.route("/")
118
+ def index() -> str:
119
+ return "Hello world!"
120
+
121
+ @app.route("/chitchat/hash", methods = ["GET"])
122
+ def get_hash() -> str:
123
+ global history_matrix
124
+
125
+ if request.args.get("hash"):
126
+ session["session_hash"] = request.args.get("hash")
127
+ move_history_from_global_memory_to_session()
128
+ hash = session.get("session_hash")
129
+ if hash:
130
+ return session.get("session_hash")
131
+ else:
132
+ return " "
133
+
134
+ @app.route( "/chitchat", methods = ["GET"] )
135
+ def chitchat() -> str:
136
+ return render_template( "chat_template.html" )
137
+
138
+ def main() -> None:
139
+ global tokenizer
140
+
141
+ args = set_args()
142
+ tokenizer = ChatBot.get_tokenizer(
143
+ args.model_path,
144
+ vocab_path=args.vocab_path,
145
+ special_token_list = special_token_list
146
+ )
147
+
148
+ app.run( host = "127.0.0.1", port = 8080 )
149
+
150
+ if __name__ == "__main__":
151
+ main()