Update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,7 @@ from qdrant_client import QdrantClient
|
|
9 |
from qdrant_client.http.models import Filter, FieldCondition, MatchValue
|
10 |
import os
|
11 |
|
12 |
-
# โหลด EfficientNet-
|
13 |
model = EfficientNet.from_pretrained('efficientnet-b7')
|
14 |
model.eval()
|
15 |
|
@@ -54,32 +54,43 @@ def get_image_embedding(image_path_or_url):
|
|
54 |
embedding = model(image_tensor)
|
55 |
return embedding
|
56 |
|
57 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
def search_similar_images(image_upload, image_url):
|
59 |
# กำหนดแหล่งภาพ
|
60 |
if image_upload is not None:
|
61 |
-
|
62 |
elif image_url:
|
63 |
-
|
64 |
-
image = Image.open(BytesIO(response.content)).convert("RGB")
|
65 |
else:
|
66 |
return [], "❌ Please upload an image or provide an image URL."
|
67 |
|
68 |
-
|
|
|
|
|
69 |
|
70 |
# Query จาก Qdrant
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
76 |
|
77 |
if not result:
|
78 |
return [], "❌ No matching images found."
|
79 |
|
80 |
gallery_data = []
|
81 |
for r in result:
|
82 |
-
image_url = r.payload
|
83 |
name = r.payload.get('name', 'Unknown')
|
84 |
score = round(r.score, 3)
|
85 |
caption = f"{name}\n(score: {score})"
|
@@ -87,8 +98,9 @@ def search_similar_images(image_upload, image_url):
|
|
87 |
|
88 |
return gallery_data, f"✅ Found {len(result)} similar images."
|
89 |
|
|
|
90 |
# Gradio Interface
|
91 |
-
with gr.Blocks(title="Image Similarity Search with EfficientNet-
|
92 |
gr.Markdown("# 🔍 EfficientNet-B7 Image Search")
|
93 |
gr.Markdown("Upload an image or paste a URL to find similar images using Qdrant.")
|
94 |
|
@@ -108,4 +120,4 @@ with gr.Blocks(title="Image Similarity Search with EfficientNet-B2") as demo:
|
|
108 |
)
|
109 |
|
110 |
if __name__ == "__main__":
|
111 |
-
demo.launch()
|
|
|
9 |
from qdrant_client.http.models import Filter, FieldCondition, MatchValue
|
10 |
import os
|
11 |
|
12 |
+
# โหลด EfficientNet-B7
|
13 |
model = EfficientNet.from_pretrained('efficientnet-b7')
|
14 |
model.eval()
|
15 |
|
|
|
54 |
embedding = model(image_tensor)
|
55 |
return embedding
|
56 |
|
57 |
+
# NEW: ดึง embedding จาก PIL.Image โดยตรง
|
58 |
+
def get_image_embedding_from_pil(image: Image.Image):
|
59 |
+
image_tensor = preprocess(image).unsqueeze(0)
|
60 |
+
with torch.no_grad():
|
61 |
+
embedding = model(image_tensor)
|
62 |
+
return embedding
|
63 |
+
|
64 |
+
# ปรับฟังก์ชันหลักให้เลือกเส้นทางที่ถูกต้อง
|
65 |
def search_similar_images(image_upload, image_url):
|
66 |
# กำหนดแหล่งภาพ
|
67 |
if image_upload is not None:
|
68 |
+
query_vector = get_image_embedding_from_pil(image_upload)
|
69 |
elif image_url:
|
70 |
+
query_vector = get_image_embedding(image_url)
|
|
|
71 |
else:
|
72 |
return [], "❌ Please upload an image or provide an image URL."
|
73 |
|
74 |
+
# ป้องกันกรณีที่ embedding ผิดพลาด
|
75 |
+
if query_vector is None:
|
76 |
+
return [], "❌ Could not process the image."
|
77 |
|
78 |
# Query จาก Qdrant
|
79 |
+
try:
|
80 |
+
result = qdrant_client.search(
|
81 |
+
collection_name="image_collection_b7",
|
82 |
+
query_vector=query_vector[0].tolist(),
|
83 |
+
with_payload=True,
|
84 |
+
)
|
85 |
+
except Exception as e:
|
86 |
+
return [], f"❌ Qdrant search failed: {str(e)}"
|
87 |
|
88 |
if not result:
|
89 |
return [], "❌ No matching images found."
|
90 |
|
91 |
gallery_data = []
|
92 |
for r in result:
|
93 |
+
image_url = r.payload.get('image_url', '')
|
94 |
name = r.payload.get('name', 'Unknown')
|
95 |
score = round(r.score, 3)
|
96 |
caption = f"{name}\n(score: {score})"
|
|
|
98 |
|
99 |
return gallery_data, f"✅ Found {len(result)} similar images."
|
100 |
|
101 |
+
|
102 |
# Gradio Interface
|
103 |
+
with gr.Blocks(title="Image Similarity Search with EfficientNet-B7") as demo:
|
104 |
gr.Markdown("# 🔍 EfficientNet-B7 Image Search")
|
105 |
gr.Markdown("Upload an image or paste a URL to find similar images using Qdrant.")
|
106 |
|
|
|
120 |
)
|
121 |
|
122 |
if __name__ == "__main__":
|
123 |
+
demo.launch()
|