thefish1 commited on
Commit
86e0603
·
1 Parent(s): 692310d

update0729

Browse files
Files changed (1) hide show
  1. app.py +195 -9
app.py CHANGED
@@ -108,11 +108,12 @@ def get_response_from_db(keywords_dict, class_name):
108
  else:
109
  return None, None, None
110
 
111
- def chatbot_response(message, history, max_tokens, temperature, top_p, window_size, threshold, user_weight, triggered_weight, api_key, state):
112
  initialize_openai_client(api_key)
113
 
114
- current_turn = len(history) + 1
115
  triggered_keywords = state.get('triggered_keywords', {})
 
116
 
117
  combined_user_message = " ".join([h[0] for h in history[-window_size:]] + [message])
118
  combined_assistant_message = " ".join([h[1] for h in history[-window_size:]])
@@ -133,7 +134,7 @@ def chatbot_response(message, history, max_tokens, temperature, top_p, window_si
133
  if distance and distance < threshold:
134
  ad_message = f"{message} <sep>品牌<sep>{ad_summary}"
135
  messages = [{"role": "system", "content": "你是一个热情的聊天机器人,应微妙地嵌入广告内容。"}]
136
- messages.extend([{"role": "user", "content": msg[0]}, {"role": "assistant", "content": msg[1]}] for msg in history)
137
  messages.append({"role": "user", "content": ad_message})
138
 
139
  for keyword in keywords_dict.keys():
@@ -141,7 +142,7 @@ def chatbot_response(message, history, max_tokens, temperature, top_p, window_si
141
  triggered_keywords[keyword] = current_turn
142
  else:
143
  messages = [{"role": "system", "content": "你是一个热情的聊天机器人。"}]
144
- messages.extend([{"role": "user", "content": msg[0]}, {"role": "assistant", "content": msg[1]}] for msg in history)
145
  messages.append({"role": "user", "content": message})
146
 
147
  response = openai_client.chat.completions.create(
@@ -152,16 +153,17 @@ def chatbot_response(message, history, max_tokens, temperature, top_p, window_si
152
  top_p=top_p,
153
  )
154
 
 
 
155
  state['triggered_keywords'] = triggered_keywords
156
- print(f"triggered_keywords: {triggered_keywords}")
157
  return response.choices[0].message.content, state
158
 
159
  # Gradio UI
160
  demo = gr.Interface(
161
- chatbot_response,
162
  inputs=[
163
  gr.Textbox(label="Message"),
164
- gr.State(), # History
165
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
166
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
167
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
@@ -170,7 +172,7 @@ demo = gr.Interface(
170
  gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
171
  gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
172
  gr.Textbox(label="API Key"),
173
- gr.State(value={}) # Triggered keywords state
174
  ],
175
  outputs=[
176
  gr.Textbox(label="Response"),
@@ -180,7 +182,191 @@ demo = gr.Interface(
180
 
181
  if __name__ == "__main__":
182
  demo.launch(share=True)
183
- print("cnm")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
 
186
 
 
108
  else:
109
  return None, None, None
110
 
111
+ def chatbot_response(message, max_tokens, temperature, top_p, window_size, threshold, user_weight, triggered_weight, api_key, state):
112
  initialize_openai_client(api_key)
113
 
114
+ history = state.get('history', [])
115
  triggered_keywords = state.get('triggered_keywords', {})
116
+ current_turn = len(history) + 1
117
 
118
  combined_user_message = " ".join([h[0] for h in history[-window_size:]] + [message])
119
  combined_assistant_message = " ".join([h[1] for h in history[-window_size:]])
 
134
  if distance and distance < threshold:
135
  ad_message = f"{message} <sep>品牌<sep>{ad_summary}"
136
  messages = [{"role": "system", "content": "你是一个热情的聊天机器人,应微妙地嵌入广告内容。"}]
137
+ messages += [{"role": "user", "content": msg[0]}, {"role": "assistant", "content": msg[1]}] for msg in history
138
  messages.append({"role": "user", "content": ad_message})
139
 
140
  for keyword in keywords_dict.keys():
 
142
  triggered_keywords[keyword] = current_turn
143
  else:
144
  messages = [{"role": "system", "content": "你是一个热情的聊天机器人。"}]
145
+ messages += [{"role": "user", "content": msg[0]}, {"role": "assistant", "content": msg[1]}] for msg in history
146
  messages.append({"role": "user", "content": message})
147
 
148
  response = openai_client.chat.completions.create(
 
153
  top_p=top_p,
154
  )
155
 
156
+ history.append((message, response.choices[0].message.content))
157
+ state['history'] = history
158
  state['triggered_keywords'] = triggered_keywords
159
+
160
  return response.choices[0].message.content, state
161
 
162
  # Gradio UI
163
  demo = gr.Interface(
164
+ fn=chatbot_response,
165
  inputs=[
166
  gr.Textbox(label="Message"),
 
167
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
168
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
169
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
 
172
  gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
173
  gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
174
  gr.Textbox(label="API Key"),
175
+ gr.State(value={'history': [], 'triggered_keywords': {}}) # Combined state
176
  ],
177
  outputs=[
178
  gr.Textbox(label="Response"),
 
182
 
183
  if __name__ == "__main__":
184
  demo.launch(share=True)
185
+
186
+
187
+ # import gradio as gr
188
+ # from huggingface_hub import InferenceClient
189
+ # import json
190
+ # import random
191
+ # import re
192
+ # from load_data import load_data
193
+ # from openai import OpenAI
194
+ # from transformers import AutoTokenizer, AutoModel
195
+ # import weaviate
196
+ # import os
197
+ # import torch
198
+ # from tqdm import tqdm
199
+ # import numpy as np
200
+ # import time
201
+
202
+ # # 设置缓存目录
203
+ # os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
204
+ # os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
205
+ # os.makedirs(os.environ['MPLCONFIGDIR'], exist_ok=True)
206
+ # os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
207
+
208
+ # # Weaviate 连接配置
209
+ # WEAVIATE_API_KEY = "Y7c8DRmcxZ4nP5IJLwkznIsK84l6EdwfXwcH"
210
+ # WEAVIATE_URL = "https://39nlafviqvard82k6y8btq.c0.asia-southeast1.gcp.weaviate.cloud"
211
+ # weaviate_auth_config = weaviate.AuthApiKey(api_key=WEAVIATE_API_KEY)
212
+ # weaviate_client = weaviate.Client(url=WEAVIATE_URL, auth_client_secret=weaviate_auth_config)
213
+
214
+ # # 预训练模型配置
215
+ # MODEL_NAME = "bert-base-chinese"
216
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
217
+ # tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
218
+ # model = AutoModel.from_pretrained(MODEL_NAME)
219
+
220
+ # # OpenAI 客户端
221
+ # openai_client = None
222
+
223
+ # def initialize_openai_client(api_key):
224
+ # global openai_client
225
+ # openai_client = OpenAI(api_key=api_key)
226
+
227
+ # def extract_keywords(text):
228
+ # prompt = """
229
+ # 你是一个关键词提取机器人。提取用户输入中的关键词,特别是名词和形容词,关键词之间用��格分隔。例如:苹果 电脑 裤子 蓝色 裙。
230
+ # """
231
+ # messages = [
232
+ # {"role": "system", "content": prompt},
233
+ # {"role": "user", "content": f"从下面的文本中提取五个关键词,以空格分隔:{text}"}
234
+ # ]
235
+
236
+ # response = openai_client.chat.completions.create(
237
+ # model="gpt-3.5-turbo",
238
+ # messages=messages,
239
+ # max_tokens=100,
240
+ # temperature=0.7,
241
+ # top_p=0.9,
242
+ # )
243
+
244
+ # keywords = response.choices[0].message.content.split(' ')
245
+ # return ','.join(keywords)
246
+
247
+ # def match_keywords(query_keywords, ad_keywords_list, triggered_keywords, current_turn, window_size, threshold):
248
+ # best_match_distance = 0
249
+ # best_match_index = -1
250
+
251
+ # for i, ad_keywords in enumerate(ad_keywords_list):
252
+ # match_count = sum(
253
+ # any(
254
+ # ad_keyword in keyword and
255
+ # (keyword not in triggered_keywords or current_turn - triggered_keywords[keyword] > window_size)
256
+ # ) for keyword in query_keywords
257
+ # )
258
+ # if match_count > best_match_distance:
259
+ # best_match_distance = match_count
260
+ # best_match_index = i
261
+
262
+ # if best_match_distance >= threshold:
263
+ # for keyword in query_keywords:
264
+ # if any(ad_keyword in keyword for ad_keyword in ad_keywords_list[best_match_index]):
265
+ # triggered_keywords[keyword] = current_turn
266
+
267
+ # return best_match_distance, best_match_index
268
+
269
+ # def encode_keywords_to_avg(keywords, model, tokenizer, device):
270
+ # embeddings = []
271
+ # for keyword in tqdm(keywords):
272
+ # inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
273
+ # inputs.to(device)
274
+ # with torch.no_grad():
275
+ # outputs = model(**inputs)
276
+ # embeddings.append(outputs.last_hidden_state.mean(dim=1))
277
+ # avg_embedding = sum(embeddings) / len(embeddings)
278
+ # return avg_embedding
279
+
280
+ # def get_response_from_db(keywords_dict, class_name):
281
+ # avg_vec = encode_keywords_to_avg(keywords_dict.keys(), model, tokenizer, device).numpy()
282
+ # response = (
283
+ # weaviate_client.query
284
+ # .get(class_name, ['keywords', 'summary'])
285
+ # .with_near_vector({'vector': avg_vec})
286
+ # .with_limit(1)
287
+ # .with_additional(['distance'])
288
+ # .do()
289
+ # )
290
+
291
+ # if class_name.capitalize() in response['data']['Get']:
292
+ # result = response['data']['Get'][class_name.capitalize()][0]
293
+ # return result['_additional']['distance'], result['summary'], result['keywords']
294
+ # else:
295
+ # return None, None, None
296
+
297
+ # def chatbot_response(message, history, max_tokens, temperature, top_p, window_size, threshold, user_weight, triggered_weight, api_key, state):
298
+ # initialize_openai_client(api_key)
299
+
300
+ # current_turn = len(history) + 1
301
+ # triggered_keywords = state.get('triggered_keywords', {})
302
+
303
+ # combined_user_message = " ".join([h[0] for h in history[-window_size:]] + [message])
304
+ # combined_assistant_message = " ".join([h[1] for h in history[-window_size:]])
305
+
306
+ # user_keywords = extract_keywords(combined_user_message).split(',')
307
+ # assistant_keywords = extract_keywords(combined_assistant_message).split(',')
308
+
309
+ # keywords_dict = {keyword: user_weight for keyword in user_keywords}
310
+ # for keyword in assistant_keywords:
311
+ # keywords_dict[keyword] = keywords_dict.get(keyword, 0) + 1
312
+
313
+ # for keyword in list(keywords_dict.keys()):
314
+ # if keyword in triggered_keywords and current_turn - triggered_keywords[keyword] < window_size:
315
+ # keywords_dict[keyword] = triggered_weight
316
+
317
+ # distance, ad_summary, ad_keywords = get_response_from_db(keywords_dict, class_name="ad_DB02")
318
+
319
+ # if distance and distance < threshold:
320
+ # ad_message = f"{message} <sep>品牌<sep>{ad_summary}"
321
+ # messages = [{"role": "system", "content": "你是一个热情的聊天机器人,应微妙地嵌入广告内容。"}]
322
+ # messages.extend([{"role": "user", "content": msg[0]}, {"role": "assistant", "content": msg[1]}] for msg in history)
323
+ # messages.append({"role": "user", "content": ad_message})
324
+
325
+ # for keyword in keywords_dict.keys():
326
+ # if any(ad_keyword in keyword for ad_keyword in ad_keywords.split(',')):
327
+ # triggered_keywords[keyword] = current_turn
328
+ # else:
329
+ # messages = [{"role": "system", "content": "你是一个热情的聊天机器人。"}]
330
+ # messages.extend([{"role": "user", "content": msg[0]}, {"role": "assistant", "content": msg[1]}] for msg in history)
331
+ # messages.append({"role": "user", "content": message})
332
+
333
+ # response = openai_client.chat.completions.create(
334
+ # model="gpt-3.5-turbo",
335
+ # messages=messages,
336
+ # max_tokens=max_tokens,
337
+ # temperature=temperature,
338
+ # top_p=top_p,
339
+ # )
340
+
341
+ # state['triggered_keywords'] = triggered_keywords
342
+ # print(f"triggered_keywords: {triggered_keywords}")
343
+ # return response.choices[0].message.content, state
344
+
345
+ # # Gradio UI
346
+ # demo = gr.Interface(
347
+ # chatbot_response,
348
+ # inputs=[
349
+ # gr.Textbox(label="Message"),
350
+ # gr.State(), # History
351
+ # gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
352
+ # gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
353
+ # gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
354
+ # gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Window size"),
355
+ # gr.Slider(minimum=0.01, maximum=0.20, value=0.08, step=0.01, label="Distance threshold"),
356
+ # gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
357
+ # gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
358
+ # gr.Textbox(label="API Key"),
359
+ # gr.State(value={}) # Triggered keywords state
360
+ # ],
361
+ # outputs=[
362
+ # gr.Textbox(label="Response"),
363
+ # gr.State() # Return the updated state
364
+ # ]
365
+ # )
366
+
367
+ # if __name__ == "__main__":
368
+ # demo.launch(share=True)
369
+ # print("cnm")
370
 
371
 
372