IdlecloudX commited on
Commit
428ecd5
·
verified ·
1 Parent(s): 356be23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -177
app.py CHANGED
@@ -1,13 +1,13 @@
1
  import os
2
- import json
3
  import gradio as gr
4
  import huggingface_hub
5
  import numpy as np
6
  import onnxruntime as rt
7
  import pandas as pd
8
  from PIL import Image
9
- from huggingface_hub import whoami
10
 
 
11
  from translator import translate_texts
12
 
13
  # ------------------------------------------------------------------
@@ -17,8 +17,14 @@ MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
17
  MODEL_FILENAME = "model.onnx"
18
  LABEL_FILENAME = "selected_tags.csv"
19
 
20
- HF_TOKEN = os.environ.get("HF_TOKEN")
21
- ACCESS_PASSWORD = os.environ.get("ACCESS_PASSWORD")
 
 
 
 
 
 
22
 
23
  # ------------------------------------------------------------------
24
  # Tagger 类 (全局实例化)
@@ -58,40 +64,53 @@ class Tagger:
58
 
59
  # ------------------------- preprocess -------------------------
60
  def _preprocess(self, img: Image.Image) -> np.ndarray:
61
- if img is None: raise ValueError("输入图像不能为空")
62
- if img.mode != "RGB": img = img.convert("RGB")
 
 
63
  size = max(img.size)
64
  canvas = Image.new("RGB", (size, size), (255, 255, 255))
65
  canvas.paste(img, ((size - img.width) // 2, (size - img.height) // 2))
66
  if size != self.input_size:
67
  canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC)
68
- return np.array(canvas)[:, :, ::-1].astype(np.float32)
69
 
70
  # --------------------------- predict --------------------------
71
  def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85):
72
- if self.model is None: raise RuntimeError("模型未成功加载,无法进行预测。")
 
73
  inp_name = self.model.get_inputs()[0].name
74
  outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0]
75
 
76
  res = {"ratings": {}, "general": {}, "characters": {}}
77
  tag_categories_for_translation = {"ratings": [], "general": [], "characters": []}
78
 
79
- for cat_key, cat_indices in self.categories.items():
80
- sub_res = {}
81
- if cat_key == "rating":
82
- for idx in cat_indices:
83
- tag_name = self.tag_names[idx].replace("_", " ")
84
- sub_res[tag_name] = float(outputs[idx])
85
- else:
86
- threshold = char_th if cat_key == "character" else gen_th
87
- for idx in cat_indices:
88
- if outputs[idx] > threshold:
89
- tag_name = self.tag_names[idx].replace("_", " ")
90
- sub_res[tag_name] = float(outputs[idx])
91
-
92
- res_key = "characters" if cat_key == "character" else cat_key
93
- res[res_key] = dict(sorted(sub_res.items(), key=lambda kv: kv[1], reverse=True))
94
- tag_categories_for_translation[res_key] = list(res[res_key].keys())
 
 
 
 
 
 
 
 
 
 
95
 
96
  return res, tag_categories_for_translation
97
 
@@ -100,7 +119,7 @@ try:
100
  tagger_instance = Tagger()
101
  except RuntimeError as e:
102
  print(f"应用启动时Tagger初始化失败: {e}")
103
- tagger_instance = None
104
 
105
  # ------------------------------------------------------------------
106
  # Gradio UI
@@ -123,7 +142,8 @@ function copyToClipboard(text) {
123
  }
124
  navigator.clipboard.writeText(text).then(() => {
125
  const feedback = document.createElement('div');
126
- let displayText = String(text).substring(0, 30) + (String(text).length > 30 ? '...' : '');
 
127
  feedback.textContent = '已复制: ' + displayText;
128
  Object.assign(feedback.style, {
129
  position: 'fixed', bottom: '20px', left: '50%', transform: 'translateX(-50%)',
@@ -136,7 +156,7 @@ function copyToClipboard(text) {
136
  setTimeout(() => { if (document.body.contains(feedback)) document.body.removeChild(feedback); }, 500);
137
  }, 1500);
138
  }).catch(err => {
139
- console.error('Failed to copy tag. Error:', err, 'Attempted to copy text:', text);
140
  });
141
  }
