update 0722
Browse files
app.py
CHANGED
@@ -1,3 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
from huggingface_hub import InferenceClient
|
3 |
import json
|
@@ -7,38 +370,32 @@ from load_data import load_data
|
|
7 |
from openai import OpenAI
|
8 |
from transformers import AutoTokenizer, AutoModel
|
9 |
import weaviate
|
10 |
-
import os
|
11 |
-
import subprocess
|
12 |
import torch
|
13 |
from tqdm import tqdm
|
14 |
import numpy as np
|
15 |
|
16 |
-
|
17 |
-
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
|
18 |
-
|
19 |
-
os.environ['
|
20 |
-
|
21 |
-
os.makedirs(os.environ['MPLCONFIGDIR'], exist_ok=True)
|
22 |
-
os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
|
23 |
|
24 |
auth_config = weaviate.AuthApiKey(api_key="Y7c8DRmcxZ4nP5IJLwkznIsK84l6EdwfXwcH")
|
25 |
-
|
26 |
URL = "https://39nlafviqvard82k6y8btq.c0.asia-southeast1.gcp.weaviate.cloud"
|
27 |
|
28 |
# Connect to a WCS instance
|
29 |
db_client = weaviate.Client(
|
30 |
-
|
31 |
-
|
32 |
)
|
33 |
|
34 |
-
|
35 |
-
class_name="ad_DB02"
|
36 |
-
|
37 |
device = torch.device(device='cuda' if torch.cuda.is_available() else 'cpu')
|
38 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
|
39 |
model = AutoModel.from_pretrained("bert-base-chinese")
|
40 |
|
41 |
-
|
42 |
global_api_key = None
|
43 |
client = None
|
44 |
|
@@ -56,7 +413,6 @@ def get_keywords(message):
|
|
56 |
你应该直接输出关键词,关键词之间用空格分隔。例如:苹果 电脑 裤子 蓝色 裙
|
57 |
# 注意:如果输入文本过短可以重复输出关键词,例如对输入“你好”可以输出:你好 你好 你好 你好 你好
|
58 |
"""
|
59 |
-
|
60 |
messages = [{"role": "system", "content": system_message}]
|
61 |
messages.append({"role": "user", "content": f"从下面的文本中给我提取五个关键词,只输出这五个关键词,以空格分隔{message}"})
|
62 |
|
@@ -71,73 +427,27 @@ def get_keywords(message):
|
|
71 |
keywords = response.choices[0].message.content.split(' ')
|
72 |
return ','.join(keywords)
|
73 |
|
|
|
|
|
|
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
#匹配模块
|
85 |
-
query_keywords= list(query_keywords_dict.keys())
|
86 |
-
|
87 |
-
for i, lst in enumerate(ad_keywords_lists):
|
88 |
-
lst = lst.split(',')
|
89 |
-
matches = sum(
|
90 |
-
any(
|
91 |
-
ad_keyword in keyword and
|
92 |
-
(
|
93 |
-
keyword not in triggered_keywords or
|
94 |
-
triggered_keywords.get(keyword) is None or
|
95 |
-
current_turn - triggered_keywords.get(keyword, 0) > window_size
|
96 |
-
) * query_keywords_dict.get(keyword, 1) #计数乘以权重
|
97 |
-
for keyword in query_keywords
|
98 |
-
)
|
99 |
-
for ad_keyword in lst
|
100 |
-
)
|
101 |
-
if matches > distance:
|
102 |
-
distance = matches
|
103 |
-
most_matching_list = lst
|
104 |
-
index = i
|
105 |
-
|
106 |
-
#更新对distance 有贡献的关键词
|
107 |
-
if distance >= distance_threshold:
|
108 |
-
for keyword in query_keywords:
|
109 |
-
if any(
|
110 |
-
ad_keyword in keyword for ad_keyword in most_matching_list
|
111 |
-
):
|
112 |
-
triggered_keywords[keyword] = current_turn
|
113 |
-
|
114 |
-
return distance, index
|
115 |
|
|
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
print(device)
|
121 |
else:
|
122 |
-
print(
|
123 |
-
|
124 |
-
|
125 |
-
avg_embeddings = []
|
126 |
-
for keywords in tqdm(keywords_list_list):
|
127 |
-
keywords_lst=[]
|
128 |
-
# keywords.split(',')
|
129 |
-
for keyword in keywords:
|
130 |
-
inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
131 |
-
inputs.to(device)
|
132 |
-
with torch.no_grad():
|
133 |
-
outputs = model(**inputs)
|
134 |
-
embeddings = outputs.last_hidden_state.mean(dim=1)
|
135 |
-
keywords_lst.append(embeddings)
|
136 |
-
avg_embedding = sum(keywords_lst) / len(keywords_lst)
|
137 |
-
avg_embeddings.append(avg_embedding)
|
138 |
-
|
139 |
-
return avg_embeddings
|
140 |
-
|
141 |
|
142 |
def encode_to_avg(keywords_dict, model, tokenizer, device):
|
143 |
if torch.cuda.is_available():
|
@@ -147,8 +457,7 @@ def encode_to_avg(keywords_dict, model, tokenizer, device):
|
|
147 |
print('Using CPU')
|
148 |
print(device)
|
149 |
|
150 |
-
|
151 |
-
keyword_embeddings=[]
|
152 |
for keyword, weight in keywords_dict.items():
|
153 |
inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
154 |
inputs.to(device)
|
@@ -156,55 +465,19 @@ def encode_to_avg(keywords_dict, model, tokenizer, device):
|
|
156 |
outputs = model(**inputs)
|
157 |
embedding = outputs.last_hidden_state.mean(dim=1)
|
158 |
|
159 |
-
keyword_embedding=embedding * weight
|
160 |
-
|
161 |
-
keyword_embeddings.append(keyword_embedding * weight)
|
162 |
|
163 |
avg_embedding = sum(keyword_embeddings) / sum(keywords_dict.values())
|
164 |
-
|
165 |
return avg_embedding.tolist()
|
166 |
|
167 |
-
|
168 |
-
def fetch_response_from_db(query_keywords_dict,class_name):
|
169 |
-
|
170 |
-
|
171 |
-
avg_vec=np.array(encode_to_avg(query_keywords_dict, model, tokenizer, device))
|
172 |
-
nearVector = {
|
173 |
-
'vector': avg_vec
|
174 |
-
}
|
175 |
-
|
176 |
-
response = (
|
177 |
-
db_client.query
|
178 |
-
.get(class_name, ['keywords', 'summary'])
|
179 |
-
.with_near_vector(nearVector)
|
180 |
-
.with_limit(1)
|
181 |
-
.with_additional(['distance'])
|
182 |
-
.do()
|
183 |
-
)
|
184 |
-
|
185 |
-
print(response)
|
186 |
-
class_name=class_name[0].upper()+class_name[1:]
|
187 |
-
|
188 |
-
if class_name in response['data']['Get']:
|
189 |
-
results = response['data']['Get'][class_name]
|
190 |
-
return results[0]['_additional']['distance'],results[0]['summary'], results[0]['keywords']
|
191 |
-
else:
|
192 |
-
print(f"Class name {class_name} not found in response")
|
193 |
-
return None
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
def wrapper(message, history, max_tokens, temperature, top_p, window_size, distance_threshold, weight_keywords_users, weight_keywords_triggered, api_key, state):
|
198 |
initialize_clients(api_key)
|
199 |
return respond(message, history, max_tokens, temperature, top_p, window_size, distance_threshold, weight_keywords_users, weight_keywords_triggered, state)
|
200 |
|
201 |
-
|
202 |
-
#触发词及触发回合字典
|
203 |
-
|
204 |
-
|
205 |
def respond(
|
206 |
message,
|
207 |
-
history
|
208 |
max_tokens,
|
209 |
temperature,
|
210 |
top_p,
|
@@ -212,106 +485,65 @@ def respond(
|
|
212 |
distance_threshold,
|
213 |
weight_keywords_users,
|
214 |
weight_keywords_triggered,
|
215 |
-
|
216 |
):
|
217 |
-
|
218 |
-
triggered_keywords=triggered_keywords or {}
|
219 |
-
system_message_with_ad = """
|
220 |
-
# 角色
|
221 |
-
你是一个热情的聊天机器人
|
222 |
-
# 指令
|
223 |
-
你的目标是以一种微妙的方式将给定的广告内容嵌入对话中。插入广告时应尽量自然,避免突兀。优先正常回答用户问题,然后逐步以推荐方式引入广告内容。
|
224 |
-
注意要在商品的描述前面加上是来自哪个品牌的广告。
|
225 |
-
注意在推荐中不要脑补用户的身份,只是进行简单推荐。
|
226 |
-
注意要热情但是语气只要适度热情
|
227 |
-
# 输入格式
|
228 |
-
用户查询后跟随广告品牌,用<sep>分隔,广告品牌后跟随广告描述,再用<sep>分隔。
|
229 |
-
例如:我想买一条阔腿裤 <sep> 腾讯 <sep> 宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。
|
230 |
-
注意: 当没有<sep>时,正常回复用户,不插入广告。
|
231 |
-
# 输出格式
|
232 |
-
始终使用中文,只输出聊天内容,不输出任何自我分析的信息
|
233 |
-
"""
|
234 |
-
|
235 |
-
system_message_without_ad = """
|
236 |
-
你是一个热情的聊天机器人
|
237 |
-
"""
|
238 |
-
print(f"triggered_keywords{triggered_keywords}")
|
239 |
-
# 更新当前轮次
|
240 |
current_turn = len(history) + 1
|
241 |
-
|
242 |
-
# 检查历史记录的长度
|
243 |
if len(history) >= window_size:
|
244 |
combined_message_user = " ".join([h[0] for h in history[-window_size:] if h[0]] + [message])
|
245 |
-
combined_message_assistant=" ".join(h[1] for h in history[-window_size:] if h[1])
|
246 |
else:
|
247 |
combined_message_user = message
|
248 |
combined_message_assistant = ""
|
249 |
|
250 |
-
key_words_users=get_keywords(combined_message_user).split(',')
|
251 |
-
key_words_assistant=get_keywords(combined_message_assistant).split(',')
|
252 |
-
|
253 |
-
print(f"Initial keywords_users: {key_words_users}")
|
254 |
-
print(f"Initial keywords_assistant: {key_words_assistant}")
|
255 |
|
256 |
-
keywords_dict={}
|
257 |
-
for
|
258 |
-
if
|
259 |
-
keywords_dict[
|
260 |
else:
|
261 |
-
keywords_dict[
|
262 |
-
for
|
263 |
-
if
|
264 |
-
keywords_dict[
|
265 |
else:
|
266 |
-
keywords_dict[
|
267 |
|
268 |
-
#窗口内触发过的关键词权重下调为0.5
|
269 |
for keyword in list(keywords_dict.keys()):
|
270 |
if keyword in triggered_keywords:
|
271 |
if current_turn - triggered_keywords[keyword] < window_size:
|
272 |
keywords_dict[keyword] = weight_keywords_triggered
|
273 |
-
|
274 |
-
query_keywords = list(keywords_dict.keys())
|
275 |
-
print(keywords_dict)
|
276 |
-
|
277 |
-
distance,top_keywords_list,top_summary = fetch_response_from_db(keywords_dict,class_name)
|
278 |
|
279 |
-
|
280 |
-
|
281 |
-
if distance<distance_threshold:
|
282 |
-
ad =top_summary
|
283 |
|
|
|
|
|
284 |
messages = [{"role": "system", "content": system_message_with_ad}]
|
285 |
-
|
286 |
for val in history:
|
287 |
if val[0]:
|
288 |
messages.append({"role": "user", "content": val[0]})
|
289 |
-
if val[1]:
|
290 |
messages.append({"role": "assistant", "content": val[1]})
|
291 |
-
|
292 |
brands = ['腾讯', '百度', '京东', '华为', '小米', '苹果', '微软', '谷歌', '亚马逊']
|
293 |
brand = random.choice(brands)
|
294 |
messages.append({"role": "user", "content": f"{message} <sep>{brand}的 <sep> {ad}"})
|
295 |
|
296 |
-
#更新触发词
|
297 |
for keyword in query_keywords:
|
298 |
-
if any(
|
299 |
-
ad_keyword in keyword for ad_keyword in top_keywords_list
|
300 |
-
):
|
301 |
triggered_keywords[keyword] = current_turn
|
302 |
-
|
303 |
else:
|
304 |
messages = [{"role": "system", "content": system_message_without_ad}]
|
305 |
-
|
306 |
for val in history:
|
307 |
if val[0]:
|
308 |
messages.append({"role": "user", "content": val[0]})
|
309 |
if val[1]:
|
310 |
messages.append({"role": "assistant", "content": val[1]})
|
311 |
-
|
312 |
messages.append({"role": "user", "content": message})
|
313 |
|
314 |
-
|
315 |
response = client.chat.completions.create(
|
316 |
model="gpt-3.5-turbo",
|
317 |
messages=messages,
|
@@ -319,42 +551,22 @@ def respond(
|
|
319 |
temperature=temperature,
|
320 |
top_p=top_p,
|
321 |
)
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
# def chat_interface(message, history, max_tokens, temperature, top_p, window_size, distance_threshold):
|
327 |
-
# global triggered_keywords
|
328 |
-
# response, triggered_keywords = respond(
|
329 |
-
# message,
|
330 |
-
# history,
|
331 |
-
# max_tokens,
|
332 |
-
# temperature,
|
333 |
-
# top_p,
|
334 |
-
# window_size,
|
335 |
-
# distance_threshold,
|
336 |
-
# triggered_keywords
|
337 |
-
# )
|
338 |
-
# return response, history + [(message, response)]
|
339 |
|
340 |
demo = gr.ChatInterface(
|
341 |
wrapper,
|
342 |
additional_inputs=[
|
343 |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
344 |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
345 |
-
gr.Slider(
|
346 |
-
minimum=0.1,
|
347 |
-
maximum=1.0,
|
348 |
-
value=0.95,
|
349 |
-
step=0.05,
|
350 |
-
label="Top-p (nucleus sampling)",
|
351 |
-
),
|
352 |
gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Window size"),
|
353 |
gr.Slider(minimum=0.01, maximum=0.20, value=0.08, step=0.01, label="Distance threshold"),
|
354 |
gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
|
355 |
gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
|
356 |
gr.Textbox(label="api_key"),
|
357 |
-
|
358 |
],
|
359 |
)
|
360 |
|
|
|
1 |
+
# import gradio as gr
|
2 |
+
# from huggingface_hub import InferenceClient
|
3 |
+
# import json
|
4 |
+
# import random
|
5 |
+
# import re
|
6 |
+
# from load_data import load_data
|
7 |
+
# from openai import OpenAI
|
8 |
+
# from transformers import AutoTokenizer, AutoModel
|
9 |
+
# import weaviate
|
10 |
+
# import os
|
11 |
+
# import subprocess
|
12 |
+
# import torch
|
13 |
+
# from tqdm import tqdm
|
14 |
+
# import numpy as np
|
15 |
+
|
16 |
+
# # 设置 Matplotlib 的缓存目录
|
17 |
+
# os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
|
18 |
+
# # 设置 Hugging Face Transformers 的缓存目录
|
19 |
+
# os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
|
20 |
+
# # 确保这些目录存在
|
21 |
+
# os.makedirs(os.environ['MPLCONFIGDIR'], exist_ok=True)
|
22 |
+
# os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
|
23 |
+
|
24 |
+
# auth_config = weaviate.AuthApiKey(api_key="Y7c8DRmcxZ4nP5IJLwkznIsK84l6EdwfXwcH")
|
25 |
+
|
26 |
+
# URL = "https://39nlafviqvard82k6y8btq.c0.asia-southeast1.gcp.weaviate.cloud"
|
27 |
+
|
28 |
+
# # Connect to a WCS instance
|
29 |
+
# db_client = weaviate.Client(
|
30 |
+
# url=URL,
|
31 |
+
# auth_client_secret=auth_config
|
32 |
+
# )
|
33 |
+
|
34 |
+
|
35 |
+
# class_name="ad_DB02"
|
36 |
+
|
37 |
+
# device = torch.device(device='cuda' if torch.cuda.is_available() else 'cpu')
|
38 |
+
# tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
|
39 |
+
# model = AutoModel.from_pretrained("bert-base-chinese")
|
40 |
+
|
41 |
+
|
42 |
+
# global_api_key = None
|
43 |
+
# client = None
|
44 |
+
|
45 |
+
# def initialize_clients(api_key):
|
46 |
+
# global client
|
47 |
+
# client = OpenAI(api_key=api_key)
|
48 |
+
|
49 |
+
# def get_keywords(message):
|
50 |
+
# system_message = """
|
51 |
+
# # 角色
|
52 |
+
# 你是一个关键词提取机器人
|
53 |
+
# # 指令
|
54 |
+
# 你的目标是从用户的输入中提取关键词,这些关键词应该尽可能是购买意图相关的。关键词中应该尽可能注意那些名词和形容词
|
55 |
+
# # 输出格式
|
56 |
+
# 你应该直接输出关键词,关键词之间用空格分隔。例如:苹果 电脑 裤子 蓝色 裙
|
57 |
+
# # 注意:如果输入文本过短可以重复输出关键词,例如对输入“你好”可以输出:你好 你好 你好 你好 你好
|
58 |
+
# """
|
59 |
+
|
60 |
+
# messages = [{"role": "system", "content": system_message}]
|
61 |
+
# messages.append({"role": "user", "content": f"从下面的文本中给我提取五个关键词,只输出这五个关键词,以空格分隔{message}"})
|
62 |
+
|
63 |
+
# response = client.chat.completions.create(
|
64 |
+
# model="gpt-3.5-turbo",
|
65 |
+
# messages=messages,
|
66 |
+
# max_tokens=100,
|
67 |
+
# temperature=0.7,
|
68 |
+
# top_p=0.9,
|
69 |
+
# )
|
70 |
+
|
71 |
+
# keywords = response.choices[0].message.content.split(' ')
|
72 |
+
# return ','.join(keywords)
|
73 |
+
|
74 |
+
|
75 |
+
# #字符串匹配模块
|
76 |
+
# def keyword_match(query_keywords_dict, ad_keywords_lists, triggered_keywords, current_turn, window_size,distance_threshold):
|
77 |
+
# distance = 0
|
78 |
+
# most_matching_list = None
|
79 |
+
# index = 0
|
80 |
+
|
81 |
+
# # query_keywords = query_keywords.split(',')
|
82 |
+
# # query_keywords = [keyword for keyword in query_keywords if keyword]
|
83 |
+
|
84 |
+
# #匹配模块
|
85 |
+
# query_keywords= list(query_keywords_dict.keys())
|
86 |
+
|
87 |
+
# for i, lst in enumerate(ad_keywords_lists):
|
88 |
+
# lst = lst.split(',')
|
89 |
+
# matches = sum(
|
90 |
+
# any(
|
91 |
+
# ad_keyword in keyword and
|
92 |
+
# (
|
93 |
+
# keyword not in triggered_keywords or
|
94 |
+
# triggered_keywords.get(keyword) is None or
|
95 |
+
# current_turn - triggered_keywords.get(keyword, 0) > window_size
|
96 |
+
# ) * query_keywords_dict.get(keyword, 1) #计数乘以权重
|
97 |
+
# for keyword in query_keywords
|
98 |
+
# )
|
99 |
+
# for ad_keyword in lst
|
100 |
+
# )
|
101 |
+
# if matches > distance:
|
102 |
+
# distance = matches
|
103 |
+
# most_matching_list = lst
|
104 |
+
# index = i
|
105 |
+
|
106 |
+
# #更新对distance 有贡献的关键词
|
107 |
+
# if distance >= distance_threshold:
|
108 |
+
# for keyword in query_keywords:
|
109 |
+
# if any(
|
110 |
+
# ad_keyword in keyword for ad_keyword in most_matching_list
|
111 |
+
# ):
|
112 |
+
# triggered_keywords[keyword] = current_turn
|
113 |
+
|
114 |
+
# return distance, index
|
115 |
+
|
116 |
+
|
117 |
+
# def encode_list_to_avg(keywords_list_list, model, tokenizer, device):
|
118 |
+
# if torch.cuda.is_available():
|
119 |
+
# print('Using GPU')
|
120 |
+
# print(device)
|
121 |
+
# else:
|
122 |
+
# print('Using CPU')
|
123 |
+
# print(device)
|
124 |
+
|
125 |
+
# avg_embeddings = []
|
126 |
+
# for keywords in tqdm(keywords_list_list):
|
127 |
+
# keywords_lst=[]
|
128 |
+
# # keywords.split(',')
|
129 |
+
# for keyword in keywords:
|
130 |
+
# inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
131 |
+
# inputs.to(device)
|
132 |
+
# with torch.no_grad():
|
133 |
+
# outputs = model(**inputs)
|
134 |
+
# embeddings = outputs.last_hidden_state.mean(dim=1)
|
135 |
+
# keywords_lst.append(embeddings)
|
136 |
+
# avg_embedding = sum(keywords_lst) / len(keywords_lst)
|
137 |
+
# avg_embeddings.append(avg_embedding)
|
138 |
+
|
139 |
+
# return avg_embeddings
|
140 |
+
|
141 |
+
|
142 |
+
# def encode_to_avg(keywords_dict, model, tokenizer, device):
|
143 |
+
# if torch.cuda.is_available():
|
144 |
+
# print('Using GPU')
|
145 |
+
# print(device)
|
146 |
+
# else:
|
147 |
+
# print('Using CPU')
|
148 |
+
# print(device)
|
149 |
+
|
150 |
+
|
151 |
+
# keyword_embeddings=[]
|
152 |
+
# for keyword, weight in keywords_dict.items():
|
153 |
+
# inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
154 |
+
# inputs.to(device)
|
155 |
+
# with torch.no_grad():
|
156 |
+
# outputs = model(**inputs)
|
157 |
+
# embedding = outputs.last_hidden_state.mean(dim=1)
|
158 |
+
|
159 |
+
# keyword_embedding=embedding * weight
|
160 |
+
|
161 |
+
# keyword_embeddings.append(keyword_embedding * weight)
|
162 |
+
|
163 |
+
# avg_embedding = sum(keyword_embeddings) / sum(keywords_dict.values())
|
164 |
+
|
165 |
+
# return avg_embedding.tolist()
|
166 |
+
|
167 |
+
|
168 |
+
# def fetch_response_from_db(query_keywords_dict,class_name):
|
169 |
+
|
170 |
+
|
171 |
+
# avg_vec=np.array(encode_to_avg(query_keywords_dict, model, tokenizer, device))
|
172 |
+
# nearVector = {
|
173 |
+
# 'vector': avg_vec
|
174 |
+
# }
|
175 |
+
|
176 |
+
# response = (
|
177 |
+
# db_client.query
|
178 |
+
# .get(class_name, ['keywords', 'summary'])
|
179 |
+
# .with_near_vector(nearVector)
|
180 |
+
# .with_limit(1)
|
181 |
+
# .with_additional(['distance'])
|
182 |
+
# .do()
|
183 |
+
# )
|
184 |
+
|
185 |
+
# print(response)
|
186 |
+
# class_name=class_name[0].upper()+class_name[1:]
|
187 |
+
|
188 |
+
# if class_name in response['data']['Get']:
|
189 |
+
# results = response['data']['Get'][class_name]
|
190 |
+
# return results[0]['_additional']['distance'],results[0]['summary'], results[0]['keywords']
|
191 |
+
# else:
|
192 |
+
# print(f"Class name {class_name} not found in response")
|
193 |
+
# return None
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
# def wrapper(message, history, max_tokens, temperature, top_p, window_size, distance_threshold, weight_keywords_users, weight_keywords_triggered, api_key, state):
|
198 |
+
# initialize_clients(api_key)
|
199 |
+
# return respond(message, history, max_tokens, temperature, top_p, window_size, distance_threshold, weight_keywords_users, weight_keywords_triggered, state)
|
200 |
+
|
201 |
+
|
202 |
+
# #触发词及触发回合字典
|
203 |
+
|
204 |
+
|
205 |
+
# def respond(
|
206 |
+
# message,
|
207 |
+
# history: list[tuple[str, str]],
|
208 |
+
# max_tokens,
|
209 |
+
# temperature,
|
210 |
+
# top_p,
|
211 |
+
# window_size,
|
212 |
+
# distance_threshold,
|
213 |
+
# weight_keywords_users,
|
214 |
+
# weight_keywords_triggered,
|
215 |
+
# triggered_keywords
|
216 |
+
# ):
|
217 |
+
|
218 |
+
# triggered_keywords=triggered_keywords or {}
|
219 |
+
# system_message_with_ad = """
|
220 |
+
# # 角色
|
221 |
+
# 你是一个热情的聊天机器人
|
222 |
+
# # 指令
|
223 |
+
# 你的目标是以一种微妙的方式将给定的广告内容嵌入对话中。插入广告时应尽量自然,避免突兀。优先正常回答用户问题,然后逐步以推荐方式引入广告内容。
|
224 |
+
# 注意要在商品的描述前面加上是来自哪个品牌的广告。
|
225 |
+
# 注意在推荐中不要脑补用户的身份,只是进行简单推荐。
|
226 |
+
# 注意要热情但是语气只要适度热情
|
227 |
+
# # 输入格式
|
228 |
+
# 用户查询后跟随广告品牌,用<sep>分隔,广告品牌后跟随广告描述,再用<sep>分隔。
|
229 |
+
# 例如:我想买一条阔腿裤 <sep> 腾讯 <sep> 宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。
|
230 |
+
# 注意: 当没有<sep>时,正常回复用户,不插入广告。
|
231 |
+
# # 输出格式
|
232 |
+
# 始终使用中文,只输出聊天内容,不输出任何自我分析的信息
|
233 |
+
# """
|
234 |
+
|
235 |
+
# system_message_without_ad = """
|
236 |
+
# 你是一个热情的聊天机器人
|
237 |
+
# """
|
238 |
+
# print(f"triggered_keywords{triggered_keywords}")
|
239 |
+
# # 更新当前轮次
|
240 |
+
# current_turn = len(history) + 1
|
241 |
+
# print(f"current_turn: {current_turn}")
|
242 |
+
# # 检查历史记录的长度
|
243 |
+
# if len(history) >= window_size:
|
244 |
+
# combined_message_user = " ".join([h[0] for h in history[-window_size:] if h[0]] + [message])
|
245 |
+
# combined_message_assistant=" ".join(h[1] for h in history[-window_size:] if h[1])
|
246 |
+
# else:
|
247 |
+
# combined_message_user = message
|
248 |
+
# combined_message_assistant = ""
|
249 |
+
|
250 |
+
# key_words_users=get_keywords(combined_message_user).split(',')
|
251 |
+
# key_words_assistant=get_keywords(combined_message_assistant).split(',')
|
252 |
+
|
253 |
+
# print(f"Initial keywords_users: {key_words_users}")
|
254 |
+
# print(f"Initial keywords_assistant: {key_words_assistant}")
|
255 |
+
|
256 |
+
# keywords_dict={}
|
257 |
+
# for keywords in key_words_users:
|
258 |
+
# if keywords in keywords_dict:
|
259 |
+
# keywords_dict[keywords]+=weight_keywords_users
|
260 |
+
# else:
|
261 |
+
# keywords_dict[keywords]=weight_keywords_users
|
262 |
+
# for keywords in key_words_assistant:
|
263 |
+
# if keywords in keywords_dict:
|
264 |
+
# keywords_dict[keywords]+=1
|
265 |
+
# else:
|
266 |
+
# keywords_dict[keywords]=1
|
267 |
+
|
268 |
+
# #窗口内触发过的关键词权重下调为0.5
|
269 |
+
# for keyword in list(keywords_dict.keys()):
|
270 |
+
# if keyword in triggered_keywords:
|
271 |
+
# if current_turn - triggered_keywords[keyword] < window_size:
|
272 |
+
# keywords_dict[keyword] = weight_keywords_triggered
|
273 |
+
|
274 |
+
# query_keywords = list(keywords_dict.keys())
|
275 |
+
# print(keywords_dict)
|
276 |
+
|
277 |
+
# distance,top_keywords_list,top_summary = fetch_response_from_db(keywords_dict,class_name)
|
278 |
+
|
279 |
+
# print(f"distance: {distance}")
|
280 |
+
|
281 |
+
# if distance<distance_threshold:
|
282 |
+
# ad =top_summary
|
283 |
+
|
284 |
+
# messages = [{"role": "system", "content": system_message_with_ad}]
|
285 |
+
|
286 |
+
# for val in history:
|
287 |
+
# if val[0]:
|
288 |
+
# messages.append({"role": "user", "content": val[0]})
|
289 |
+
# if val[1]:
|
290 |
+
# messages.append({"role": "assistant", "content": val[1]})
|
291 |
+
|
292 |
+
# brands = ['腾讯', '百度', '京东', '华为', '小米', '苹果', '微软', '谷歌', '亚马逊']
|
293 |
+
# brand = random.choice(brands)
|
294 |
+
# messages.append({"role": "user", "content": f"{message} <sep>{brand}的 <sep> {ad}"})
|
295 |
+
|
296 |
+
# #更新触发词
|
297 |
+
# for keyword in query_keywords:
|
298 |
+
# if any(
|
299 |
+
# ad_keyword in keyword for ad_keyword in top_keywords_list
|
300 |
+
# ):
|
301 |
+
# triggered_keywords[keyword] = current_turn
|
302 |
+
|
303 |
+
# else:
|
304 |
+
# messages = [{"role": "system", "content": system_message_without_ad}]
|
305 |
+
|
306 |
+
# for val in history:
|
307 |
+
# if val[0]:
|
308 |
+
# messages.append({"role": "user", "content": val[0]})
|
309 |
+
# if val[1]:
|
310 |
+
# messages.append({"role": "assistant", "content": val[1]})
|
311 |
+
|
312 |
+
# messages.append({"role": "user", "content": message})
|
313 |
+
|
314 |
+
|
315 |
+
# response = client.chat.completions.create(
|
316 |
+
# model="gpt-3.5-turbo",
|
317 |
+
# messages=messages,
|
318 |
+
# max_tokens=max_tokens,
|
319 |
+
# temperature=temperature,
|
320 |
+
# top_p=top_p,
|
321 |
+
# )
|
322 |
+
|
323 |
+
# return response.choices[0].message.content , triggered_keywords
|
324 |
+
|
325 |
+
|
326 |
+
# # def chat_interface(message, history, max_tokens, temperature, top_p, window_size, distance_threshold):
|
327 |
+
# # global triggered_keywords
|
328 |
+
# # response, triggered_keywords = respond(
|
329 |
+
# # message,
|
330 |
+
# # history,
|
331 |
+
# # max_tokens,
|
332 |
+
# # temperature,
|
333 |
+
# # top_p,
|
334 |
+
# # window_size,
|
335 |
+
# # distance_threshold,
|
336 |
+
# # triggered_keywords
|
337 |
+
# # )
|
338 |
+
# # return response, history + [(message, response)]
|
339 |
+
|
340 |
+
# demo = gr.ChatInterface(
|
341 |
+
# wrapper,
|
342 |
+
# additional_inputs=[
|
343 |
+
# gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
344 |
+
# gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
345 |
+
# gr.Slider(
|
346 |
+
# minimum=0.1,
|
347 |
+
# maximum=1.0,
|
348 |
+
# value=0.95,
|
349 |
+
# step=0.05,
|
350 |
+
# label="Top-p (nucleus sampling)",
|
351 |
+
# ),
|
352 |
+
# gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Window size"),
|
353 |
+
# gr.Slider(minimum=0.01, maximum=0.20, value=0.08, step=0.01, label="Distance threshold"),
|
354 |
+
# gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
|
355 |
+
# gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
|
356 |
+
# gr.Textbox(label="api_key"),
|
357 |
+
# 'state'
|
358 |
+
# ],
|
359 |
+
# )
|
360 |
+
|
361 |
+
# if __name__ == "__main__":
|
362 |
+
# demo.launch(share=True)
|
363 |
+
|
364 |
import gradio as gr
|
365 |
from huggingface_hub import InferenceClient
|
366 |
import json
|
|
|
370 |
from openai import OpenAI
|
371 |
from transformers import AutoTokenizer, AutoModel
|
372 |
import weaviate
|
373 |
+
import os
|
374 |
+
import subprocess
|
375 |
import torch
|
376 |
from tqdm import tqdm
|
377 |
import numpy as np
|
378 |
|
379 |
+
# 设置 Matplotlib 和 Hugging Face Transformers 的缓存目录
|
380 |
+
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
|
381 |
+
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
|
382 |
+
os.makedirs(os.environ['MPLCONFIGDIR'], exist_ok=True)
|
383 |
+
os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
|
|
|
|
|
384 |
|
385 |
auth_config = weaviate.AuthApiKey(api_key="Y7c8DRmcxZ4nP5IJLwkznIsK84l6EdwfXwcH")
|
|
|
386 |
URL = "https://39nlafviqvard82k6y8btq.c0.asia-southeast1.gcp.weaviate.cloud"
|
387 |
|
388 |
# Connect to a WCS instance
|
389 |
db_client = weaviate.Client(
|
390 |
+
url=URL,
|
391 |
+
auth_client_secret=auth_config
|
392 |
)
|
393 |
|
394 |
+
class_name = "ad_DB02"
|
|
|
|
|
395 |
device = torch.device(device='cuda' if torch.cuda.is_available() else 'cpu')
|
396 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
|
397 |
model = AutoModel.from_pretrained("bert-base-chinese")
|
398 |
|
|
|
399 |
global_api_key = None
|
400 |
client = None
|
401 |
|
|
|
413 |
你应该直接输出关键词,关键词之间用空格分隔。例如:苹果 电脑 裤子 蓝色 裙
|
414 |
# 注意:如果输入文本过短可以重复输出关键词,例如对输入“你好”可以输出:你好 你好 你好 你好 你好
|
415 |
"""
|
|
|
416 |
messages = [{"role": "system", "content": system_message}]
|
417 |
messages.append({"role": "user", "content": f"从下面的文本中给我提取五个关键词,只输出这五个关键词,以空格分隔{message}"})
|
418 |
|
|
|
427 |
keywords = response.choices[0].message.content.split(' ')
|
428 |
return ','.join(keywords)
|
429 |
|
430 |
+
def fetch_response_from_db(query_keywords_dict, class_name):
|
431 |
+
avg_vec = np.array(encode_to_avg(query_keywords_dict, model, tokenizer, device))
|
432 |
+
nearVector = {'vector': avg_vec}
|
433 |
|
434 |
+
response = (
|
435 |
+
db_client.query
|
436 |
+
.get(class_name, ['keywords', 'summary'])
|
437 |
+
.with_near_vector(nearVector)
|
438 |
+
.with_limit(1)
|
439 |
+
.with_additional(['distance'])
|
440 |
+
.do()
|
441 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
442 |
|
443 |
+
class_name = class_name[0].upper() + class_name[1:]
|
444 |
|
445 |
+
if class_name in response['data']['Get']:
|
446 |
+
results = response['data']['Get'][class_name]
|
447 |
+
return results[0]['_additional']['distance'], results[0]['summary'], results[0]['keywords']
|
|
|
448 |
else:
|
449 |
+
print(f"Class name {class_name} not found in response")
|
450 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
|
452 |
def encode_to_avg(keywords_dict, model, tokenizer, device):
|
453 |
if torch.cuda.is_available():
|
|
|
457 |
print('Using CPU')
|
458 |
print(device)
|
459 |
|
460 |
+
keyword_embeddings = []
|
|
|
461 |
for keyword, weight in keywords_dict.items():
|
462 |
inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
463 |
inputs.to(device)
|
|
|
465 |
outputs = model(**inputs)
|
466 |
embedding = outputs.last_hidden_state.mean(dim=1)
|
467 |
|
468 |
+
keyword_embedding = embedding * weight
|
469 |
+
keyword_embeddings.append(keyword_embedding)
|
|
|
470 |
|
471 |
avg_embedding = sum(keyword_embeddings) / sum(keywords_dict.values())
|
|
|
472 |
return avg_embedding.tolist()
|
473 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
474 |
def wrapper(message, history, max_tokens, temperature, top_p, window_size, distance_threshold, weight_keywords_users, weight_keywords_triggered, api_key, state):
|
475 |
initialize_clients(api_key)
|
476 |
return respond(message, history, max_tokens, temperature, top_p, window_size, distance_threshold, weight_keywords_users, weight_keywords_triggered, state)
|
477 |
|
|
|
|
|
|
|
|
|
478 |
def respond(
|
479 |
message,
|
480 |
+
history,
|
481 |
max_tokens,
|
482 |
temperature,
|
483 |
top_p,
|
|
|
485 |
distance_threshold,
|
486 |
weight_keywords_users,
|
487 |
weight_keywords_triggered,
|
488 |
+
state
|
489 |
):
|
490 |
+
triggered_keywords = state.get('triggered_keywords', {})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
491 |
current_turn = len(history) + 1
|
492 |
+
|
|
|
493 |
if len(history) >= window_size:
|
494 |
combined_message_user = " ".join([h[0] for h in history[-window_size:] if h[0]] + [message])
|
495 |
+
combined_message_assistant = " ".join(h[1] for h in history[-window_size:] if h[1])
|
496 |
else:
|
497 |
combined_message_user = message
|
498 |
combined_message_assistant = ""
|
499 |
|
500 |
+
key_words_users = get_keywords(combined_message_user).split(',')
|
501 |
+
key_words_assistant = get_keywords(combined_message_assistant).split(',')
|
|
|
|
|
|
|
502 |
|
503 |
+
keywords_dict = {}
|
504 |
+
for keyword in key_words_users:
|
505 |
+
if keyword in keywords_dict:
|
506 |
+
keywords_dict[keyword] += weight_keywords_users
|
507 |
else:
|
508 |
+
keywords_dict[keyword] = weight_keywords_users
|
509 |
+
for keyword in key_words_assistant:
|
510 |
+
if keyword in keywords_dict:
|
511 |
+
keywords_dict[keyword] += 1
|
512 |
else:
|
513 |
+
keywords_dict[keyword] = 1
|
514 |
|
|
|
515 |
for keyword in list(keywords_dict.keys()):
|
516 |
if keyword in triggered_keywords:
|
517 |
if current_turn - triggered_keywords[keyword] < window_size:
|
518 |
keywords_dict[keyword] = weight_keywords_triggered
|
|
|
|
|
|
|
|
|
|
|
519 |
|
520 |
+
query_keywords = list(keywords_dict.keys())
|
521 |
+
distance, top_keywords_list, top_summary = fetch_response_from_db(keywords_dict, class_name)
|
|
|
|
|
522 |
|
523 |
+
if distance < distance_threshold:
|
524 |
+
ad = top_summary
|
525 |
messages = [{"role": "system", "content": system_message_with_ad}]
|
|
|
526 |
for val in history:
|
527 |
if val[0]:
|
528 |
messages.append({"role": "user", "content": val[0]})
|
529 |
+
if val[1]:
|
530 |
messages.append({"role": "assistant", "content": val[1]})
|
|
|
531 |
brands = ['腾讯', '百度', '京东', '华为', '小米', '苹果', '微软', '谷歌', '亚马逊']
|
532 |
brand = random.choice(brands)
|
533 |
messages.append({"role": "user", "content": f"{message} <sep>{brand}的 <sep> {ad}"})
|
534 |
|
|
|
535 |
for keyword in query_keywords:
|
536 |
+
if any(ad_keyword in keyword for ad_keyword in top_keywords_list):
|
|
|
|
|
537 |
triggered_keywords[keyword] = current_turn
|
|
|
538 |
else:
|
539 |
messages = [{"role": "system", "content": system_message_without_ad}]
|
|
|
540 |
for val in history:
|
541 |
if val[0]:
|
542 |
messages.append({"role": "user", "content": val[0]})
|
543 |
if val[1]:
|
544 |
messages.append({"role": "assistant", "content": val[1]})
|
|
|
545 |
messages.append({"role": "user", "content": message})
|
546 |
|
|
|
547 |
response = client.chat.completions.create(
|
548 |
model="gpt-3.5-turbo",
|
549 |
messages=messages,
|
|
|
551 |
temperature=temperature,
|
552 |
top_p=top_p,
|
553 |
)
|
554 |
+
|
555 |
+
state['triggered_keywords'] = triggered_keywords
|
556 |
+
return response.choices[0].message.content, state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
557 |
|
558 |
demo = gr.ChatInterface(
|
559 |
wrapper,
|
560 |
additional_inputs=[
|
561 |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
562 |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
563 |
+
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
|
|
|
|
|
|
|
|
|
|
|
|
|
564 |
gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Window size"),
|
565 |
gr.Slider(minimum=0.01, maximum=0.20, value=0.08, step=0.01, label="Distance threshold"),
|
566 |
gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
|
567 |
gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
|
568 |
gr.Textbox(label="api_key"),
|
569 |
+
gr.State(label="state")
|
570 |
],
|
571 |
)
|
572 |
|