Upload interact_mmi.py

#9
by chenmingxuan - opened
Files changed (1) hide show
  1. interact_mmi.py +238 -0
interact_mmi.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch
3
+ import os
4
+ import json
5
+ import random
6
+ import numpy as np
7
+ import argparse
8
+ # from torch.utils.tensorboard import SummaryWriter
9
+ from datetime import datetime
10
+ from tqdm import tqdm
11
+ from torch.nn import DataParallel
12
+ import logging
13
+ from transformers.modeling_gpt2 import GPT2Config, GPT2LMHeadModel
14
+ from transformers import BertTokenizer
15
+ from os.path import join, exists
16
+ from itertools import zip_longest, chain
17
+ from dataset import MyDataset
18
+ from torch.utils.data import Dataset, DataLoader
19
+ from torch.nn import CrossEntropyLoss
20
+ from sklearn.model_selection import train_test_split
21
+ from train import create_model
22
+ import torch.nn.functional as F
23
+ import copy
24
+
25
+ PAD = '[PAD]'
26
+ pad_id = 0
27
+
28
+
29
+ def set_interact_args():
30
+ """
31
+ Sets up the training arguments.
32
+ """
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument('--device', default='1', type=str, required=False, help='生成设备')
35
+ parser.add_argument('--temperature', default=1, type=float, required=False, help='生成的temperature')
36
+ parser.add_argument('--topk', default=8, type=int, required=False, help='最高k选1')
37
+ parser.add_argument('--topp', default=0, type=float, required=False, help='最高积累概率')
38
+ parser.add_argument('--model_config', default='config/model_config_dialogue_small.json', type=str, required=False,
39
+ help='模型参数')
40
+ parser.add_argument('--log_path', default='data/interacting_mmi.log', type=str, required=False,
41
+ help='interact_mmi日志存放位置')
42
+ parser.add_argument('--voca_path', default='vocabulary/vocab_small.txt', type=str, required=False, help='选择词库')
43
+ parser.add_argument('--dialogue_model_path', default='dialogue_model/', type=str, required=False,
44
+ help='dialogue_model路径')
45
+ parser.add_argument('--mmi_model_path', default='mmi_model/', type=str, required=False,
46
+ help='互信息mmi_model路径')
47
+ parser.add_argument('--save_samples_path', default="sample/", type=str, required=False, help="保存聊天记录的文件路径")
48
+ parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False,
49
+ help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数")
50
+ parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
51
+ parser.add_argument('--max_len', type=int, default=25, help='每个utterance的最大长度,超过指定长度则进行截断')
52
+ parser.add_argument('--max_history_len', type=int, default=5, help="dialogue history的最大长度")
53
+ parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测')
54
+ parser.add_argument('--batch_size', type=int, default=5, help='批量生成response,然后经过MMI模型进行筛选')
55
+ parser.add_argument('--debug', action='store_true', help='指定该参数,可以查看生成的所有候选的reponse,及其loss')
56
+ return parser.parse_args()
57
+
58
+
59
+ def create_logger(args):
60
+ """
61
+ 将日志输出到日志文件和控制台
62
+ """
63
+ logger = logging.getLogger(__name__)
64
+ logger.setLevel(logging.INFO)
65
+
66
+ formatter = logging.Formatter(
67
+ '%(asctime)s - %(levelname)s - %(message)s')
68
+
69
+ # 创建一个handler,用于写入日志文件
70
+ file_handler = logging.FileHandler(
71
+ filename=args.log_path)
72
+ file_handler.setFormatter(formatter)
73
+ file_handler.setLevel(logging.INFO)
74
+ logger.addHandler(file_handler)
75
+
76
+ # 创建一个handler,用于将日志输出到控制台
77
+ console = logging.StreamHandler()
78
+ console.setLevel(logging.DEBUG)
79
+ console.setFormatter(formatter)
80
+ logger.addHandler(console)
81
+
82
+ return logger
83
+
84
+
85
+ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
86
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
87
+ Args:
88
+ logits: logits distribution shape (vocabulary size)
89
+ top_k > 0: keep only top k tokens with highest probability (top-k filtering).
90
+ top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
91
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
92
+ """
93
+ assert logits.dim() == 2
94
+ top_k = min(top_k, logits[0].size(-1)) # Safety check
95
+ if top_k > 0:
96
+ # Remove all tokens with a probability less than the last token of the top-k
97
+ # torch.topk()返回最后一维最大的top_k个元素,返回值为二维(values,indices)
98
+ # ...表示其他维度由计算机自行推断
99
+ for logit in logits:
100
+ indices_to_remove = logit < torch.topk(logit, top_k)[0][..., -1, None]
101
+ logit[indices_to_remove] = filter_value # 对于topk之外的其他元素的logits值设为负无穷
102
+
103
+ if top_p > 0.0:
104
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) # 对logits进行递减排序
105
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
106
+
107
+ # Remove tokens with cumulative probability above the threshold
108
+ sorted_indices_to_remove = cumulative_probs > top_p
109
+ # Shift the indices to the right to keep also the first token above the threshold
110
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
111
+ sorted_indices_to_remove[..., 0] = 0
112
+ for index, logit in enumerate(logits):
113
+ indices_to_remove = sorted_indices[index][sorted_indices_to_remove[index]]
114
+ logit[indices_to_remove] = filter_value
115
+ return logits
116
+
117
+
118
+ def main():
119
+ args = set_interact_args()
120
+ logger = create_logger(args)
121
+ # 当用户使用GPU,并且GPU可用时
122
+ args.cuda = torch.cuda.is_available() and not args.no_cuda
123
+ device = 'cuda' if args.cuda else 'cpu'
124
+ logger.info('using device:{}'.format(device))
125
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.device
126
+
127
+
128
+ tokenizer = BertTokenizer(vocab_file=args.voca_path)
129
+ # 对话model
130
+ dialogue_model = GPT2LMHeadModel.from_pretrained(args.dialogue_model_path)
131
+ dialogue_model.to(device)
132
+ dialogue_model.eval()
133
+ # 互信息mmi model
134
+ mmi_model = GPT2LMHeadModel.from_pretrained(args.mmi_model_path)
135
+ mmi_model.to(device)
136
+ mmi_model.eval()
137
+ if args.save_samples_path:
138
+ if not os.path.exists(args.save_samples_path):
139
+ os.makedirs(args.save_samples_path)
140
+ samples_file = open(args.save_samples_path + '/mmi_samples.txt', 'a', encoding='utf8')
141
+ samples_file.write("聊天记录{}:\n".format(datetime.now()))
142
+ # 存储聊天记录,每个utterance以token的id的形式进行存储
143
+ history = []
144
+ print('开始和chatbot聊天,输入CTRL + Z以退出')
145
+
146
+ while True:
147
+ try:
148
+ text = input("user:")
149
+ if args.save_samples_path:
150
+ samples_file.write("user:{}\n".format(text))
151
+ history.append(tokenizer.encode(text))
152
+ input_ids = [tokenizer.cls_token_id] # 每个input以[CLS]为开头
153
+ for history_id, history_utr in enumerate(history[-args.max_history_len:]):
154
+ input_ids.extend(history_utr)
155
+ input_ids.append(tokenizer.sep_token_id)
156
+ # 用于批量生成response,维度为(batch_size,token_len)
157
+ input_ids = [copy.deepcopy(input_ids) for _ in range(args.batch_size)]
158
+
159
+ curr_input_tensors = torch.tensor(input_ids).long().to(device)
160
+ generated = [] # 二维数组,维度为(生成的response的最大长度,batch_size),generated[i,j]表示第j个response的第i个token的id
161
+ finish_set = set() # 标记是否所有response均已生成结束,若第i个response生成结束,即生成了sep_token_id,则将i放入finish_set
162
+ # 最多生成max_len个token
163
+ for _ in range(args.max_len):
164
+ outputs = dialogue_model(input_ids=curr_input_tensors)
165
+ # print ("outputs",outputs)
166
+ next_token_logits = outputs[0][:, -1, :]
167
+ # 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率
168
+ for index in range(args.batch_size):
169
+ for token_id in set([token_ids[index] for token_ids in generated]):
170
+ next_token_logits[index][token_id] /= args.repetition_penalty
171
+ next_token_logits = next_token_logits / args.temperature
172
+ # 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token
173
+ for next_token_logit in next_token_logits:
174
+ next_token_logit[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
175
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp)
176
+ # torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标
177
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
178
+ # 判断是否有response生成了[SEP],将已生成了[SEP]的resposne进行标记
179
+ for index, token_id in enumerate(next_token[:, 0]):
180
+ if token_id == tokenizer.sep_token_id:
181
+ finish_set.add(index)
182
+ # 检验是否所有的response均已生成[SEP]
183
+ finish_flag = True # 是否所有的response均已生成[SEP]的token
184
+ for index in range(args.batch_size):
185
+ if index not in finish_set: # response批量生成未完成
186
+ finish_flag = False
187
+ break
188
+ if finish_flag:
189
+ break
190
+ generated.append([token.item() for token in next_token[:, 0]])
191
+ # 将新生成的token与原来的token进行拼接
192
+ curr_input_tensors = torch.cat((curr_input_tensors, next_token), dim=-1)
193
+ candidate_responses = [] # 生成的所有候选response
194
+ for batch_index in range(args.batch_size):
195
+ response = []
196
+ for token_index in range(len(generated)):
197
+ if generated[token_index][batch_index] != tokenizer.sep_token_id:
198
+ response.append(generated[token_index][batch_index])
199
+ else:
200
+ break
201
+ candidate_responses.append(response)
202
+
203
+ # mmi模型的输入
204
+ if args.debug:
205
+ print("candidate response:")
206
+ samples_file.write("candidate response:\n")
207
+ min_loss = float('Inf')
208
+ best_response = ""
209
+ for response in candidate_responses:
210
+ mmi_input_id = [tokenizer.cls_token_id] # 每个input以[CLS]为开头
211
+ mmi_input_id.extend(response)
212
+ mmi_input_id.append(tokenizer.sep_token_id)
213
+ for history_utr in reversed(history[-args.max_history_len:]):
214
+ mmi_input_id.extend(history_utr)
215
+ mmi_input_id.append(tokenizer.sep_token_id)
216
+ mmi_input_tensor = torch.tensor(mmi_input_id).long().to(device)
217
+ out = mmi_model(input_ids=mmi_input_tensor, labels=mmi_input_tensor)
218
+ loss = out[0].item()
219
+ if args.debug:
220
+ text = tokenizer.convert_ids_to_tokens(response)
221
+ print("{} loss:{}".format("".join(text), loss))
222
+ samples_file.write("{} loss:{}\n".format("".join(text), loss))
223
+ if loss < min_loss:
224
+ best_response = response
225
+ min_loss = loss
226
+ history.append(best_response)
227
+ text = tokenizer.convert_ids_to_tokens(best_response)
228
+ print("chatbot:" + "".join(text))
229
+ if args.save_samples_path:
230
+ samples_file.write("chatbot:{}\n".format("".join(text)))
231
+ except KeyboardInterrupt:
232
+ if args.save_samples_path:
233
+ samples_file.close()
234
+ break
235
+
236
+
237
+ if __name__ == '__main__':
238
+ main()