142
  """
@@ -144,203 +164,241 @@ function copyToClipboard(text) {
144
  with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css, js=_js_functions) as demo:
145
  gr.Markdown("# 🖼️ AI 图像标签分析器")
146
  gr.Markdown("上传图片自动识别标签,支持中英文显示和一键复制。[NovelAI在线绘画](https://nai.idlecloud.cc/)")
147
-
148
- # 统一的状态和登录/登出控制区域
149
- with gr.Row():
150
- user_status_html = gr.HTML("<p>ℹ️ 正在检查登录状态...</p>")
151
- with gr.Row():
152
- login_button = gr.LoginButton(value="🤗 通过 Hugging Face 登录", visible=True)
153
- logout_button = gr.LogoutButton(value="退出登录", visible=False)
154
 
155
  state_res = gr.State({})
156
  state_translations_dict = gr.State({})
 
 
157
 
158
- with gr.Row(visible=False) as main_interface:
159
  with gr.Column(scale=1):
160
  img_in = gr.Image(type="pil", label="上传图片", height=300)
 
161
  btn = gr.Button("🚀 开始分析", variant="primary", elem_classes=["btn-analyze-container"])
162
 
163
  with gr.Accordion("⚙️ 高级设置", open=False):
164
- gen_slider = gr.Slider(0, 1, value=0.35, step=0.01, label="通用标签阈值")
165
- char_slider = gr.Slider(0, 1, value=0.85, step=0.01, label="角色标签阈值")
166
  show_tag_scores = gr.Checkbox(True, label="在列表中显示标签置信度")
167
 
168
- with gr.Accordion("🔑 翻译密钥设置", open=True):
169
- gr.Markdown("输入访问密码可使用空间配置的密钥,否则请提供您自己的密钥。")
170
- access_password_in = gr.Textbox(label="访问密码 (可选)", type="password", lines=1)
171
- tencent_id_in = gr.Textbox(label="腾讯云 Secret ID", lines=1)
172
- tencent_key_in = gr.Textbox(label="腾讯云 Secret Key", lines=1, type="password")
173
- baidu_json_in = gr.Textbox(label="百度翻译凭证 (JSON 格式)", lines=3, placeholder='[{"app_id": "...", "secret_key": "..."}]')
 
 
 
 
 
 
174
 
175
  with gr.Accordion("📊 标签汇总设置", open=True):
176
- sum_cats = gr.CheckboxGroup(["通用标签", "角色标签", "评分标签"], value=["通用标签", "角色标签"], label="汇总类别")
177
- sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="标签分隔符")
 
 
 
 
178
  sum_show_zh = gr.Checkbox(False, label="在汇总中显示中文翻译")
179
 
180
  processing_info = gr.Markdown("", visible=False)
181
 
182
  with gr.Column(scale=2):
183
  with gr.Tabs():
184
- with gr.TabItem("🏷️ 通用标签"): out_general = gr.HTML(label="General Tags")
185
- with gr.TabItem("👤 角色标签"): out_char = gr.HTML(label="Character Tags")
186
- with gr.TabItem(" 评分标签"): out_rating = gr.HTML(label="Rating Tags")
187
- gr.Markdown("### 标签汇总结果")
188
- out_summary = gr.Textbox(label="标签汇总", lines=5, show_copy_button=True)
 
 
189
 
190
- # ----------------- 辅助函数 -----------------
191
- def get_token_from_request(request: gr.Request) -> str | None:
192
- auth_header = request.headers.get("authorization")
193
- if auth_header and auth_header.startswith("Bearer "):
194
- return auth_header.split(" ")[1]
195
- return None
196
-
197
- def check_user_status(request: gr.Request):
198
- token = get_token_from_request(request)
199
- if token:
200
- try:
201
- user_info = whoami(token=token)
202
- welcome_msg = f"<p style='color:green;font-weight:bold;'>✅ 您好, {user_info.get('fullname', user_info.get('name'))}!欢迎使用。</p>"
203
- # 已登录:显示欢迎信息,隐藏登录按钮,显示登出按钮,显示主界面
204
- return (
205
- gr.update(value=welcome_msg),
206
- gr.update(visible=False),
207
- gr.update(visible=True),
208
- gr.update(visible=True)
209
- )
210
- except Exception as e:
211
- print(f"Token 无效或已过期: {e}")
212
- error_msg = "<p style='color:red;'>🚫 登录令牌无效或已过期,请重新登录。</p>"
213
- # 令牌无效:显示错误,显示登录按钮,隐藏登出按钮,隐藏主界面
214
- return (
215
- gr.update(value=error_msg),
216
- gr.update(visible=True),
217
- gr.update(visible=False),
218
- gr.update(visible=False)
219
- )
220
-
221
- # 未登录
222
- info_msg = "<p style='color:#d46b08;'>🚫 您需要登录才能使用此应用。</p>"
223
- return (
224
- gr.update(value=info_msg),
225
- gr.update(visible=True),
226
- gr.update(visible=False),
227
- gr.update(visible=False)
228
- )
229
 
230
- def format_tags_html(tags_dict, translations_list, show_scores):
231
- if not tags_dict: return "<p>暂无标签</p>"
232
  html = '<div class="label-container">'
233
- for i, (tag, score) in enumerate(tags_dict.items()):
 
 
 
234
  escaped_tag = tag.replace("'", "\\'")
 
235
  html += '<div class="tag-item">'
236
  tag_display_html = f'<span class="tag-en" onclick="copyToClipboard(\'{escaped_tag}\')">{tag}</span>'
237
- if i < len(translations_list) and translations_list[i]:
238
- tag_display_html += f'<span class="tag-zh">({translations_list[i]})</span>'
 
 
 
 
239
  html += f'<div>{tag_display_html}</div>'
240
- if show_scores: html += f'<span class="tag-score">{score:.3f}</span>'
 
241
  html += '</div>'
242
- return html + '</div>'
243
-
244
- def generate_summary_text_content(current_res, translations, sum_cats, sep_type, show_zh):
245
- if not current_res: return "请先分析图像。"
246
- parts, sep = [], {"逗号": ", ", "换行": "\n", "空格": " "}.get(sep_type, ", ")
247
- cat_map = {"通用标签": "general", "角色标签": "characters", "评分标签": "ratings"}
248
- for cat_name in sum_cats:
249
- cat_key = cat_map.get(cat_name)
250
- if cat_key and current_res.get(cat_key):
251
- tags_en, trans = list(current_res[cat_key].keys()), translations.get(cat_key, [])
252
- tags_to_join = [f"{en}({zh})" if show_zh and i < len(trans) and trans[i] else en for i, en in enumerate(tags_en)]
253
- if tags_to_join: parts.append(sep.join(tags_to_join))
254
- return "\n".join(parts) if parts else "选定的类别中没有找到标签。"
255
-
256
- # ----------------- 主要处理回调 -----------------
257
- def process_image_and_generate_outputs(
258
- img, g_th, c_th, s_scores,
259
- access_pwd, user_tencent_id, user_tencent_key, user_baidu_json,
260
- sum_cats, s_sep, s_zh_in_sum,
261
- request: gr.Request
262
- ):
263
- if get_token_from_request(request) is None:
264
- raise gr.Error("错误:您的登录会话已失效,请刷新页面后重试。")
265
- if img is None:
266
- raise gr.Error("请先上传图片。")
267
- if tagger_instance is None:
268
- raise gr.Error("分析器未成功初始化,请检查后台错误。")
269
 
270
- yield gr.update(interactive=False, value="🔄 处理中..."), gr.update(visible=True, value="🔄 正在分析..."), *["<p>分析中...</p>"]*3, "分析中...", {}, {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
- use_space_keys = bool(ACCESS_PASSWORD and access_pwd == ACCESS_PASSWORD)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
- final_tencent_id, final_tencent_key, baidu_json_str = (
275
- (os.environ.get("TENCENT_SECRET_ID"), os.environ.get("TENCENT_SECRET_KEY"), os.environ.get("BAIDU_CREDENTIALS_JSON", "[]"))
276
- if use_space_keys else (user_tencent_id, user_tencent_key, user_baidu_json)
 
 
 
 
 
 
277
  )
278
 
279
- final_baidu_creds_list = []
280
- if baidu_json_str and baidu_json_str.strip():
281
- try:
282
- parsed_data = json.loads(baidu_json_str)
283
- if isinstance(parsed_data, list): final_baidu_creds_list = parsed_data
284
- except json.JSONDecodeError: print("提供的百度凭证JSON无效。")
285
-
286
  try:
287
- res, tag_cats_original = tagger_instance.predict(img, g_th, c_th)
288
- all_tags = [tag for cat in tag_cats_original.values() for tag in cat]
289
-
290
- translations_flat = translate_texts(
291
- all_tags,
292
- tencent_secret_id=final_tencent_id,
293
- tencent_secret_key=final_tencent_key,
294
- baidu_credentials_list=final_baidu_creds_list
295
- ) if all_tags else []
296
 
297
- translations, offset = {}, 0
298
- for cat_key, tags in tag_cats_original.items():
299
- translations[cat_key] = translations_flat[offset : offset + len(tags)]
300
- offset += len(tags)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
- outputs_html = {k: format_tags_html(res.get(k, {}), translations.get(k, []), s_scores) for k in ["general", "characters", "ratings"]}
303
- summary = generate_summary_text_content(res, translations, sum_cats, s_sep, s_zh_in_sum)
 
 
 
304
 
305
- yield gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value="✅ 分析完成! " + ("(使用空间密钥)" if use_space_keys else "(使用自定义密钥)")), outputs_html["general"], outputs_html["characters"], outputs_html["ratings"], summary, res, translations
 
 
 
 
306
 
307
  except Exception as e:
308
  import traceback
309
- traceback.print_exc()
310
- raise gr.Error(f"处理时发生错误: {e}")
311
-
312
- # ----------------- 绑定事件 -----------------
313
- demo.load(
314
- fn=check_user_status,
315
- inputs=None,
316
- outputs=[user_status_html, login_button, logout_button, main_interface],
317
- queue=False
318
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
  btn.click(
321
  process_image_and_generate_outputs,
322
- inputs=[
323
- img_in, gen_slider, char_slider, show_tag_scores,
324
- access_password_in, tencent_id_in, tencent_key_in, baidu_json_in,
325
- sum_cats, sum_sep, sum_show_zh
326
- ],
327
- outputs=[
328
- btn, processing_info,
329
- out_general, out_char, out_rating,
330
- out_summary,
331
- state_res, state_translations_dict
332
- ],
333
  )
334
 
335
- summary_controls = [sum_cats, sum_sep, sum_show_zh]
336
  for ctrl in summary_controls:
337
  ctrl.change(
338
- fn=lambda r, t, c, s, z: generate_summary_text_content(r, t, c, s, z),
339
- inputs=[state_res, state_translations_dict] + summary_controls,
340
- outputs=[out_summary],
341
  )
342
-
343
  if __name__ == "__main__":
344
  if tagger_instance is None:
345
- print("CRITICAL: Tagger 未能初始化,应用功能将受限。")
346
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import os
 
2
  import gradio as gr
3
  import huggingface_hub
4
  import numpy as np
5
  import onnxruntime as rt
6
  import pandas as pd
7
  from PIL import Image
8
+ from huggingface_hub import login
9
 
10
+ # 导入修改后的翻译函数
11
  from translator import translate_texts
12
 
13
  # ------------------------------------------------------------------
 
17
  MODEL_FILENAME = "model.onnx"
18
  LABEL_FILENAME = "selected_tags.csv"
19
 
20
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
21
+ if HF_TOKEN:
22
+ try:
23
+ login(token=HF_TOKEN)
24
+ except Exception as e:
25
+ print(f"Hugging Face登录失败: {e}")
26
+ else:
27
+ print("⚠️ 未检测到 HF_TOKEN,私有模型可能下载失败")
28
 
29
  # ------------------------------------------------------------------
30
  # Tagger 类 (全局实例化)
 
64
 
65
  # ------------------------- preprocess -------------------------
66
  def _preprocess(self, img: Image.Image) -> np.ndarray:
67
+ if img is None:
68
+ raise ValueError("输入图像不能为空")
69
+ if img.mode != "RGB":
70
+ img = img.convert("RGB")
71
  size = max(img.size)
72
  canvas = Image.new("RGB", (size, size), (255, 255, 255))
73
  canvas.paste(img, ((size - img.width) // 2, (size - img.height) // 2))
74
  if size != self.input_size:
75
  canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC)
76
+ return np.array(canvas)[:, :, ::-1].astype(np.float32) # to BGR
77
 
78
  # --------------------------- predict --------------------------
79
  def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85):
80
+ if self.model is None:
81
+ raise RuntimeError("模型未成功加载,无法进行预测。")
82
  inp_name = self.model.get_inputs()[0].name
83
  outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0]
84
 
85
  res = {"ratings": {}, "general": {}, "characters": {}}
86
  tag_categories_for_translation = {"ratings": [], "general": [], "characters": []}
87
 
88
+ for idx in self.categories["rating"]:
89
+ tag_name = self.tag_names[idx].replace("_", " ")
90
+ res["ratings"][tag_name] = float(outputs[idx])
91
+ tag_categories_for_translation["ratings"].append(tag_name)
92
+
93
+ for idx in self.categories["general"]:
94
+ if outputs[idx] > gen_th:
95
+ tag_name = self.tag_names[idx].replace("_", " ")
96
+ res["general"][tag_name] = float(outputs[idx])
97
+ tag_categories_for_translation["general"].append(tag_name)
98
+
99
+ for idx in self.categories["character"]:
100
+ if outputs[idx] > char_th:
101
+ tag_name = self.tag_names[idx].replace("_", " ")
102
+ res["characters"][tag_name] = float(outputs[idx])
103
+ tag_categories_for_translation["characters"].append(tag_name)
104
+
105
+
106
+ res["general"] = dict(sorted(res["general"].items(), key=lambda kv: kv[1], reverse=True))
107
+ res["characters"] = dict(sorted(res["characters"].items(), key=lambda kv: kv[1], reverse=True))
108
+ res["ratings"] = dict(sorted(res["ratings"].items(), key=lambda kv: kv[1], reverse=True))
109
+
110
+
111
+ tag_categories_for_translation["general"] = list(res["general"].keys())
112
+ tag_categories_for_translation["characters"] = list(res["characters"].keys())
113
+ tag_categories_for_translation["ratings"] = list(res["ratings"].keys())
114
 
115
  return res, tag_categories_for_translation
116
 
 
119
  tagger_instance = Tagger()
120
  except RuntimeError as e:
121
  print(f"应用启动时Tagger初始化失败: {e}")
122
+ tagger_instance = None # 允许应用启动,但在处理时会失���
123
 
124
  # ------------------------------------------------------------------
125
  # Gradio UI
 
142
  }
143
  navigator.clipboard.writeText(text).then(() => {
144
  const feedback = document.createElement('div');
145
+ let displayText = String(text);
146
+ displayText = displayText.substring(0, 30) + (displayText.length > 30 ? '...' : '');
147
  feedback.textContent = '已复制: ' + displayText;
148
  Object.assign(feedback.style, {
149
  position: 'fixed', bottom: '20px', left: '50%', transform: 'translateX(-50%)',
 
156
  setTimeout(() => { if (document.body.contains(feedback)) document.body.removeChild(feedback); }, 500);
157
  }, 1500);
158
  }).catch(err => {
159
+ console.error('Failed to copy tag. Error:', err, 'Text:', text);
160
  });
161
  }
162
  """
 
