HtSimple commited on
Commit
394e93a
·
verified ·
1 Parent(s): 347e7c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -161
app.py CHANGED
@@ -8,13 +8,8 @@ import os
8
  import json
9
  from datetime import datetime
10
 
11
- # 配置设备 - 强制使用CPU以适配Hugging Face Spaces免费环境
12
- device = "cpu"
13
-
14
- # 动态获取工作目录
15
- script_dir = os.getcwd()
16
- root_dir = os.path.join(script_dir, 'GroceryStoreDataset')
17
- print(f"数据集根目录: {root_dir}")
18
 
19
  # 加载CLIP模型和处理器
20
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
@@ -25,15 +20,12 @@ index = Index(url="https://skilled-duckling-934-us1-vector.upstash.io",
25
  token="ABgFMHNraWxsZWQtZHVja2xpbmctOTM0LXVzMWFkbWluWkRWalpqUTFPV010T0daaU5DMDBORGMwTFdFMVkyUXRaV1JrTVRjNU1EWmpOekZo")
26
 
27
 
28
- # 加载数据集函数 - 改进路径处理和错误日志
29
  def load_dataset(file_path, root_dir):
30
  data = []
31
  print(f"加载数据集文件: {file_path}")
32
-
33
- # 检查文件是否存在
34
  if not os.path.exists(file_path):
35
  raise FileNotFoundError(f"数据集文件不存在: {file_path}")
36
-
37
  with open(file_path, 'r', encoding='utf-8') as f:
38
  lines = f.readlines()
39
  for i, line in enumerate(lines):
@@ -42,53 +34,30 @@ def load_dataset(file_path, root_dir):
42
  if len(parts) != 3:
43
  print(f"第 {i + 1} 行格式错误: {line}")
44
  continue
45
-
46
  image_path, fine_grained_label, coarse_grained_label = parts
47
-
48
- # 确保路径格式正确(使用正斜杠)
49
- image_path = image_path.replace('\\', '/')
50
-
51
- # 构建完整路径(移除多余的'dataset'前缀)
52
- if image_path.startswith('dataset/'):
53
- image_path = image_path[8:] # 移除'dataset/'前缀
54
-
55
  full_image_path = os.path.join(root_dir, 'dataset', image_path)
56
-
57
- # 检查文件是否存在并可读取
58
- if os.path.exists(full_image_path) and os.access(full_image_path, os.R_OK):
59
  data.append((full_image_path, int(fine_grained_label), int(coarse_grained_label)))
60
  else:
61
- print(f"警告: 文件不存在或不可读 - {full_image_path}")
62
-
63
  except Exception as e:
64
  print(f"解析第 {i + 1} 行时出错: {line}")
65
  print(f"错误详情: {e}")
66
-
67
  print(f"成功加载 {len(data)} 个样本")
68
  return data
69
 
70
 
71
- # 特征提取和向量插入函数
72
  def insert_images_to_index(data):
73
  print(f"开始向向量数据库插入 {len(data)} 个图像特征...")
74
  success_count = 0
75
  error_count = 0
76
-
77
  for image_path, fine_label, coarse_label in data:
78
  try:
79
- # 验证图像文件存在
80
- if not os.path.exists(image_path):
81
- print(f"错误: 图像文件不存在 - {image_path}")
82
- error_count += 1
83
- continue
84
-
85
  image = Image.open(image_path)
86
  features = extract_image_features(image)
87
-
88
- # 使用规范化的文件路径作为ID的一部分
89
- file_id = os.path.basename(image_path).replace('.', '_')
90
- vector_id = f"img_{file_id}_{fine_label}"
91
-
92
  vector = Vector(
93
  id=vector_id,
94
  vector=features,
@@ -98,14 +67,11 @@ def insert_images_to_index(data):
98
  "coarse_label": coarse_label
99
  }
100
  )
101
-
102
  index.upsert(vectors=[vector])
103
  success_count += 1
104
-
105
  except Exception as e:
106
  print(f"处理图像 {image_path} 时出错: {e}")
107
  error_count += 1
108
-
109
  print(f"向量插入完成: 成功 {success_count}, 失败 {error_count}")
110
 
111
 
