hackerloi45 commited on
Commit
294389b
Β·
1 Parent(s): ebbc741
Files changed (1) hide show
  1. app.py +135 -111
app.py CHANGED
@@ -1,121 +1,160 @@
1
  import os
2
- import uuid
3
  import gradio as gr
4
- import numpy as np
5
- from PIL import Image
6
  from qdrant_client import QdrantClient
7
- from qdrant_client.http.models import VectorParams, Distance, PointStruct
 
8
  from sentence_transformers import SentenceTransformer
 
 
 
9
 
10
- # ===============================
11
- # Config
12
- # ===============================
13
- UPLOAD_DIR = "uploaded_images"
14
- os.makedirs(UPLOAD_DIR, exist_ok=True)
15
-
16
  COLLECTION = "lost_and_found"
17
- qclient = QdrantClient(":memory:")
18
 
19
- # Load CLIP model
20
- encoder = SentenceTransformer("clip-ViT-B-32")
 
 
 
 
21
 
22
- # Safe vector size
23
- VECTOR_SIZE = encoder.get_sentence_embedding_dimension()
24
- if VECTOR_SIZE is None:
25
- VECTOR_SIZE = len(encoder.encode(["test"])[0])
26
 
27
- # Create collection
28
- if not qclient.collection_exists(COLLECTION):
 
 
29
  qclient.create_collection(
30
- collection_name=COLLECTION,
31
- vectors_config=VectorParams(size=int(VECTOR_SIZE), distance=Distance.COSINE),
 
 
 
32
  )
33
 
34
- # ===============================
35
- # Encode function
36
- # ===============================
37
  def encode_text(text: str):
38
- return np.asarray(encoder.encode([text])[0], dtype=float)
39
 
40
- def encode_image(img: Image.Image):
41
- return np.asarray(encoder.encode(img.convert("RGB")), dtype=float)
 
 
 
42
 
43
- # ===============================
44
- # Add Item
45
- # ===============================
46
  def add_item(text, image, uploader_name, uploader_phone):
47
  try:
48
- vector = None
49
- img_path = None
50
 
51
- if isinstance(image, Image.Image):
52
- img_id = str(uuid.uuid4())
53
- img_path = os.path.join(UPLOAD_DIR, f"{img_id}.png")
54
- image.save(img_path)
55
- vector = encode_image(image)
56
- elif text and text.strip():
57
- vector = encode_text(text)
58
 
59
- if vector is None:
60
- return "❌ Please provide either an image or description."
 
 
 
 
 
61
 
62
- payload = {
63
- "text": text or "",
64
- "uploader_name": uploader_name or "N/A",
65
- "uploader_phone": uploader_phone or "N/A",
66
- "image_path": img_path,
67
- }
68
 
69
  qclient.upsert(
70
  collection_name=COLLECTION,
71
- points=[PointStruct(id=str(uuid.uuid4()), vector=vector.tolist(), payload=payload)],
72
- wait=True,
 
 
 
 
 
 
 
 
 
 
73
  )
74
- return "βœ… Item added!"
75
  except Exception as e:
76
- return f"❌ Error: {e}"
77
 
78
- # ===============================
79
- # Search Items
80
- # ===============================
81
  def search_items(text, image, max_results, min_score):
82
  try:
83
  vector = None
 
 
84
  if isinstance(image, Image.Image):
85
  vector = encode_image(image)
86
- elif text and text.strip():
87
  vector = encode_text(text)
88
 
89
- if vector is None:
90
- return "❌ Provide text or image.", []
91
 