164
  with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css, js=_js_functions) as demo:
165
  gr.Markdown("# 🖼️ AI 图像标签分析器")
166
  gr.Markdown("上传图片自动识别标签,支持中英文显示和一键复制。[NovelAI在线绘画](https://nai.idlecloud.cc/)")
 
 
 
 
 
 
 
167
 
168
  state_res = gr.State({})
169
  state_translations_dict = gr.State({})
170
+ state_tag_categories_for_translation = gr.State({})
171
+
172
 
173
+ with gr.Row():
174
  with gr.Column(scale=1):
175
  img_in = gr.Image(type="pil", label="上传图片", height=300)
176
+
177
  btn = gr.Button("🚀 开始分析", variant="primary", elem_classes=["btn-analyze-container"])
178
 
179
  with gr.Accordion("⚙️ 高级设置", open=False):
180
+ gen_slider = gr.Slider(0, 1, value=0.35, step=0.01, label="通用标签阈值", info="越高 → 标签更少更准")
181
+ char_slider = gr.Slider(0, 1, value=0.85, step=0.01, label="角色标签阈值", info="推荐保持较高阈值")
182
  show_tag_scores = gr.Checkbox(True, label="在列表中显示标签置信度")
183
 
184
+ with gr.Accordion("🔑 翻译服务配置", open=False):
185
+ enable_translation_cb = gr.Checkbox(label="启用翻译", value=True, info="取消勾选则不进行翻译")
186
+ gr.Markdown("提供 **系统访问密钥** 或 **自定义API密钥** 来启用翻译功能。如果两者均未提供或不正确,将不进行翻译。")
187
+
188
+ with gr.Tabs():
189
+ with gr.TabItem("使用系统密钥"):
190
+ system_key_input = gr.Textbox(label="系统访问密钥", type="password", placeholder="输入管理员提供的密钥")
191
+ with gr.TabItem("使用自定义API"):
192
+ gr.Markdown("在此处填入你自己的翻译API密钥。")
193
+ tencent_id_input = gr.Textbox(label="腾讯云 SecretId", type="password")
194
+ tencent_key_input = gr.Textbox(label="腾讯云 SecretKey", type="password")
195
+ baidu_json_input = gr.Textbox(label="百度翻译凭证 (JSON格式)", type="password", placeholder='[{"app_id":"...", "secret_key":"..."}]')
196
 