@@ -113,226 +79,161 @@ def extract_image_features(image):
113
  try:
114
  if isinstance(image, np.ndarray):
115
  image = Image.fromarray(image)
116
-
117
  inputs = processor(images=image, return_tensors="pt").to(device)
118
-
119
  with torch.no_grad():
120
  image_features = model.get_image_features(**inputs)
121
-
122
  image_features = image_features / image_features.norm(dim=-1, keepdim=True)
123
  return image_features.cpu().numpy().flatten().tolist()
124
-
125
  except Exception as e:
126
  print(f"特征提取错误: {e}")
127
  return [0.0] * 512
128
 
129
 
130
- # 搜索函数 - 改进图像加载和错误处理
131
  def text_search(query_text, top_k=9, min_similarity=0.0):
132
  try:
133
  if not query_text.strip():
134
  return [(Image.new("RGB", (400, 200), "white"), "请输入搜索文字")]
135
-
136
  text_inputs = processor(text=query_text, return_tensors="pt", padding=True).to(device)
137
-
138
  with torch.no_grad():
139
  text_features = model.get_text_features(**text_inputs)
140
  text_features = text_features / text_features.norm(dim=-1, keepdim=True)
141
-
142
  results = index.query(
143
  vector=text_features.cpu().numpy().flatten().tolist(),
144
  top_k=top_k,
145
  include_vectors=True,
146
  include_metadata=True
147
  )
148
-
149
  filtered_results = [item for item in results if item.score >= min_similarity]
150
-
151
  if not filtered_results:
152
  return [(Image.new("RGB", (400, 200), "white"), "无匹配结果")]
153
-
154
  gallery_items = []
155
-
156
  for item in filtered_results[:top_k]:
157
  metadata = item.metadata
158
  image_path = metadata["image_path"]
159
-
160
- # 打印路径用于调试
161
- print(f"搜索结果图像路径: {image_path}")
162
-
163
  try:
164
- # 验证路径有效性
165
- if not image_path or not os.path.exists(image_path):
166
- raise FileNotFoundError(f"路径不存在: {image_path}")
167
-
168
  img = Image.open(image_path).convert("RGB")
169
-
170
- except FileNotFoundError as e:
171
- print(f"错误: 找不到图像 - {image_path}")
172
  img = Image.new("RGB", (200, 200), "white")
173
-
174
- except Exception as e:
175
- print(f"加载图像失败: {image_path}, 错误: {e}")
176
- img = Image.new("RGB", (200, 200), "white")
177
-
178
  caption = f"相似度: {item.score:.4f}"
179
  gallery_items.append((img, caption))
180
-
181
  return gallery_items
182
-
183
  except Exception as e:
184
  print(f"文字搜索错误: {e}")
185
  return [(Image.new("RGB", (400, 200), "white"), f"错误: {str(e)}")]
186
 
187
 
188
- # 图像搜索函数 - 改进错误处理
 
189
  def image_search(query_image, top_k=9, min_similarity=0.0):
190
  try:
191
  if query_image is None:
192
  return [(Image.new("RGB", (400, 200), "white"), "请上传搜索图像")]
193
-
194
  # 提取图像特征
195
  image_features = extract_image_features(query_image)
196
-
197
- # 确保特征向格式正确
198
- if not isinstance(image_features, list):
199
- image_features = image_features.tolist()
200
-
 
 
201
  # 使用正确的特征向量进行查询
202
  results = index.query(
203
- vector=image_features,
204
  top_k=top_k,
205
  include_vectors=True,
206
  include_metadata=True
207
  )
208
-
209
  filtered_results = []
210
-
211
  for item in results:
212
  metadata = item.metadata
213
  image_path = metadata["image_path"]
214
-
215
  # 相似度过滤
216
  if item.score < min_similarity:
217
  continue
218
-
219
  filtered_results.append(item)
220
-
221
  # 处理空结果
222
  if not filtered_results:
223
  return [(Image.new("RGB", (400, 200), "white"), "无匹配结果")]
224
-
225
  # 构建Gallery所需的元组列表
226
  gallery_items = []
227
-
228
  for item in filtered_results[:top_k]:
229
  metadata = item.metadata
230
  image_path = metadata["image_path"]