92
- results = qclient.search(
93
- collection_name=COLLECTION,
94
- query_vector=vector.tolist(),
95
- limit=int(max_results),
96
- score_threshold=float(min_score),
97
- with_payload=True,
98
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  if not results:
101
  return "No matches found.", []
102
 
103
- text_out = []
104
- gallery = []
105
-
106
- for r in results:
107
  payload = r.payload or {}
108
  score = getattr(r, "score", 0)
109
  uploader_name = payload.get("uploader_name", "N/A")
110
  uploader_phone = payload.get("uploader_phone", "N/A")
111
-
112
  desc = (
113
  f"id:{r.id} | score:{score:.3f} | "
114
  f"text:{payload.get('text','')} | "
115
  f"finder:{uploader_name} ({uploader_phone})"
116
  )
117
  text_out.append(desc)
118
-
119
  img_path = payload.get("image_path")
120
  if img_path and os.path.exists(img_path):
121
  gallery.append(img_path)
@@ -124,55 +163,40 @@ def search_items(text, image, max_results, min_score):
124
  except Exception as e:
125
  return f"❌ Error: {e}", []
126
 
127
- # ===============================
128
- # Clear DB
129
- # ===============================
130
- def clear_database():
131
- try:
132
- if qclient.collection_exists(COLLECTION):
133
- qclient.delete_collection(COLLECTION)
134
- qclient.create_collection(
135
- collection_name=COLLECTION,
136
- vectors_config=VectorParams(size=int(VECTOR_SIZE), distance=Distance.COSINE),
137
- )
138
- for f in os.listdir(UPLOAD_DIR):
139
- try:
140
- os.remove(os.path.join(UPLOAD_DIR, f))
141
- except:
142
- pass
143
- return "πŸ—‘οΈ Database cleared!"
144
- except Exception as e:
145
- return f"❌ Error clearing DB: {e}"
146
-
147
- # ===============================
148
  # Gradio UI
149
- # ===============================
150
- with gr.Blocks() as demo:
151
- gr.Markdown("## πŸ”Ž Lost & Found System")
152
-
153
  with gr.Tab("βž• Add Found Item"):
154
- text_in = gr.Textbox(label="Description (optional)")
155
- img_in = gr.Image(type="pil", label="Upload Image (optional)")
156
  uploader_name = gr.Textbox(label="Finder's Name")
157
  uploader_phone = gr.Textbox(label="Finder's Phone")
158
- add_btn = gr.Button("Add to Database")
159
- add_status = gr.Textbox(label="Status")
160
- add_btn.click(add_item, inputs=[text_in, img_in, uploader_name, uploader_phone], outputs=[add_status])
 
 
 
 
161
 
162
  with gr.Tab("πŸ” Search Lost Item"):
163
- search_text = gr.Textbox(label="Search by Text (optional)")
164
- search_img = gr.Image(type="pil", label="Search by Image (optional)")
165
- max_results = gr.Slider(1, 20, value=5, step=1, label="Max Results")
166
- min_score = gr.Slider(0.0, 1.0, value=0.3, step=0.01, label="Min Similarity Score")
167
  search_btn = gr.Button("Search")
168
- search_text_out = gr.Textbox(label="Search Results (Text)")
169
- search_gallery = gr.Gallery(label="Search Results", columns=2, height="auto")
170
- search_btn.click(search_items, inputs=[search_text, search_img, max_results, min_score], outputs=[search_text_out, search_gallery])
 
 
 
 
171
 
172
- with gr.Tab("πŸ—‘οΈ Admin"):
173
- clear_btn = gr.Button("Clear Database")
174
- clear_out = gr.Textbox(label="Status")
175
- clear_btn.click(clear_database, outputs=[clear_out])
176
 
177
  if __name__ == "__main__":
 
178
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import os
 
2
  import gradio as gr
 
 
3
  from qdrant_client import QdrantClient
4
+ from qdrant_client.http import models as rest
5
+ from qdrant_client.http.models import Distance, VectorParams
6
  from sentence_transformers import SentenceTransformer
7
+ from transformers import CLIPProcessor, CLIPModel
8
+ from PIL import Image
9
+ import uuid
10
 
11
+ # ----------------------------
12
+ # Qdrant Setup
13
+ # ----------------------------
14
+ QDRANT_HOST = "localhost"
15
+ QDRANT_PORT = 6333
 
16
  COLLECTION = "lost_and_found"
 
17
 
18
+ qclient = QdrantClient(QDRANT_HOST, port=QDRANT_PORT)
19
+
20
+ # Load Models
21
+ text_model = SentenceTransformer("all-MiniLM-L6-v2")
22
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
23
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
24
 
25
+ # Embedding sizes
26
+ TEXT_VECTOR_SIZE = text_model.get_sentence_embedding_dimension()
27
+ IMAGE_VECTOR_SIZE = clip_model.config.projection_dim
 
28
 
29
+ # Create collection if not exists
30
+ try:
31
+ qclient.get_collection(COLLECTION)
32
+ except Exception:
33
  qclient.create_collection(
34
+ COLLECTION,
35
+ vectors_config={
36
+ "text": VectorParams(size=TEXT_VECTOR_SIZE, distance=Distance.COSINE),
37
+ "image": VectorParams(size=IMAGE_VECTOR_SIZE, distance=Distance.COSINE),
38
+ },
39
  )
40
 
41
+ # ----------------------------
42
+ # Encoding Helpers
43
+ # ----------------------------
44
  def encode_text(text: str):
45
+ return text_model.encode([text])[0]
46
 
47
+ def encode_image(image: Image.Image):
48
+ inputs = clip_processor(images=image, return_tensors="pt")
49
+ with torch.no_grad():
50
+ emb = clip_model.get_image_features(**inputs)
51
+ return emb[0].cpu().numpy()
52
 
53
+ # ----------------------------
54
+ # Add Found Item
55
+ # ----------------------------
56
  def add_item(text, image, uploader_name, uploader_phone):
57
  try:
58
+ if not text and image is None:
59
+ return "❌ Please provide a description or an image."
60
 
61
+ text_vector = encode_text(text) if text else None
62
+ image_vector = encode_image(image) if image is not None else None
 
 
 
 
 
63
 
64
+ # Save uploaded image
65
+ img_path = None
66
+ if image is not None:
67
+ os.makedirs("uploaded_images", exist_ok=True)
68
+ img_id = str(uuid.uuid4()) + ".png"
69
+ img_path = os.path.join("uploaded_images", img_id)
70
+ image.save(img_path)
71
 
72
+ vectors = {}
73
+ if text_vector is not None:
74
+ vectors["text"] = text_vector.tolist()
75
+ if image_vector is not None:
76
+ vectors["image"] = image_vector.tolist()
 
77
 
78
  qclient.upsert(
79
  collection_name=COLLECTION,
80
+ points=[
81
+ rest.PointStruct(
82
+ id=str(uuid.uuid4()),
83
+ vector=vectors,
84
+ payload={
85
+ "text": text,
86
+ "image_path": img_path,
87
+ "uploader_name": uploader_name or "N/A",
88
+ "uploader_phone": uploader_phone or "N/A",
89
+ },
90
+ )
91
+ ],
92
  )
93
+ return "βœ… Item added successfully!"
94
  except Exception as e:
95
+ return f"❌ Error adding item: {e}"
96
 
97
+ # ----------------------------
98
+ # Search Lost Items
99
+ # ----------------------------
100
  def search_items(text, image, max_results, min_score):
101
  try:
102
  vector = None
103
+ query_text = text.strip() if text else ""
104
+
105
  if isinstance(image, Image.Image):
106
  vector = encode_image(image)
107
+ elif text:
108
  vector = encode_text(text)
109
 
110
+ results = []
 
111
 
112
+ # 1. Vector search
113
+ if vector is not None:
114
+ results = qclient.search(
115
+ collection_name=COLLECTION,
116
+ query_vector=vector.tolist(),
117
+ limit=int(max_results),
118
+ score_threshold=float(min_score),
119
+ with_payload=True,
120
+ )
121
+
122
+ # 2. Fallback text search on payload
123
+ if query_text:
124
+ keyword_results = qclient.scroll(
125
+ collection_name=COLLECTION,
126
+ scroll_filter=rest.Filter(
127
+ must=[rest.FieldCondition(
128
+ key="text",
129
+ match=rest.MatchText(text=query_text)
130
+ )]
131
+ ),
132
+ limit=100,
133
+ with_payload=True
134
+ )[0]
135
+
136
+ existing_ids = {r.id for r in results}
137
+ for km in keyword_results:
138
+ if km.id not in existing_ids:
139
+ km.score = 1.0
140
+ results.append(km)
141
 
142
  if not results:
143
  return "No matches found.", []
144
 
145
+ # Format output
146
+ text_out, gallery = [], []
147
+ for r in results[:max_results]:
 
148
  payload = r.payload or {}
149
  score = getattr(r, "score", 0)
150
  uploader_name = payload.get("uploader_name", "N/A")
151
  uploader_phone = payload.get("uploader_phone", "N/A")
 
152
  desc = (
153
  f"id:{r.id} | score:{score:.3f} | "
154
  f"text:{payload.get('text','')} | "
155
  f"finder:{uploader_name} ({uploader_phone})"
156
  )
157
  text_out.append(desc)
 
158
  img_path = payload.get("image_path")
159
  if img_path and os.path.exists(img_path):
160
  gallery.append(img_path)
 
163
  except Exception as e:
164
  return f"❌ Error: {e}", []
165
 
166
+ # ----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  # Gradio UI
168
+ # ----------------------------
169
+ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
 
 
170
  with gr.Tab("βž• Add Found Item"):
171
+ desc_in = gr.Textbox(label="Description", placeholder="Describe the item...")
172
+ img_in = gr.Image(label="Upload Image", type="pil")
173
  uploader_name = gr.Textbox(label="Finder's Name")
174
  uploader_phone = gr.Textbox(label="Finder's Phone")
175
+ add_btn = gr.Button("Add Item")
176
+ add_out = gr.Textbox(label="Status")
177
+ add_btn.click(
178
+ add_item,
179
+ inputs=[desc_in, img_in, uploader_name, uploader_phone],
180
+ outputs=[add_out]
181
+ )
182
 
183
  with gr.Tab("πŸ” Search Lost Item"):
184
+ text_in = gr.Textbox(label="Search by Text (optional)")
185
+ img_in_search = gr.Image(label="Search by Image (optional)", type="pil")
186
+ max_res = gr.Slider(1, 20, value=5, step=1, label="Max Results")
187
+ min_score = gr.Slider(0, 1, value=0.3, step=0.01, label="Min Similarity Score")
188
  search_btn = gr.Button("Search")
189
+ result_text = gr.Textbox(label="Search Results (Text)")
190
+ result_gallery = gr.Gallery(label="Search Results (Images)").style(grid=3)
191
+ search_btn.click(
192
+ search_items,
193
+ inputs=[text_in, img_in_search, max_res, min_score],
194
+ outputs=[result_text, result_gallery]
195
+ )
196
 
197
+ with gr.Tab("βš™οΈ Admin"):
198
+ gr.Markdown("Admin dashboard placeholder...")
 
 
199
 
200
  if __name__ == "__main__":
201
+ import torch
202
  demo.launch(server_name="0.0.0.0", server_port=7860)