197
  with gr.Accordion("📊 标签汇总设置", open=True):
198
+ gr.Markdown("选择要包含在下方汇总文本框中的标签类别:")
199
+ with gr.Row():
200
+ sum_general = gr.Checkbox(True, label="通用标签", min_width=50)
201
+ sum_char = gr.Checkbox(True, label="角色标签", min_width=50)
202
+ sum_rating = gr.Checkbox(False, label="评分标签", min_width=50)
203
+ sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="标签之间的分隔符")
204
  sum_show_zh = gr.Checkbox(False, label="在汇总中显示中文翻译")
205
 
206
  processing_info = gr.Markdown("", visible=False)
207
 
208
  with gr.Column(scale=2):
209
  with gr.Tabs():
210
+ with gr.TabItem("🏷️ 通用标签"):
211
+ out_general = gr.HTML(label="General Tags")
212
+ with gr.TabItem("👤 角色标签"):
213
+ gr.Markdown("<p style='color:gray; font-size:small;'>提示:角色标签推测基于截至2024年2月的数据。</p>")
214
+ out_char = gr.HTML(label="Character Tags")
215
+ with gr.TabItem("⭐ 评分标签"):
216
+ out_rating = gr.HTML(label="Rating Tags")
217
 
218
+ gr.Markdown("### 标签汇总结果")
219
+ out_summary = gr.Textbox(
220
+ label="标签汇总",
221
+ placeholder="分析完成后,此处将显示汇总的英文标签...",
222
+ lines=5,
223
+ show_copy_button=True
224
+ )
225
+
226
+ def format_tags_html(tags_dict, translations_list, show_scores=True, show_translation_in_list=True):
227
+ if not tags_dict:
228
+ return "<p>暂无标签</p>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
 
 
230
  html = '<div class="label-container">'