231
-
232
- # 打印路径用于调试
233
- print(f"图像搜索结果路径: {image_path}")
234
-
235
- try:
236
- # 验证路径有效性
237
- if not image_path or not os.path.exists(image_path):
238
- raise FileNotFoundError(f"路径不存在: {image_path}")
239
-
240
- img = Image.open(image_path).convert("RGB")
241
-
242
- except FileNotFoundError as e:
243
- print(f"错误: 找不到图像 - {image_path}")
244
- img = Image.new("RGB", (200, 200), "white")
245
-
246
- except Exception as e:
247
- print(f"加载图像失败: {image_path}, 错误: {e}")
248
- img = Image.new("RGB", (200, 200), "white")
249
-
250
- # 组合分数和标签作为标题
251
- caption = f"相似度: {item.score:.4f}"
252
- gallery_items.append((img, caption))
253
-
254
  return gallery_items
255
-
256
  except Exception as e:
257
  print(f"图像搜索错误: {e}")
258
  return [(Image.new("RGB", (400, 200), "red"), f"错误: {str(e)}")]
259
 
260
 
261
- # 初始化向量数据库 - 改进路径验证
262
  def initialize_vector_db():
 
 
263
  flag_file = os.path.join(root_dir, 'dataset', '.vectors_inserted')
264
-
265
- # 检查标志文件
266
  if os.path.exists(flag_file):
267
  print("发现标志文件,跳过向量数据库检查")
268
  return
269
-
270
  try:
271
- # 测试向量数据库连接
272
  results = index.query(vector=[0.0] * 512, top_k=1, include_metadata=False)
273
-
274
- if results and len(results) > 0:
275
  print("向量数据库已有数据,跳过插入")
276
  os.makedirs(os.path.dirname(flag_file), exist_ok=True)
277
-
278
  with open(flag_file, 'w') as f:
279
  f.write("Vectors already exist")
280
-
281
  return
282
-
283
- # 验证数据集文件
284
  train_file = os.path.join(root_dir, 'dataset', 'train.txt')
285
  val_file = os.path.join(root_dir, 'dataset', 'val.txt')
286
  test_file = os.path.join(root_dir, 'dataset', 'test.txt')
287
-
288
  for file_path in [train_file, val_file, test_file]:
289
  if not os.path.exists(file_path):
290
  print(f"警告: 数据集文件不存在 - {file_path}")
291
  return
292
-
293
- # 加载数据集
294
  train_data = load_dataset(train_file, root_dir)
295
  val_data = load_dataset(val_file, root_dir)
296
  test_data = load_dataset(test_file, root_dir)
297
-
298
- # 插入向量
299
  insert_images_to_index(train_data + val_data + test_data)
300
-
301
- # 创建标志文件
302
  os.makedirs(os.path.dirname(flag_file), exist_ok=True)
303
-
304
  with open(flag_file, 'w') as f:
305
  f.write("Vectors inserted successfully")
306
-
307
  except Exception as e:
308
  print(f"查询向量数据库失败: {e}")
309
-
310
  if os.path.exists(flag_file):
311
  print("但发现标志文件,推测数据已插入,跳过插入")
312
  return
313
-
314
  print("没有标志文件,尝试加载数据并插入(有重复风险)")
315
-
316
- # 尝试恢复数据加载
317
- if 'train_data' not in locals():
318
  train_file = os.path.join(root_dir, 'dataset', 'train.txt')
319
  val_file = os.path.join(root_dir, 'dataset', 'val.txt')
320
  test_file = os.path.join(root_dir, 'dataset', 'test.txt')
321
-
322
  for file_path in [train_file, val_file, test_file]:
323
  if not os.path.exists(file_path):
324
  print(f"警告: 数据集文件不存在 - {file_path}")
325
  return
326
-
327
  train_data = load_dataset(train_file, root_dir)
328
  val_data = load_dataset(val_file, root_dir)
329
  test_data = load_dataset(test_file, root_dir)
330
-
331
  insert_images_to_index(train_data + val_data + test_data)
332
-
333
- # 创建标志文件
334
  os.makedirs(os.path.dirname(flag_file), exist_ok=True)
335
-
336
  with open(flag_file, 'w') as f:
