thefish1 commited on
Commit
28bb1f7
·
1 Parent(s): 5e4b62e

update 0722

Browse files
Files changed (1) hide show
  1. app.py +16 -50
app.py CHANGED
@@ -39,13 +39,6 @@ tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
39
  model = AutoModel.from_pretrained("bert-base-chinese")
40
 
41
 
42
-
43
- # 本地加载数据
44
- dataset = load_data(file_path='train_300.json', num_samples=300)
45
- keyword_lists = [item['content'] for item in dataset if 'content' in item]
46
- summary_lists = [item['summary'] for item in dataset if 'summary' in item]
47
-
48
-
49
  global_api_key = None
50
  client = None
51
 
@@ -53,10 +46,6 @@ def initialize_clients(api_key):
53
  global client
54
  client = OpenAI(api_key=api_key)
55
 
56
-
57
- for item in keyword_lists:
58
- item = item.split(',')
59
-
60
  def get_keywords(message):
61
  system_message = """
62
  # 角色
@@ -83,8 +72,7 @@ def get_keywords(message):
83
  return ','.join(keywords)
84
 
85
 
86
-
87
-
88
  def keyword_match(query_keywords_dict, ad_keywords_lists, triggered_keywords, current_turn, window_size,distance_threshold):
89
  distance = 0
90
  most_matching_list = None
@@ -93,7 +81,6 @@ def keyword_match(query_keywords_dict, ad_keywords_lists, triggered_keywords, cu
93
  # query_keywords = query_keywords.split(',')
94
  # query_keywords = [keyword for keyword in query_keywords if keyword]
95
 
96
-
97
  #匹配模块
98
  query_keywords= list(query_keywords_dict.keys())
99
 
@@ -126,6 +113,7 @@ def keyword_match(query_keywords_dict, ad_keywords_lists, triggered_keywords, cu
126
 
127
  return distance, index
128
 
 
129
  def encode_list_to_avg(keywords_list_list, model, tokenizer, device):
130
  if torch.cuda.is_available():
131
  print('Using GPU')
@@ -150,6 +138,7 @@ def encode_list_to_avg(keywords_list_list, model, tokenizer, device):
150
 
151
  return avg_embeddings
152
 
 
153
  def encode_to_avg(keywords_dict, model, tokenizer, device):
154
  if torch.cuda.is_available():
155
  print('Using GPU')
@@ -176,33 +165,6 @@ def encode_to_avg(keywords_dict, model, tokenizer, device):
176
  return avg_embedding.tolist()
177
 
178
 
179
- # def fetch_response_from_db(query_keywords,class_name):
180
-
181
- # avg_vec=np.array(encode_list_to_avg([query_keywords], model, tokenizer, device)[0])
182
- # nearVector = {
183
- # 'vector': avg_vec
184
- # }
185
-
186
- # response = (
187
- # db_client.query
188
- # .get(class_name, ['keywords', 'summary'])
189
- # .with_near_vector(nearVector)
190
- # .with_limit(1)
191
- # .with_additional(['distance'])
192
- # .do()
193
- # )
194
-
195
- # print(response)
196
- # class_name=class_name[0].upper()+class_name[1:]
197
-
198
- # if class_name in response['data']['Get']:
199
- # results = response['data']['Get'][class_name]
200
- # return results[0]['_additional']['distance'],results[0]['summary'], results[0]['keywords']
201
- # else:
202
- # print(f"Class name {class_name} not found in response")
203
- # return None
204
-
205
-
206
  def fetch_response_from_db(query_keywords_dict,class_name):
207
 
208
 
@@ -230,13 +192,15 @@ def fetch_response_from_db(query_keywords_dict,class_name):
230
  print(f"Class name {class_name} not found in response")
231
  return None
232
 
233
- def wrapper(message, history, max_tokens, temperature, top_p, window_size, distance_threshold, weight_keywords_users, weight_keywords_triggered, api_key):
 
 
234
  initialize_clients(api_key)
235
- return respond(message, history, max_tokens, temperature, top_p, window_size, distance_threshold, weight_keywords_users, weight_keywords_triggered)
236
 
237
 
238
  #触发词及触发回合字典
239
- triggered_keywords = {}
240
 
241
  def respond(
242
  message,
@@ -247,10 +211,11 @@ def respond(
247
  window_size,
248
  distance_threshold,
249
  weight_keywords_users,
250
- weight_keywords_triggered
 
251
  ):
252
 
253
-
254
  system_message_with_ad = """