231
+ tag_keys = list(tags_dict.keys())
232
+
233
+ for i, tag in enumerate(tag_keys):
234
+ score = tags_dict[tag]
235
  escaped_tag = tag.replace("'", "\\'")
236
+
237
  html += '<div class="tag-item">'
238
  tag_display_html = f'<span class="tag-en" onclick="copyToClipboard(\'{escaped_tag}\')">{tag}</span>'
239
+
240
+ translation_text = translations_list[i] if i < len(translations_list) else None
241
+ # 仅当翻译文本存在且与原文不同时显示
242
+ if show_translation_in_list and translation_text and translation_text != tag:
243
+ tag_display_html += f'<span class="tag-zh">({translation_text})</span>'
244
+
245
  html += f'<div>{tag_display_html}</div>'
246
+ if show_scores:
247
+ html += f'<span class="tag-score">{score:.3f}</span>'
248
  html += '</div>'
249
+ html += '</div>'
250
+ return html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
+ def generate_summary_text_content(
253
+ current_res, current_translations_dict,
254
+ s_gen, s_char, s_rat, s_sep_type, s_show_zh
255
+ ):
256
+ if not current_res: return "请先分析图像或选择要汇总的标签类别。"
257
+
258
+ summary_parts = []
259
+ separator = {"逗号": ", ", "换行": "\n", "空格": " "}.get(s_sep_type, ", ")
260
+
261
+ categories_to_summarize = []
262
+ if s_gen: categories_to_summarize.append("general")
263
+ if s_char: categories_to_summarize.append("characters")
264
+ if s_rat: categories_to_summarize.append("ratings")
265
+
266
+ if not categories_to_summarize: return "请至少选择一个标签类别进行汇总。"
267
+
268
+ for cat_key in categories_to_summarize:
269
+ if current_res.get(cat_key):
270
+ tags_to_join = []
271
+ cat_tags_en = list(current_res[cat_key].keys())
272
+ cat_translations = current_translations_dict.get(cat_key, [])
273
+
274
+ for i, en_tag in enumerate(cat_tags_en):
275
+ translation_text = cat_translations[i] if i < len(cat_translations) else None
276
+ # 仅当勾选显示中文、翻译文本存在且与原文不同时,才加入翻译
277
+ if s_show_zh and translation_text and translation_text != en_tag:
278
+ tags_to_join.append(f"{en_tag}({translation_text})")
279
+ else:
280
+ tags_to_join.append(en_tag)
281
+ if tags_to_join:
282
+ summary_parts.append(separator.join(tags_to_join))
283
+
284
+ joiner = "\n\n" if separator != "\n" and len(summary_parts) > 1 else separator
285
+ final_summary = joiner.join(summary_parts)
286
+ return final_summary if final_summary else "选定的类别中没有找到标签。"
287
 
