update0729
Browse files
app.py
CHANGED
@@ -108,11 +108,12 @@ def get_response_from_db(keywords_dict, class_name):
|
|
108 |
else:
|
109 |
return None, None, None
|
110 |
|
111 |
-
def chatbot_response(message,
|
112 |
initialize_openai_client(api_key)
|
113 |
|
114 |
-
|
115 |
triggered_keywords = state.get('triggered_keywords', {})
|
|
|
116 |
|
117 |
combined_user_message = " ".join([h[0] for h in history[-window_size:]] + [message])
|
118 |
combined_assistant_message = " ".join([h[1] for h in history[-window_size:]])
|
@@ -133,7 +134,7 @@ def chatbot_response(message, history, max_tokens, temperature, top_p, window_si
|
|
133 |
if distance and distance < threshold:
|
134 |
ad_message = f"{message} <sep>品牌<sep>{ad_summary}"
|
135 |
messages = [{"role": "system", "content": "你是一个热情的聊天机器人,应微妙地嵌入广告内容。"}]
|
136 |
-
messages
|
137 |
messages.append({"role": "user", "content": ad_message})
|
138 |
|
139 |
for keyword in keywords_dict.keys():
|
@@ -141,7 +142,7 @@ def chatbot_response(message, history, max_tokens, temperature, top_p, window_si
|
|
141 |
triggered_keywords[keyword] = current_turn
|
142 |
else:
|
143 |
messages = [{"role": "system", "content": "你是一个热情的聊天机器人。"}]
|
144 |
-
messages
|
145 |
messages.append({"role": "user", "content": message})
|
146 |
|
147 |
response = openai_client.chat.completions.create(
|
@@ -152,16 +153,17 @@ def chatbot_response(message, history, max_tokens, temperature, top_p, window_si
|
|
152 |
top_p=top_p,
|
153 |
)
|
154 |
|
|
|
|
|
155 |
state['triggered_keywords'] = triggered_keywords
|
156 |
-
|
157 |
return response.choices[0].message.content, state
|
158 |
|
159 |
# Gradio UI
|
160 |
demo = gr.Interface(
|
161 |
-
chatbot_response,
|
162 |
inputs=[
|
163 |
gr.Textbox(label="Message"),
|
164 |
-
gr.State(), # History
|
165 |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
166 |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
167 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
|
@@ -170,7 +172,7 @@ demo = gr.Interface(
|
|
170 |
gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
|
171 |
gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
|
172 |
gr.Textbox(label="API Key"),
|
173 |
-
gr.State(value={}) #
|
174 |
],
|
175 |
outputs=[
|
176 |
gr.Textbox(label="Response"),
|
@@ -180,7 +182,191 @@ demo = gr.Interface(
|
|
180 |
|
181 |
if __name__ == "__main__":
|
182 |
demo.launch(share=True)
|
183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
|
186 |
|
|
|
108 |
else:
|
109 |
return None, None, None
|
110 |
|
111 |
+
def chatbot_response(message, max_tokens, temperature, top_p, window_size, threshold, user_weight, triggered_weight, api_key, state):
|
112 |
initialize_openai_client(api_key)
|
113 |
|
114 |
+
history = state.get('history', [])
|
115 |
triggered_keywords = state.get('triggered_keywords', {})
|
116 |
+
current_turn = len(history) + 1
|
117 |
|
118 |
combined_user_message = " ".join([h[0] for h in history[-window_size:]] + [message])
|
119 |
combined_assistant_message = " ".join([h[1] for h in history[-window_size:]])
|
|
|
134 |
if distance and distance < threshold:
|
135 |
ad_message = f"{message} <sep>品牌<sep>{ad_summary}"
|
136 |
messages = [{"role": "system", "content": "你是一个热情的聊天机器人,应微妙地嵌入广告内容。"}]
|
137 |
+
messages += [{"role": "user", "content": msg[0]}, {"role": "assistant", "content": msg[1]}] for msg in history
|
138 |
messages.append({"role": "user", "content": ad_message})
|
139 |
|
140 |
for keyword in keywords_dict.keys():
|
|
|
142 |
triggered_keywords[keyword] = current_turn
|
143 |
else:
|
144 |
messages = [{"role": "system", "content": "你是一个热情的聊天机器人。"}]
|
145 |
+
messages += [{"role": "user", "content": msg[0]}, {"role": "assistant", "content": msg[1]}] for msg in history
|
146 |
messages.append({"role": "user", "content": message})
|
147 |
|
148 |
response = openai_client.chat.completions.create(
|
|
|
153 |
top_p=top_p,
|
154 |
)
|
155 |
|
156 |
+
history.append((message, response.choices[0].message.content))
|
157 |
+
state['history'] = history
|
158 |
state['triggered_keywords'] = triggered_keywords
|
159 |
+
|
160 |
return response.choices[0].message.content, state
|
161 |
|
162 |
# Gradio UI
|
163 |
demo = gr.Interface(
|
164 |
+
fn=chatbot_response,
|
165 |
inputs=[
|
166 |
gr.Textbox(label="Message"),
|
|
|
167 |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
168 |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
169 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
|
|
|
172 |
gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
|
173 |
gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
|
174 |
gr.Textbox(label="API Key"),
|
175 |
+
gr.State(value={'history': [], 'triggered_keywords': {}}) # Combined state
|
176 |
],
|
177 |
outputs=[
|
178 |
gr.Textbox(label="Response"),
|
|
|
182 |
|
183 |
if __name__ == "__main__":
|
184 |
demo.launch(share=True)
|
185 |
+
|
186 |
+
|
187 |
+
# import gradio as gr
|
188 |
+
# from huggingface_hub import InferenceClient
|
189 |
+
# import json
|
190 |
+
# import random
|
191 |
+
# import re
|
192 |
+
# from load_data import load_data
|
193 |
+
# from openai import OpenAI
|
194 |
+
# from transformers import AutoTokenizer, AutoModel
|
195 |
+
# import weaviate
|
196 |
+
# import os
|
197 |
+
# import torch
|
198 |
+
# from tqdm import tqdm
|
199 |
+
# import numpy as np
|
200 |
+
# import time
|
201 |
+
|
202 |
+
# # 设置缓存目录
|
203 |
+
# os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
|
204 |
+
# os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
|
205 |
+
# os.makedirs(os.environ['MPLCONFIGDIR'], exist_ok=True)
|
206 |
+
# os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
|
207 |
+
|
208 |
+
# # Weaviate 连接配置
|
209 |
+
# WEAVIATE_API_KEY = "Y7c8DRmcxZ4nP5IJLwkznIsK84l6EdwfXwcH"
|
210 |
+
# WEAVIATE_URL = "https://39nlafviqvard82k6y8btq.c0.asia-southeast1.gcp.weaviate.cloud"
|
211 |
+
# weaviate_auth_config = weaviate.AuthApiKey(api_key=WEAVIATE_API_KEY)
|
212 |
+
# weaviate_client = weaviate.Client(url=WEAVIATE_URL, auth_client_secret=weaviate_auth_config)
|
213 |
+
|
214 |
+
# # 预训练模型配置
|
215 |
+
# MODEL_NAME = "bert-base-chinese"
|
216 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
217 |
+
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
218 |
+
# model = AutoModel.from_pretrained(MODEL_NAME)
|
219 |
+
|
220 |
+
# # OpenAI 客户端
|
221 |
+
# openai_client = None
|
222 |
+
|
223 |
+
# def initialize_openai_client(api_key):
|
224 |
+
# global openai_client
|
225 |
+
# openai_client = OpenAI(api_key=api_key)
|
226 |
+
|
227 |
+
# def extract_keywords(text):
|
228 |
+
# prompt = """
|
229 |
+
# 你是一个关键词提取机器人。提取用户输入中的关键词,特别是名词和形容词,关键词之间用��格分隔。例如:苹果 电脑 裤子 蓝色 裙。
|
230 |
+
# """
|
231 |
+
# messages = [
|
232 |
+
# {"role": "system", "content": prompt},
|
233 |
+
# {"role": "user", "content": f"从下面的文本中提取五个关键词,以空格分隔:{text}"}
|
234 |
+
# ]
|
235 |
+
|
236 |
+
# response = openai_client.chat.completions.create(
|
237 |
+
# model="gpt-3.5-turbo",
|
238 |
+
# messages=messages,
|
239 |
+
# max_tokens=100,
|
240 |
+
# temperature=0.7,
|
241 |
+
# top_p=0.9,
|
242 |
+
# )
|
243 |
+
|
244 |
+
# keywords = response.choices[0].message.content.split(' ')
|
245 |
+
# return ','.join(keywords)
|
246 |
+
|
247 |
+
# def match_keywords(query_keywords, ad_keywords_list, triggered_keywords, current_turn, window_size, threshold):
|
248 |
+
# best_match_distance = 0
|
249 |
+
# best_match_index = -1
|
250 |
+
|
251 |
+
# for i, ad_keywords in enumerate(ad_keywords_list):
|
252 |
+
# match_count = sum(
|
253 |
+
# any(
|
254 |
+
# ad_keyword in keyword and
|
255 |
+
# (keyword not in triggered_keywords or current_turn - triggered_keywords[keyword] > window_size)
|
256 |
+
# ) for keyword in query_keywords
|
257 |
+
# )
|
258 |
+
# if match_count > best_match_distance:
|
259 |
+
# best_match_distance = match_count
|
260 |
+
# best_match_index = i
|
261 |
+
|
262 |
+
# if best_match_distance >= threshold:
|
263 |
+
# for keyword in query_keywords:
|
264 |
+
# if any(ad_keyword in keyword for ad_keyword in ad_keywords_list[best_match_index]):
|
265 |
+
# triggered_keywords[keyword] = current_turn
|
266 |
+
|
267 |
+
# return best_match_distance, best_match_index
|
268 |
+
|
269 |
+
# def encode_keywords_to_avg(keywords, model, tokenizer, device):
|
270 |
+
# embeddings = []
|
271 |
+
# for keyword in tqdm(keywords):
|
272 |
+
# inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
273 |
+
# inputs.to(device)
|
274 |
+
# with torch.no_grad():
|
275 |
+
# outputs = model(**inputs)
|
276 |
+
# embeddings.append(outputs.last_hidden_state.mean(dim=1))
|
277 |
+
# avg_embedding = sum(embeddings) / len(embeddings)
|
278 |
+
# return avg_embedding
|
279 |
+
|
280 |
+
# def get_response_from_db(keywords_dict, class_name):
|
281 |
+
# avg_vec = encode_keywords_to_avg(keywords_dict.keys(), model, tokenizer, device).numpy()
|
282 |
+
# response = (
|
283 |
+
# weaviate_client.query
|
284 |
+
# .get(class_name, ['keywords', 'summary'])
|
285 |
+
# .with_near_vector({'vector': avg_vec})
|
286 |
+
# .with_limit(1)
|
287 |
+
# .with_additional(['distance'])
|
288 |
+
# .do()
|
289 |
+
# )
|
290 |
+
|
291 |
+
# if class_name.capitalize() in response['data']['Get']:
|
292 |
+
# result = response['data']['Get'][class_name.capitalize()][0]
|
293 |
+
# return result['_additional']['distance'], result['summary'], result['keywords']
|
294 |
+
# else:
|
295 |
+
# return None, None, None
|
296 |
+
|
297 |
+
# def chatbot_response(message, history, max_tokens, temperature, top_p, window_size, threshold, user_weight, triggered_weight, api_key, state):
|
298 |
+
# initialize_openai_client(api_key)
|
299 |
+
|
300 |
+
# current_turn = len(history) + 1
|
301 |
+
# triggered_keywords = state.get('triggered_keywords', {})
|
302 |
+
|
303 |
+
# combined_user_message = " ".join([h[0] for h in history[-window_size:]] + [message])
|
304 |
+
# combined_assistant_message = " ".join([h[1] for h in history[-window_size:]])
|
305 |
+
|
306 |
+
# user_keywords = extract_keywords(combined_user_message).split(',')
|
307 |
+
# assistant_keywords = extract_keywords(combined_assistant_message).split(',')
|
308 |
+
|
309 |
+
# keywords_dict = {keyword: user_weight for keyword in user_keywords}
|
310 |
+
# for keyword in assistant_keywords:
|
311 |
+
# keywords_dict[keyword] = keywords_dict.get(keyword, 0) + 1
|
312 |
+
|
313 |
+
# for keyword in list(keywords_dict.keys()):
|
314 |
+
# if keyword in triggered_keywords and current_turn - triggered_keywords[keyword] < window_size:
|
315 |
+
# keywords_dict[keyword] = triggered_weight
|
316 |
+
|
317 |
+
# distance, ad_summary, ad_keywords = get_response_from_db(keywords_dict, class_name="ad_DB02")
|
318 |
+
|
319 |
+
# if distance and distance < threshold:
|
320 |
+
# ad_message = f"{message} <sep>品牌<sep>{ad_summary}"
|
321 |
+
# messages = [{"role": "system", "content": "你是一个热情的聊天机器人,应微妙地嵌入广告内容。"}]
|
322 |
+
# messages.extend([{"role": "user", "content": msg[0]}, {"role": "assistant", "content": msg[1]}] for msg in history)
|
323 |
+
# messages.append({"role": "user", "content": ad_message})
|
324 |
+
|
325 |
+
# for keyword in keywords_dict.keys():
|
326 |
+
# if any(ad_keyword in keyword for ad_keyword in ad_keywords.split(',')):
|
327 |
+
# triggered_keywords[keyword] = current_turn
|
328 |
+
# else:
|
329 |
+
# messages = [{"role": "system", "content": "你是一个热情的聊天机器人。"}]
|
330 |
+
# messages.extend([{"role": "user", "content": msg[0]}, {"role": "assistant", "content": msg[1]}] for msg in history)
|
331 |
+
# messages.append({"role": "user", "content": message})
|
332 |
+
|
333 |
+
# response = openai_client.chat.completions.create(
|
334 |
+
# model="gpt-3.5-turbo",
|
335 |
+
# messages=messages,
|
336 |
+
# max_tokens=max_tokens,
|
337 |
+
# temperature=temperature,
|
338 |
+
# top_p=top_p,
|
339 |
+
# )
|
340 |
+
|
341 |
+
# state['triggered_keywords'] = triggered_keywords
|
342 |
+
# print(f"triggered_keywords: {triggered_keywords}")
|
343 |
+
# return response.choices[0].message.content, state
|
344 |
+
|
345 |
+
# # Gradio UI
|
346 |
+
# demo = gr.Interface(
|
347 |
+
# chatbot_response,
|
348 |
+
# inputs=[
|
349 |
+
# gr.Textbox(label="Message"),
|
350 |
+
# gr.State(), # History
|
351 |
+
# gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
352 |
+
# gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
353 |
+
# gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
|
354 |
+
# gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Window size"),
|
355 |
+
# gr.Slider(minimum=0.01, maximum=0.20, value=0.08, step=0.01, label="Distance threshold"),
|
356 |
+
# gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
|
357 |
+
# gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
|
358 |
+
# gr.Textbox(label="API Key"),
|
359 |
+
# gr.State(value={}) # Triggered keywords state
|
360 |
+
# ],
|
361 |
+
# outputs=[
|
362 |
+
# gr.Textbox(label="Response"),
|
363 |
+
# gr.State() # Return the updated state
|
364 |
+
# ]
|
365 |
+
# )
|
366 |
+
|
367 |
+
# if __name__ == "__main__":
|
368 |
+
# demo.launch(share=True)
|
369 |
+
# print("cnm")
|
370 |
|
371 |
|
372 |
|