255
  # 角色
256
  你是一个热情的聊天机器人
@@ -291,9 +256,9 @@ def respond(
291
  keywords_dict={}
292
  for keywords in key_words_users:
293
  if keywords in keywords_dict:
294
- keywords_dict[keywords]+=2
295
  else:
296
- keywords_dict[keywords]=2
297
  for keywords in key_words_assistant:
298
  if keywords in keywords_dict:
299
  keywords_dict[keywords]+=1
@@ -355,7 +320,7 @@ def respond(
355
  top_p=top_p,
356
  )
357
 
358
- return response.choices[0].message.content
359
 
360
 
361
  # def chat_interface(message, history, max_tokens, temperature, top_p, window_size, distance_threshold):
@@ -388,7 +353,8 @@ demo = gr.ChatInterface(
388
  gr.Slider(minimum=0.01, maximum=0.20, value=0.08, step=0.01, label="Distance threshold"),
389
  gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
390
  gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
391
- gr.Textbox(label="api_key")
 
392
  ],
393
  )
394
 
 
39
  model = AutoModel.from_pretrained("bert-base-chinese")
40
 
41
 
 
 
 
 
 
 
 
42
  global_api_key = None
43
  client = None
44
 
 
46
  global client
47
  client = OpenAI(api_key=api_key)
48
 
 
 
 
 
49
  def get_keywords(message):
50
  system_message = """
51
  # 角色
 
72
  return ','.join(keywords)
73
 
74
 
75
+ #字符串匹配模块
 
76
  def keyword_match(query_keywords_dict, ad_keywords_lists, triggered_keywords, current_turn, window_size,distance_threshold):
77
  distance = 0
78
  most_matching_list = None
 
81
  # query_keywords = query_keywords.split(',')
82
  # query_keywords = [keyword for keyword in query_keywords if keyword]
83
 
 
84
  #匹配模块
85
  query_keywords= list(query_keywords_dict.keys())
86
 
 
113
 
114
  return distance, index
115
 
116
+
117
  def encode_list_to_avg(keywords_list_list, model, tokenizer, device):
118
  if torch.cuda.is_available():
119
  print('Using GPU')
 
138
 
139
  return avg_embeddings
140
 
141
+
142
  def encode_to_avg(keywords_dict, model, tokenizer, device):
143
  if torch.cuda.is_available():
144
  print('Using GPU')
 
165
  return avg_embedding.tolist()
166
 
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  def fetch_response_from_db(query_keywords_dict,class_name):
169
 
170
 
 
192
  print(f"Class name {class_name} not found in response")
193
  return None
194
 
195
+
196
+
197
+ def wrapper(message, history, max_tokens, temperature, top_p, window_size, distance_threshold, weight_keywords_users, weight_keywords_triggered, api_key, state):
198
  initialize_clients(api_key)
199
+ return respond(message, history, max_tokens, temperature, top_p, window_size, distance_threshold, weight_keywords_users, weight_keywords_triggered, state)
200
 
201
 
202
  #触发词及触发回合字典
203
+
204
 
205
  def respond(
206
  message,
 
211
  window_size,
212
  distance_threshold,
213
  weight_keywords_users,
214
+ weight_keywords_triggered,
215
+ triggered_keywords
216
  ):
217
 
218
+ triggered_keywords=triggered_keywords or {}
219
  system_message_with_ad = """
220
  # 角色
221
  你是一个热情的聊天机器人
 
256
  keywords_dict={}
257
  for keywords in key_words_users:
258
  if keywords in keywords_dict:
259
+ keywords_dict[keywords]+=weight_keywords_users
260
  else:
261
+ keywords_dict[keywords]=weight_keywords_users
262
  for keywords in key_words_assistant:
263
  if keywords in keywords_dict:
264
  keywords_dict[keywords]+=1
 
320
  top_p=top_p,
321
  )
322
 
323
+ return response.choices[0].message.content , triggered_keywords
324
 
325
 
326
  # def chat_interface(message, history, max_tokens, temperature, top_p, window_size, distance_threshold):
 
353
  gr.Slider(minimum=0.01, maximum=0.20, value=0.08, step=0.01, label="Distance threshold"),
354
  gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
355
  gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
356
+ gr.Textbox(label="api_key"),
357
+ 'state'
358
  ],
359
  )
360