288
+ def process_image_and_generate_outputs(
289
+ img, g_th, c_th, s_scores, # Main inputs
290
+ s_gen, s_char, s_rat, s_sep, s_zh_in_sum, # Summary controls
291
+ # New translation controls
292
+ enable_translation, sys_key, tc_id, tc_key, baidu_json
293
+ ):
294
+ initial_yield_state = (
295
+ gr.update(interactive=True, value="🚀 开始分析"), # btn
296
+ "", "", "", "", # html outputs
297
+ gr.update(placeholder="分析失败..."), # summary
298
+ {}, {}, {} # states
299
+ )
300
+ if img is None:
301
+ yield (gr.update(visible=True, value="❌ 请先上传图片。"), *initial_yield_state)
302
+ return
303
 
304
+ if tagger_instance is None:
305
+ yield (gr.update(visible=True, value="❌ 分析器未成功初始化,请检查控制台错误。"), *initial_yield_state)
306
+ return
307
+
308
+ yield (
309
+ gr.update(interactive=False, value="🔄 处理中..."),
310
+ gr.update(visible=True, value="🔄 正在分析图像,请稍候..."),
311
+ gr.HTML(value="<p>分析中...</p>"), gr.HTML(value="<p>分析中...</p>"), gr.HTML(value="<p>分析中...</p>"),
312
+ gr.update(value="分析中,请稍候..."), {}, {}, {}
313
  )
