chanhua commited on
Commit
17871e5
1 Parent(s): 2da22fc

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +20 -9
  2. image_feature.py +29 -3
app.py CHANGED
@@ -3,7 +3,8 @@ import image_feature as func
3
 
4
 
5
  def work11(image1, image2):
6
- return func.infer1(image1, image2)
 
7
 
8
 
9
  # with gr.Blocks() as demo:
@@ -15,11 +16,21 @@ def work11(image1, image2):
15
  # demo.launch()
16
 
17
  # 定义你的界面
18
- with gr.Interface(fn=work11,
19
- inputs=[gr.Textbox(label='图片1', lines=1), gr.Textbox(label='图片2', lines=1)], # 两个文本输入框
20
- outputs=[gr.Textbox(lines=3, label="推理结果")], # 输出为文本
21
- title="图片相似度推理", # 界面标题
22
- description="输入两张图片链接进行相似度推理", # 界面描述
23
- examples=[["https://example.com", "https://google.com"], # 示例输入
24
- ["https://github.com", "https://twitter.com"]]) as demo: # 更多示例输入
25
- demo.launch() # 启动界面
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  def work11(image1, image2):
6
+ # return func.infer1(image1, image2)
7
+ return func.infer3(image1, image2)
8
 
9
 
10
  # with gr.Blocks() as demo:
 
16
  # demo.launch()
17
 
18
  # 定义你的界面
19
+ # with gr.Interface(fn=work11,
20
+ # inputs=[gr.Textbox(label='图片1', lines=1), gr.Textbox(label='图片2', lines=1)], # 两个文本输入框
21
+ # outputs=[gr.Textbox(lines=3, label="推理结果")], # 输出为文本
22
+ # title="图片相似度推理", # 界面标题
23
+ # description="输入两张图片链接进行相似度推理", # 界面描述
24
+ # examples=[["https://example.com", "https://google.com"], # 示例输入
25
+ # ["https://github.com", "https://twitter.com"]]) as demo: # 更多示例输入
26
+ # demo.launch() # 启动界面
27
+
28
+ demo = gr.Interface(title="图片相似度推理",
29
+ css="",
30
+ fn=work11,
31
+ inputs=[gr.Image(type="filepath", label="图片1"), gr.Image(type="filepath", label="图片2")],
32
+ outputs=[gr.Textbox(lines=3, label="推理结果")])
33
+ #
34
+ # # demo = gr.Interface(fn=work, inputs="image,text", outputs="text")
35
+ #
36
+ demo.launch()
image_feature.py CHANGED
@@ -5,7 +5,7 @@ from PIL import Image
5
  from torch.nn.functional import cosine_similarity
6
  from transformers import AutoImageProcessor, AutoModel
7
 
8
- # from transformers import pipeline
9
 
10
  # import transformers
11
  #
@@ -53,6 +53,34 @@ processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
53
  model = AutoModel.from_pretrained("google/vit-base-patch16-224").to(DEVICE)
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # 推理
57
  def infer2(url):
58
  # image_real = Image.open(requests.get(img_urls[0], stream=True).raw).convert("RGB")
@@ -80,5 +108,3 @@ def infer1(image1, image2):
80
  finally:
81
  # 无论是否发生异常,都会执行此代码块
82
  print("这是finally块")
83
-
84
- # tensor([0.6061], device='cuda:0', grad_fn=<SumBackward1>)
 
5
  from torch.nn.functional import cosine_similarity
6
  from transformers import AutoImageProcessor, AutoModel
7
 
8
+ from transformers import pipeline
9
 
10
  # import transformers
11
  #
 
53
  model = AutoModel.from_pretrained("google/vit-base-patch16-224").to(DEVICE)
54
 
55
 
56
+ # tensor([0.6061], device='cuda:0', grad_fn=<SumBackward1>)
57
+
58
+
59
+ # 推理
60
+ def infer3(url1, url2):
61
+ try:
62
+ image_real = Image.open(requests.get(url1, stream=True).raw).convert("RGB")
63
+ image_gen = Image.open(requests.get(url2, stream=True).raw).convert("RGB")
64
+
65
+ pipe = pipeline(task="image-feature-extraction", model_name="google/vit-base-patch16-384", device=DEVICE,
66
+ pool=True)
67
+
68
+ outputs = pipe([image_real, image_gen])
69
+
70
+ similarity_score = cosine_similarity(torch.Tensor(outputs[0]), torch.Tensor(outputs[1]), dim=1)
71
+
72
+ t_cpu = similarity_score.cpu()
73
+
74
+ # 然后提取这个值
75
+ return t_cpu.item()
76
+
77
+ except Exception as e:
78
+ print(f"发生了一个错误: {e}")
79
+ finally:
80
+ # 无论是否发生异常,都会执行此代码块
81
+ print("这是finally块")
82
+
83
+
84
  # 推理
85
  def infer2(url):
86
  # image_real = Image.open(requests.get(img_urls[0], stream=True).raw).convert("RGB")
 
108
  finally:
109
  # 无论是否发生异常,都会执行此代码块
110
  print("这是finally块")