337
  f.write("Vectors inserted with error handling")
338
 
@@ -340,11 +241,11 @@ def initialize_vector_db():
340
  # 主应用界面
341
  def create_app():
342
  initialize_vector_db()
343
-
344
  with gr.Blocks(title="CLIP图像搜索系统", theme=gr.themes.Soft()) as app:
345
  gr.Markdown("# CLIP图像搜索系统")
346
  gr.Markdown("使用文字或图像搜索相似的商品图片")
347
-
348
  with gr.Tabs():
349
  # 文字搜索标签页
350
  with gr.Tab("文字搜索"):
@@ -355,14 +256,13 @@ def create_app():
355
  placeholder="点击下方标签自动填充",
356
  interactive=True
357
  )
358
-
359
- # 可选标签
360
  gr.Markdown("### 可选标签")
361
  with gr.Row():
362
  # 示例标签,可根据实际数据扩展
363
  labels = ["apple", "banana", "orange", "vegetables", "fruit"]
364
  label_btns = []
365
-
366
  for label in labels:
367
  btn = gr.Button(
368
  label,
@@ -370,55 +270,52 @@ def create_app():
370
  elem_classes="tag-btn"
371
  )
372
  label_btns.append(btn)
373
-
374
  # 点击标签时触发的函数
375
  btn.click(
376
- fn=lambda txt, lbl: lbl if txt != lbl else "",
377
  inputs=[text_query, gr.Textbox(value=label, visible=False)],
378
  outputs=text_query
379
  )
380
-
 
381
  # 控制区
382
  with gr.Group():
383
  gr.Markdown("### 搜索参数")
384
  text_top_k = gr.Slider(minimum=1, maximum=21, step=1, value=9, label="最多显示图片数")
385
  text_min_sim = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, value=0.0,
386
  label="最低相似度阈值")
387
-
388
  text_search_btn = gr.Button("搜索", variant="primary")
389
-
390
  text_output_images = gr.Gallery(label="搜索结果", show_label=True, columns=3, rows=7)
391
-
392
- # 图像搜索标签页
393
  with gr.Tab("图像搜索"):
394
  with gr.Row():
395
  with gr.Column(scale=2):
396
  image_query = gr.Image(label="上传搜索图像", type="pil")
397
-
398
  with gr.Group():
399
  gr.Markdown("### 搜索参数")
400
  image_top_k = gr.Slider(minimum=1, maximum=21, step=1, value=9, label="最多显示图片数")
401
  image_min_sim = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, value=0.0,
402
  label="最低相似度阈值")
403
-
404
  image_search_btn = gr.Button("搜索", variant="primary")
405
-
406
  image_output_images = gr.Gallery(label="搜索结果", show_label=True, columns=3, rows=7)
407
-
408
  # 文字搜索按钮事件绑定
409
  text_search_btn.click(
410
  fn=text_search,
411
  inputs=[text_query, text_top_k, text_min_sim],
412
  outputs=text_output_images
413
  )
414
-
415
  # 图像搜索按钮事件绑定
416
  image_search_btn.click(
417
  fn=image_search,
418
  inputs=[image_query, image_top_k, image_min_sim],
419
  outputs=image_output_images
420
  )
421
-
422
  # 全局样式:标签按钮样式
423
  gr.Markdown("""
424
  <style>
@@ -439,7 +336,7 @@ def create_app():
439
  }
440
  </style>
441
  """)
442
-
443
  return app
444
 
445
 
 
8
  import json
9
  from datetime import datetime
10
 
11
+ # 配置设备
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
13
 
14
  # 加载CLIP模型和处理器
15
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
 
20
  token="ABgFMHNraWxsZWQtZHVja2xpbmctOTM0LXVzMWFkbWluWkRWalpqUTFPV010T0daaU5DMDBORGMwTFdFMVkyUXRaV1JrTVRjNU1EWmpOekZo")
21
 
22
 
23
+ # 加载数据集函数(保持不变)
24
  def load_dataset(file_path, root_dir):
25
  data = []
26
  print(f"加载数据集文件: {file_path}")
 
 
27
  if not os.path.exists(file_path):
28
  raise FileNotFoundError(f"数据集文件不存在: {file_path}")
 