314
 
 
 
 
 
 
 
 
315
  try:
316
+ res, tag_categories_original_order = tagger_instance.predict(img, g_th, c_th)
 
 
 
 
 
 
 
 
317
 
318
+ current_translations_dict = {}
319
+ if enable_translation:
320
+ all_tags_to_translate = []
321
+ for cat_key in ["general", "characters", "ratings"]:
322
+ all_tags_to_translate.extend(tag_categories_original_order.get(cat_key, []))
323
+
324
+ all_translations_flat = []
325
+ if all_tags_to_translate:
326
+ # 使用新的参数调用翻译函数
327
+ all_translations_flat = translate_texts(
328
+ texts=all_tags_to_translate,
329
+ system_key_input=sys_key,
330
+ tencent_id=tc_id,
331
+ tencent_key=tc_key,
332
+ baidu_creds_json_str=baidu_json
333
+ )
334
+
335
+ offset = 0
336
+ for cat_key in ["general", "characters", "ratings"]:
337
+ num_tags_in_cat = len(tag_categories_original_order.get(cat_key, []))
338
+ current_translations_dict[cat_key] = all_translations_flat[offset : offset + num_tags_in_cat] if num_tags_in_cat > 0 else []
339
+ offset += num_tags_in_cat
340
+ else: # 如果未启用翻译,则用空列表填充
341
+ for cat_key in ["general", "characters", "ratings"]:
342
+ current_translations_dict[cat_key] = []
343
 
