import gradio as gr from PIL import Image import torch from torchvision import transforms from efficientnet_pytorch import EfficientNet import requests from io import BytesIO from qdrant_client import QdrantClient from qdrant_client.http.models import Filter, FieldCondition, MatchValue import os # โหลด EfficientNet-B7 model = EfficientNet.from_pretrained('efficientnet-b7') model.eval() url = os.environ['QDRANT_URL'] api_key = os.environ['QDRANT_API_KEY'] qdrant_client = QdrantClient(url=url, api_key=api_key) # กำหนดการแปลงภาพ preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def preprocess_image(image_path_or_url): if image_path_or_url.startswith("http://") or image_path_or_url.startswith("https://"): try: response = requests.get(image_path_or_url) image = Image.open(BytesIO(response.content)) image = image.convert('RGB') except Exception as e: print(f"Error: {e}") return None else: image = Image.open(image_path_or_url).convert('RGB') # แปลงภาพ image_tensor = preprocess(image) # เพิ่มมิติ batch image_tensor = image_tensor.unsqueeze(0) return image_tensor # ดึง embedding จาก EfficientNet-B7 def get_image_embedding(image_path_or_url): # ประมวลผลภาพ image_tensor = preprocess_image(image_path_or_url) if image_tensor == None: return None # สร้าง embedding with torch.no_grad(): embedding = model(image_tensor) return embedding # NEW: ดึง embedding จาก PIL.Image โดยตรง def get_image_embedding_from_pil(image: Image.Image): image_tensor = preprocess(image).unsqueeze(0) with torch.no_grad(): embedding = model(image_tensor) return embedding # ปรับฟังก์ชันหลักให้เลือกเส้นทางที่ถูกต้อง def search_similar_images(image_upload, image_url): # กำหนดแหล่งภาพ if image_upload is not None: query_vector = get_image_embedding_from_pil(image_upload) elif image_url: query_vector = get_image_embedding(image_url) else: return [], "❌ Please upload an image or provide an image URL." # ป้องกันกรณีที่ embedding ผิดพลาด if query_vector is None: return [], "❌ Could not process the image." # Query จาก Qdrant try: result = qdrant_client.search( collection_name="image_collection_b7_new", query_vector=query_vector[0].tolist(), with_payload=True, ) except Exception as e: return [], f"❌ Qdrant search failed: {str(e)}" if not result: return [], "❌ No matching images found." gallery_data = [] for r in result: image_url = r.payload.get('image_url', '') name = r.payload.get('name', 'Unknown') score = round(r.score, 3) caption = f"{name}\n(score: {score})" gallery_data.append((image_url, caption)) return gallery_data, f"✅ Found {len(result)} similar images." # Gradio Interface with gr.Blocks(title="Image Similarity Search with EfficientNet-B7") as demo: gr.Markdown("# 🔍 EfficientNet-B7 Image Search") gr.Markdown("Upload an image or paste a URL to find similar images using Qdrant.") with gr.Row(): image_input = gr.Image(label="Upload Image", type="pil") image_url_input = gr.Textbox(label="Or enter Image URL") search_btn = gr.Button("Search Similar Images") gallery_output = gr.Gallery(label="Similar Images", columns=3, height="auto", show_label=False) status_output = gr.Textbox(label="Status") search_btn.click( fn=search_similar_images, inputs=[image_input, image_url_input], outputs=[gallery_output, status_output] ) if __name__ == "__main__": demo.launch()