29
  with open(file_path, 'r', encoding='utf-8') as f:
30
  lines = f.readlines()
31
  for i, line in enumerate(lines):
 
34
  if len(parts) != 3:
35
  print(f"第 {i + 1} 行格式错误: {line}")
36
  continue
 
37
  image_path, fine_grained_label, coarse_grained_label = parts
38
+ image_path = image_path.replace('/', os.sep)
 
 
 
 
 
 
 
39
  full_image_path = os.path.join(root_dir, 'dataset', image_path)
40
+ if os.path.exists(full_image_path):
 
 
41
  data.append((full_image_path, int(fine_grained_label), int(coarse_grained_label)))
42
  else:
43
+ print(f"警告: 文件不存在 - {full_image_path}")
 
44
  except Exception as e:
45
  print(f"解析第 {i + 1} 行时出错: {line}")
46
  print(f"错误详情: {e}")
 
47
  print(f"成功加载 {len(data)} 个样本")
48
  return data
49
 
50
 
51
+ # 特征提取和向量插入函数(保持不变)
52
  def insert_images_to_index(data):
53
  print(f"开始向向量数据库插入 {len(data)} 个图像特征...")
54
  success_count = 0
55
  error_count = 0
 
56
  for image_path, fine_label, coarse_label in data:
57
  try:
 
 
 
 
 
 
58
  image = Image.open(image_path)
59
  features = extract_image_features(image)
60
+ vector_id = f"img_{os.path.basename(image_path)}_{fine_label}"
 
 
 
 
61
  vector = Vector(
62
  id=vector_id,
63
  vector=features,
 
67
  "coarse_label": coarse_label
68
  }
69
  )
 
70
  index.upsert(vectors=[vector])
71
  success_count += 1
 
72
  except Exception as e:
73
  print(f"处理图像 {image_path} 时出错: {e}")
74
  error_count += 1
 
75
  print(f"向量插入完成: 成功 {success_count}, 失败 {error_count}")
76
 
77
 
 
79
  try:
80
  if isinstance(image, np.ndarray):
81
  image = Image.fromarray(image)
 
82
  inputs = processor(images=image, return_tensors="pt").to(device)
 
83
  with torch.no_grad():
84
  image_features = model.get_image_features(**inputs)
 
85
  image_features = image_features / image_features.norm(dim=-1, keepdim=True)
86
  return image_features.cpu().numpy().flatten().tolist()
 
87
  except Exception as e:
88
  print(f"特征提取错误: {e}")
89
  return [0.0] * 512
90
 
91
 
92
+ # 搜索函数(保持不变)
93
  def text_search(query_text, top_k=9, min_similarity=0.0):
94
  try:
95
  if not query_text.strip():
96
  return [(Image.new("RGB", (400, 200), "white"), "请输入搜索文字")]
 
97
  text_inputs = processor(text=query_text, return_tensors="pt", padding=True).to(device)
 
98
  with torch.no_grad():
99
  text_features = model.get_text_features(**text_inputs)
100
  text_features = text_features / text_features.norm(dim=-1, keepdim=True)
 
101
  results = index.query(
102
  vector=text_features.cpu().numpy().flatten().tolist(),
103
  top_k=top_k,
104
  include_vectors=True,
105
  include_metadata=True
106
  )
 
107
  filtered_results = [item for item in results if item.score >= min_similarity]
 
108
  if not filtered_results:
109
  return [(Image.new("RGB", (400, 200), "white"), "无匹配结果")]
 
110
  gallery_items = []
 
111
  for item in filtered_results[:top_k]:
112
  metadata = item.metadata
113
  image_path = metadata["image_path"]
 
 
 
 
114
  try:
 
 
 
 
115
  img = Image.open(image_path).convert("RGB")
116
+ except:
 
 
117
  img = Image.new("RGB", (200, 200), "white")
 
 
 
 
 
118
  caption = f"相似度: {item.score:.4f}"
119
  gallery_items.append((img, caption))
 
120
  return gallery_items
 
121
  except Exception as e:
122
  print(f"文字搜索错误: {e}")
123
  return [(Image.new("RGB", (400, 200), "white"), f"错误: {str(e)}")]
124
 
125
 