344
+ general_html = format_tags_html(res.get("general", {}), current_translations_dict.get("general", []), s_scores, enable_translation)
345
+ char_html = format_tags_html(res.get("characters", {}), current_translations_dict.get("characters", []), s_scores, enable_translation)
346
+ rating_html = format_tags_html(res.get("ratings", {}), current_translations_dict.get("ratings", []), s_scores, enable_translation)
347
+
348
+ summary_text = generate_summary_text_content(res, current_translations_dict, s_gen, s_char, s_rat, s_sep, s_zh_in_sum)
349
 
350
+ yield (
351
+ gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value="✅ 分析完成!"),
352
+ general_html, char_html, rating_html,
353
+ gr.update(value=summary_text), res, current_translations_dict, tag_categories_original_order
354
+ )
355
 
356
  except Exception as e:
357
  import traceback
358
+ tb_str = traceback.format_exc()
359
+ print(f"处理时发生错误: {e}\n{tb_str}")
360
+ yield (
361
+ gr.update(visible=True, value=f"❌ 处理失败: {str(e)}"),
362
+ gr.update(interactive=True, value="🚀 开始分析"),
363
+ "<p>处理出错</p>", "<p>处理出错</p>", "<p>处理出错</p>",
364
+ gr.update(value=f"错误: {str(e)}", placeholder="分析失败..."),
365
+ {}, {}, {}
366
+ )
367
+
368
+ def update_summary_display(
369
+ s_gen, s_char, s_rat, s_sep, s_zh_in_sum,
370
+ current_res_from_state, current_translations_from_state
371
+ ):
372
+ if not current_res_from_state:
373
+ return gr.update(placeholder="请先完成一次图像分析以生成汇总。", value="")
374
+
375
+ new_summary_text = generate_summary_text_content(
376
+ current_res_from_state, current_translations_from_state,
377
+ s_gen, s_char, s_rat, s_sep, s_zh_in_sum
378
+ )
379
+ return gr.update(value=new_summary_text)
380
+
381
+
382
+ translation_inputs = [enable_translation_cb, system_key_input, tencent_id_input, tencent_key_input, baidu_json_input]
383
 
384
  btn.click(
385
  process_image_and_generate_outputs,
386
+ inputs=[img_in, gen_slider, char_slider, show_tag_scores,
387
+ sum_general, sum_char, sum_rating, sum_sep, sum_show_zh] + translation_inputs,
388
+ outputs=[btn, processing_info,
389
+ out_general, out_char, out_rating, out_summary,
390
+ state_res, state_translations_dict, state_tag_categories_for_translation]
 
 
 
 
 
 
391
  )
392
 
393
+ summary_controls = [sum_general, sum_char, sum_rating, sum_sep, sum_show_zh]
394
  for ctrl in summary_controls:
395
  ctrl.change(
396
+ fn=update_summary_display,
397
+ inputs=summary_controls + [state_res, state_translations_dict],
398
+ outputs=[out_summary]
399
  )
400
+
401
  if __name__ == "__main__":
402
  if tagger_instance is None:
403
+ print("CRITICAL: Tagger 未能初始化,应用功能将受限。请检查之前的错误信息。")
404
  demo.launch(server_name="0.0.0.0", server_port=7860)