126
+
127
+ # 图像搜索函数
128
  def image_search(query_image, top_k=9, min_similarity=0.0):
129
  try:
130
  if query_image is None:
131
  return [(Image.new("RGB", (400, 200), "white"), "请上传搜索图像")]
132
+
133
  # 提取图像特征
134
  image_features = extract_image_features(query_image)
135
+
136
+ # 将列表转换为 PyTorch 张
137
+ image_features = torch.tensor(image_features)
138
+
139
+ # 归一化处理
140
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
141
+
142
  # 使用正确的特征向量进行查询
143
  results = index.query(
144
+ vector=image_features.cpu().numpy().flatten().tolist(),
145
  top_k=top_k,
146
  include_vectors=True,
147
  include_metadata=True
148
  )
149
+
150
  filtered_results = []
 
151
  for item in results:
152
  metadata = item.metadata
153
  image_path = metadata["image_path"]
154
+
155
  # 相似度过滤
156
  if item.score < min_similarity:
157
  continue
158
+
159
  filtered_results.append(item)
160
+
161
  # 处理空结果
162
  if not filtered_results:
163
  return [(Image.new("RGB", (400, 200), "white"), "无匹配结果")]
164
+
165
  # 构建Gallery所需的元组列表
166
  gallery_items = []
 
167
  for item in filtered_results[:top_k]:
168
  metadata = item.metadata
169
  image_path = metadata["image_path"]
170
+ if image_path:
171
+ try:
172
+ img = Image.open(image_path).convert("RGB")
173
+ except Exception as e:
174
+ print(f"加载图片失败: {image_path}, 错误: {e}")
175
+ img = Image.new("RGB", (200, 200), "white")
176
+
177
+ # 组合分数和标签作为标题
178
+ caption = f"相似度: {item.score:.4f}"
179
+ gallery_items.append((img, caption))
180
+
 
 
 
 
 
 
 
 
 
 
 
 
181
  return gallery_items
182
+
183
  except Exception as e:
184
  print(f"图像搜索错误: {e}")
185
  return [(Image.new("RGB", (400, 200), "red"), f"错误: {str(e)}")]
186
 
187
 
188
+ # 初始化向量数据库(保持不变)
189
  def initialize_vector_db():
190
+ script_dir = os.path.dirname(os.path.abspath(__file__))
191
+ root_dir = os.path.join(script_dir, 'GroceryStoreDataset')
192
  flag_file = os.path.join(root_dir, 'dataset', '.vectors_inserted')
 
 
193
  if os.path.exists(flag_file):
194
  print("发现标志文件,跳过向量数据库检查")
195
  return
 
196
  try:
 
197
  results = index.query(vector=[0.0] * 512, top_k=1, include_metadata=False)
198
+ if results and len(results.get("results", [])) > 0:
 
199
  print("向量数据库已有数据,跳过插入")
200
  os.makedirs(os.path.dirname(flag_file), exist_ok=True)
 
201
  with open(flag_file, 'w') as f:
202
  f.write("Vectors already exist")
 
203
  return
 
 
204
  train_file = os.path.join(root_dir, 'dataset', 'train.txt')
205
  val_file = os.path.join(root_dir, 'dataset', 'val.txt')
206
  test_file = os.path.join(root_dir, 'dataset', 'test.txt')
 
207
  for file_path in [train_file, val_file, test_file]:
208
  if not os.path.exists(file_path):
209
  print(f"警告: 数据集文件不存在 - {file_path}")
210
  return
 
 
211
  train_data = load_dataset(train_file, root_dir)
212
  val_data = load_dataset(val_file, root_dir)
213
  test_data = load_dataset(test_file, root_dir)
 
 
214
  insert_images_to_index(train_data + val_data + test_data)
 
 
215
  os.makedirs(os.path.dirname(flag_file), exist_ok=True)
 
216
  with open(flag_file, 'w') as f:
217
  f.write("Vectors inserted successfully")
 
218
  except Exception as e:
219
  print(f"查询向量数据库失败: {e}")
 
220
  if os.path.exists(flag_file):
221
  print("但发现标志文件,推测数据已插入,跳过插入")
222
  return
 
223
  print("没有标志文件,尝试加载数据并插入(有重复风险)")
224
+ if train_data is None:
 
 
225
  train_file = os.path.join(root_dir, 'dataset', 'train.txt')
226
  val_file = os.path.join(root_dir, 'dataset', 'val.txt')
227
  test_file = os.path.join(root_dir, 'dataset', 'test.txt')
 
228
  for file_path in [train_file, val_file, test_file]:
229
  if not os.path.exists(file_path):
230
  print(f"警告: 数据集文件不存在 - {file_path}")
231
  return
 
232
  train_data = load_dataset(train_file, root_dir)
233
  val_data = load_dataset(val_file, root_dir)
234
  test_data = load_dataset(test_file, root_dir)
 
235
  insert_images_to_index(train_data + val_data + test_data)
 
 
236
  os.makedirs(os.path.dirname(flag_file), exist_ok=True)
 
237
  with open(flag_file, 'w') as f:
238
  f.write("Vectors inserted with error handling")
239
 
 
241
  # 主应用界面
242
  def create_app():
243
  initialize_vector_db()
244
+
245
  with gr.Blocks(title="CLIP图像搜索系统", theme=gr.themes.Soft()) as app:
246
  gr.Markdown("# CLIP图像搜索系统")
247
  gr.Markdown("使用文字或图像搜索相似的商品图片")
248
+
249
  with gr.Tabs():
250
  # 文字搜索标签页
251
  with gr.Tab("文字搜索"):
 
256
  placeholder="点击下方标签自动填充",
257
  interactive=True
258
  )
259
+
260
+ # 可选标签(使用HTML按钮实现可取消选择)
261
  gr.Markdown("### 可选标签")
262
  with gr.Row():
263
  # 示例标签,可根据实际数据扩展
264
  labels = ["apple", "banana", "orange", "vegetables", "fruit"]
265
  label_btns = []
 
266
  for label in labels:
267
  btn = gr.Button(
268
  label,
 
270
  elem_classes="tag-btn"
271
  )
272
  label_btns.append(btn)
 
273
  # 点击标签时触发的函数
274
  btn.click(
275
+ fn=lambda txt, lbl: lbl if txt != lbl else "", # 点击已选标签则清空
276
  inputs=[text_query, gr.Textbox(value=label, visible=False)],
277
  outputs=text_query
278
  )
279
+
280
+
281
  # 控制区
282
  with gr.Group():
283
  gr.Markdown("### 搜索参数")
284
  text_top_k = gr.Slider(minimum=1, maximum=21, step=1, value=9, label="最多显示图片数")
285
  text_min_sim = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, value=0.0,
286
  label="最低相似度阈值")
287
+
288
  text_search_btn = gr.Button("搜索", variant="primary")
289
+
290
  text_output_images = gr.Gallery(label="搜索结果", show_label=True, columns=3, rows=7)
291
+
292
+ # 图像搜索标签页(保持不变)
293
  with gr.Tab("图像搜索"):
294
  with gr.Row():
295
  with gr.Column(scale=2):
296
  image_query = gr.Image(label="上传搜索图像", type="pil")
 
297
  with gr.Group():
298
  gr.Markdown("### 搜索参数")
299
  image_top_k = gr.Slider(minimum=1, maximum=21, step=1, value=9, label="最多显示图片数")
300
  image_min_sim = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, value=0.0,
301
  label="最低相似度阈值")
 
302
  image_search_btn = gr.Button("搜索", variant="primary")
 
303
  image_output_images = gr.Gallery(label="搜索结果", show_label=True, columns=3, rows=7)
304
+
305
  # 文字搜索按钮事件绑定
306
  text_search_btn.click(
307
  fn=text_search,
308
  inputs=[text_query, text_top_k, text_min_sim],
309
  outputs=text_output_images
310
  )
311
+
312
  # 图像搜索按钮事件绑定
313
  image_search_btn.click(
314
  fn=image_search,
315
  inputs=[image_query, image_top_k, image_min_sim],
316
  outputs=image_output_images
317
  )
318
+
319
  # 全局样式:标签按钮样式
320
  gr.Markdown("""
321
  <style>
 
336
  }
337
  </style>
338
  """)
339
+
340
  return app
341
 
342