Spaces:
Running
on
T4
Running
on
T4
Upload folder using huggingface_hub
Browse files- .gitignore +3 -1
- backend/colpali.py +4 -62
- backend/modelmanager.py +1 -1
- backend/vespa_app.py +82 -24
- colpali-with-snippets/security/clients.pem +7 -7
- frontend/app.py +8 -7
- main.py +132 -130
.gitignore
CHANGED
@@ -8,4 +8,6 @@ template/
|
|
8 |
*.json
|
9 |
output/
|
10 |
pdfs/
|
11 |
-
static/
|
|
|
|
|
|
8 |
*.json
|
9 |
output/
|
10 |
pdfs/
|
11 |
+
static/full_images/
|
12 |
+
static/sim_maps/
|
13 |
+
embeddings/
|
backend/colpali.py
CHANGED
@@ -7,7 +7,7 @@ from typing import cast, Generator
|
|
7 |
from pathlib import Path
|
8 |
import base64
|
9 |
from io import BytesIO
|
10 |
-
from typing import Union, Tuple, List
|
11 |
import matplotlib
|
12 |
import matplotlib.cm as cm
|
13 |
import re
|
@@ -49,7 +49,7 @@ def load_model() -> Tuple[ColPali, ColPaliProcessor]:
|
|
49 |
|
50 |
# Load the processor
|
51 |
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
52 |
-
return model, processor
|
53 |
|
54 |
|
55 |
def load_vit_config(model):
|
@@ -63,7 +63,6 @@ def gen_similarity_maps(
|
|
63 |
model: ColPali,
|
64 |
processor: ColPaliProcessor,
|
65 |
device,
|
66 |
-
vit_config,
|
67 |
query: str,
|
68 |
query_embs: torch.Tensor,
|
69 |
token_idx_map: dict,
|
@@ -88,7 +87,7 @@ def gen_similarity_maps(
|
|
88 |
Tuple[int, str, str]: A tuple containing the image index, the selected token, and the base64-encoded image.
|
89 |
|
90 |
"""
|
91 |
-
|
92 |
# Process images and store original images and sizes
|
93 |
processed_images = []
|
94 |
original_images = []
|
@@ -254,7 +253,7 @@ def gen_similarity_maps(
|
|
254 |
|
255 |
# Store the base64-encoded image
|
256 |
result_per_image[token] = blended_img_base64
|
257 |
-
yield idx, token, blended_img_base64
|
258 |
end3 = time.perf_counter()
|
259 |
print(f"Blending images took: {end3 - start3} s")
|
260 |
|
@@ -287,60 +286,3 @@ def is_special_token(token: str) -> bool:
|
|
287 |
if (len(token) < 3) or pattern.match(token):
|
288 |
return True
|
289 |
return False
|
290 |
-
|
291 |
-
|
292 |
-
def add_sim_maps_to_result(
|
293 |
-
result: Dict[str, Any],
|
294 |
-
model: ColPali,
|
295 |
-
processor: ColPaliProcessor,
|
296 |
-
query: str,
|
297 |
-
q_embs: Any,
|
298 |
-
token_to_idx: Dict[str, int],
|
299 |
-
query_id: str,
|
300 |
-
result_cache,
|
301 |
-
) -> Dict[str, Any]:
|
302 |
-
print("Adding similarity maps to result - query_id:", query_id)
|
303 |
-
vit_config = load_vit_config(model)
|
304 |
-
imgs: List[str] = []
|
305 |
-
vespa_sim_maps: List[str] = []
|
306 |
-
for single_result in result["root"]["children"]:
|
307 |
-
img = single_result["fields"]["blur_image"]
|
308 |
-
if img:
|
309 |
-
imgs.append(img)
|
310 |
-
vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
|
311 |
-
if vespa_sim_map:
|
312 |
-
vespa_sim_maps.append(vespa_sim_map)
|
313 |
-
if not imgs:
|
314 |
-
return result
|
315 |
-
if len(imgs) != len(vespa_sim_maps):
|
316 |
-
print(
|
317 |
-
"Number of images and similarity maps do not match. Skipping similarity map generation."
|
318 |
-
)
|
319 |
-
return result
|
320 |
-
sim_map_imgs_generator = gen_similarity_maps(
|
321 |
-
model=model,
|
322 |
-
processor=processor,
|
323 |
-
device=model.device if hasattr(model, "device") else "cpu",
|
324 |
-
vit_config=vit_config,
|
325 |
-
query=query,
|
326 |
-
query_embs=q_embs,
|
327 |
-
token_idx_map=token_to_idx,
|
328 |
-
images=imgs,
|
329 |
-
vespa_sim_maps=vespa_sim_maps,
|
330 |
-
)
|
331 |
-
for img_idx, token, sim_mapb64 in sim_map_imgs_generator:
|
332 |
-
print(f"Created sim map for image {img_idx} and token {token}")
|
333 |
-
if (
|
334 |
-
len(result["root"]["children"]) > img_idx
|
335 |
-
and "fields" in result["root"]["children"][img_idx]
|
336 |
-
):
|
337 |
-
result["root"]["children"][img_idx]["fields"][f"sim_map_{token}"] = (
|
338 |
-
sim_mapb64
|
339 |
-
)
|
340 |
-
# Update result_cache with the new sim_map
|
341 |
-
result_cache.set(query_id, result)
|
342 |
-
else:
|
343 |
-
print(
|
344 |
-
f"Could not add sim map to result for image {img_idx} and token {token}"
|
345 |
-
)
|
346 |
-
return result
|
|
|
7 |
from pathlib import Path
|
8 |
import base64
|
9 |
from io import BytesIO
|
10 |
+
from typing import Union, Tuple, List
|
11 |
import matplotlib
|
12 |
import matplotlib.cm as cm
|
13 |
import re
|
|
|
49 |
|
50 |
# Load the processor
|
51 |
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
52 |
+
return model, processor, device
|
53 |
|
54 |
|
55 |
def load_vit_config(model):
|
|
|
63 |
model: ColPali,
|
64 |
processor: ColPaliProcessor,
|
65 |
device,
|
|
|
66 |
query: str,
|
67 |
query_embs: torch.Tensor,
|
68 |
token_idx_map: dict,
|
|
|
87 |
Tuple[int, str, str]: A tuple containing the image index, the selected token, and the base64-encoded image.
|
88 |
|
89 |
"""
|
90 |
+
vit_config = load_vit_config(model)
|
91 |
# Process images and store original images and sizes
|
92 |
processed_images = []
|
93 |
original_images = []
|
|
|
253 |
|
254 |
# Store the base64-encoded image
|
255 |
result_per_image[token] = blended_img_base64
|
256 |
+
yield idx, token, token_idx, blended_img_base64
|
257 |
end3 = time.perf_counter()
|
258 |
print(f"Blending images took: {end3 - start3} s")
|
259 |
|
|
|
286 |
if (len(token) < 3) or pattern.match(token):
|
287 |
return True
|
288 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/modelmanager.py
CHANGED
@@ -17,7 +17,7 @@ class ModelManager:
|
|
17 |
|
18 |
def initialize_model_and_processor(self):
|
19 |
if self.model is None or self.processor is None: # Ensure no reinitialization
|
20 |
-
self.model, self.processor = load_model()
|
21 |
if self.model is None or self.processor is None:
|
22 |
print("Failed to initialize model or processor at startup")
|
23 |
else:
|
|
|
17 |
|
18 |
def initialize_model_and_processor(self):
|
19 |
if self.model is None or self.processor is None: # Ensure no reinitialization
|
20 |
+
self.model, self.processor, self.device = load_model()
|
21 |
if self.model is None or self.processor is None:
|
22 |
print("Failed to initialize model or processor at startup")
|
23 |
else:
|
backend/vespa_app.py
CHANGED
@@ -1,18 +1,19 @@
|
|
1 |
import os
|
2 |
import time
|
3 |
from typing import Any, Dict, Tuple
|
4 |
-
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
from dotenv import load_dotenv
|
8 |
from vespa.application import Vespa
|
9 |
from vespa.io import VespaQueryResponse
|
|
|
10 |
|
11 |
|
12 |
class VespaQueryClient:
|
13 |
MAX_QUERY_TERMS = 64
|
14 |
VESPA_SCHEMA_NAME = "pdf_page"
|
15 |
-
SELECT_FIELDS = "id,title,url,blur_image,page_number,snippet,text
|
16 |
|
17 |
def __init__(self):
|
18 |
"""
|
@@ -73,6 +74,12 @@ class VespaQueryClient:
|
|
73 |
self.app.wait_for_application_up()
|
74 |
print(f"Connected to Vespa at {self.vespa_app_url}")
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
def format_query_results(
|
77 |
self, query: str, response: VespaQueryResponse, hits: int = 5
|
78 |
) -> dict:
|
@@ -100,6 +107,7 @@ class VespaQueryClient:
|
|
100 |
q_emb: torch.Tensor,
|
101 |
hits: int = 3,
|
102 |
timeout: str = "10s",
|
|
|
103 |
**kwargs,
|
104 |
) -> dict:
|
105 |
"""
|
@@ -121,9 +129,9 @@ class VespaQueryClient:
|
|
121 |
response: VespaQueryResponse = await session.query(
|
122 |
body={
|
123 |
"yql": (
|
124 |
-
f"select {self.
|
125 |
),
|
126 |
-
"ranking": "default",
|
127 |
"query": query,
|
128 |
"timeout": timeout,
|
129 |
"hits": hits,
|
@@ -146,6 +154,7 @@ class VespaQueryClient:
|
|
146 |
q_emb: torch.Tensor,
|
147 |
hits: int = 3,
|
148 |
timeout: str = "10s",
|
|
|
149 |
**kwargs,
|
150 |
) -> dict:
|
151 |
"""
|
@@ -167,9 +176,9 @@ class VespaQueryClient:
|
|
167 |
response: VespaQueryResponse = await session.query(
|
168 |
body={
|
169 |
"yql": (
|
170 |
-
f"select {self.
|
171 |
),
|
172 |
-
"ranking": "bm25",
|
173 |
"query": query,
|
174 |
"timeout": timeout,
|
175 |
"hits": hits,
|
@@ -266,30 +275,54 @@ class VespaQueryClient:
|
|
266 |
Returns:
|
267 |
Dict[str, Any]: The query results.
|
268 |
"""
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
|
|
278 |
else:
|
279 |
-
raise ValueError(f"Unsupported ranking: {
|
280 |
-
|
281 |
# Print score, title id, and text of the results
|
282 |
if "root" not in result or "children" not in result["root"]:
|
283 |
result["root"] = {"children": []}
|
284 |
return result
|
285 |
-
for idx, child in enumerate(result["root"]["children"]):
|
286 |
-
print(
|
287 |
-
f"Result {idx+1}: {child['relevance']}, {child['fields']['title']}, {child['fields']['id']}"
|
288 |
-
)
|
289 |
for single_result in result["root"]["children"]:
|
290 |
print(single_result["fields"].keys())
|
291 |
return result
|
292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
async def get_full_image_from_vespa(self, doc_id: str) -> str:
|
294 |
"""
|
295 |
Retrieve the full image from Vespa for a given document ID.
|
@@ -317,6 +350,23 @@ class VespaQueryClient:
|
|
317 |
)
|
318 |
return response.json["root"]["children"][0]["fields"]["full_image"]
|
319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
async def get_suggestions(self, query: str) -> list:
|
321 |
async with self.app.asyncio(connections=1) as session:
|
322 |
start = time.perf_counter()
|
@@ -348,6 +398,12 @@ class VespaQueryClient:
|
|
348 |
flat_questions = [item for sublist in questions for item in sublist]
|
349 |
return flat_questions
|
350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
async def query_vespa_nearest_neighbor(
|
352 |
self,
|
353 |
query: str,
|
@@ -355,6 +411,7 @@ class VespaQueryClient:
|
|
355 |
target_hits_per_query_tensor: int = 20,
|
356 |
hits: int = 3,
|
357 |
timeout: str = "10s",
|
|
|
358 |
**kwargs,
|
359 |
) -> dict:
|
360 |
"""
|
@@ -385,15 +442,16 @@ class VespaQueryClient:
|
|
385 |
binary_query_embeddings, target_hits_per_query_tensor
|
386 |
)
|
387 |
query_tensors.update(nn_query_dict)
|
388 |
-
|
389 |
response: VespaQueryResponse = await session.query(
|
390 |
body={
|
391 |
**query_tensors,
|
392 |
"presentation.timing": True,
|
393 |
"yql": (
|
394 |
-
f"select {self.
|
|
|
|
|
|
|
395 |
),
|
396 |
-
"ranking.profile": "retrieval-and-rerank",
|
397 |
"timeout": timeout,
|
398 |
"hits": hits,
|
399 |
"query": query,
|
|
|
1 |
import os
|
2 |
import time
|
3 |
from typing import Any, Dict, Tuple
|
4 |
+
import asyncio
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
from dotenv import load_dotenv
|
8 |
from vespa.application import Vespa
|
9 |
from vespa.io import VespaQueryResponse
|
10 |
+
from .colpali import is_special_token
|
11 |
|
12 |
|
13 |
class VespaQueryClient:
|
14 |
MAX_QUERY_TERMS = 64
|
15 |
VESPA_SCHEMA_NAME = "pdf_page"
|
16 |
+
SELECT_FIELDS = "id,title,url,blur_image,page_number,snippet,text"
|
17 |
|
18 |
def __init__(self):
|
19 |
"""
|
|
|
74 |
self.app.wait_for_application_up()
|
75 |
print(f"Connected to Vespa at {self.vespa_app_url}")
|
76 |
|
77 |
+
def get_fields(self, sim_map: bool = False):
|
78 |
+
if not sim_map:
|
79 |
+
return self.SELECT_FIELDS
|
80 |
+
else:
|
81 |
+
return "summaryfeatures"
|
82 |
+
|
83 |
def format_query_results(
|
84 |
self, query: str, response: VespaQueryResponse, hits: int = 5
|
85 |
) -> dict:
|
|
|
107 |
q_emb: torch.Tensor,
|
108 |
hits: int = 3,
|
109 |
timeout: str = "10s",
|
110 |
+
sim_map: bool = False,
|
111 |
**kwargs,
|
112 |
) -> dict:
|
113 |
"""
|
|
|
129 |
response: VespaQueryResponse = await session.query(
|
130 |
body={
|
131 |
"yql": (
|
132 |
+
f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where userQuery();"
|
133 |
),
|
134 |
+
"ranking": self.get_rank_profile("default", sim_map),
|
135 |
"query": query,
|
136 |
"timeout": timeout,
|
137 |
"hits": hits,
|
|
|
154 |
q_emb: torch.Tensor,
|
155 |
hits: int = 3,
|
156 |
timeout: str = "10s",
|
157 |
+
sim_map: bool = False,
|
158 |
**kwargs,
|
159 |
) -> dict:
|
160 |
"""
|
|
|
176 |
response: VespaQueryResponse = await session.query(
|
177 |
body={
|
178 |
"yql": (
|
179 |
+
f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where userQuery();"
|
180 |
),
|
181 |
+
"ranking": self.get_rank_profile("bm25", sim_map),
|
182 |
"query": query,
|
183 |
"timeout": timeout,
|
184 |
"hits": hits,
|
|
|
275 |
Returns:
|
276 |
Dict[str, Any]: The query results.
|
277 |
"""
|
278 |
+
rank_method = ranking.split("_")[0]
|
279 |
+
sim_map: bool = len(ranking.split("_")) > 1 and ranking.split("_")[1] == "sim"
|
280 |
+
if rank_method == "nn+colpali":
|
281 |
+
result = await self.query_vespa_nearest_neighbor(
|
282 |
+
query, q_embs, sim_map=sim_map
|
283 |
+
)
|
284 |
+
elif rank_method == "bm25+colpali":
|
285 |
+
result = await self.query_vespa_default(query, q_embs, sim_map=sim_map)
|
286 |
+
elif rank_method == "bm25":
|
287 |
+
result = await self.query_vespa_bm25(query, q_embs, sim_map=sim_map)
|
288 |
else:
|
289 |
+
raise ValueError(f"Unsupported ranking: {rank_method}")
|
|
|
290 |
# Print score, title id, and text of the results
|
291 |
if "root" not in result or "children" not in result["root"]:
|
292 |
result["root"] = {"children": []}
|
293 |
return result
|
|
|
|
|
|
|
|
|
294 |
for single_result in result["root"]["children"]:
|
295 |
print(single_result["fields"].keys())
|
296 |
return result
|
297 |
|
298 |
+
def get_sim_maps_from_query(
|
299 |
+
self, query: str, q_embs: torch.Tensor, ranking: str, token_to_idx: dict
|
300 |
+
):
|
301 |
+
"""
|
302 |
+
Get similarity maps from Vespa based on the ranking method.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
query (str): The query text.
|
306 |
+
q_embs (torch.Tensor): Query embeddings.
|
307 |
+
ranking (str): The ranking method to use.
|
308 |
+
token_to_idx (dict): Token to index mapping.
|
309 |
+
|
310 |
+
Returns:
|
311 |
+
Dict[str, Any]: The query results.
|
312 |
+
"""
|
313 |
+
# Get the result by calling asyncio.run
|
314 |
+
result = asyncio.run(
|
315 |
+
self.get_result_from_query(query, q_embs, ranking, token_to_idx)
|
316 |
+
)
|
317 |
+
vespa_sim_maps = []
|
318 |
+
for single_result in result["root"]["children"]:
|
319 |
+
vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
|
320 |
+
if vespa_sim_map is not None:
|
321 |
+
vespa_sim_maps.append(vespa_sim_map)
|
322 |
+
else:
|
323 |
+
raise ValueError("No sim_map found in Vespa response")
|
324 |
+
return vespa_sim_maps
|
325 |
+
|
326 |
async def get_full_image_from_vespa(self, doc_id: str) -> str:
|
327 |
"""
|
328 |
Retrieve the full image from Vespa for a given document ID.
|
|
|
350 |
)
|
351 |
return response.json["root"]["children"][0]["fields"]["full_image"]
|
352 |
|
353 |
+
def get_results_children(self, result: VespaQueryResponse) -> list:
|
354 |
+
return result["root"]["children"]
|
355 |
+
|
356 |
+
def results_to_search_results(
|
357 |
+
self, result: VespaQueryResponse, token_to_idx: dict
|
358 |
+
) -> list:
|
359 |
+
# Initialize sim_map_ fields in the result
|
360 |
+
fields_to_add = [
|
361 |
+
f"sim_map_{token}_{idx}"
|
362 |
+
for idx, token in enumerate(token_to_idx.keys())
|
363 |
+
if not is_special_token(token)
|
364 |
+
]
|
365 |
+
for child in result["root"]["children"]:
|
366 |
+
for sim_map_key in fields_to_add:
|
367 |
+
child["fields"][sim_map_key] = None
|
368 |
+
return self.get_results_children(result)
|
369 |
+
|
370 |
async def get_suggestions(self, query: str) -> list:
|
371 |
async with self.app.asyncio(connections=1) as session:
|
372 |
start = time.perf_counter()
|
|
|
398 |
flat_questions = [item for sublist in questions for item in sublist]
|
399 |
return flat_questions
|
400 |
|
401 |
+
def get_rank_profile(self, ranking: str, sim_map: bool) -> str:
|
402 |
+
if sim_map:
|
403 |
+
return f"{ranking}_sim"
|
404 |
+
else:
|
405 |
+
return ranking
|
406 |
+
|
407 |
async def query_vespa_nearest_neighbor(
|
408 |
self,
|
409 |
query: str,
|
|
|
411 |
target_hits_per_query_tensor: int = 20,
|
412 |
hits: int = 3,
|
413 |
timeout: str = "10s",
|
414 |
+
sim_map: bool = False,
|
415 |
**kwargs,
|
416 |
) -> dict:
|
417 |
"""
|
|
|
442 |
binary_query_embeddings, target_hits_per_query_tensor
|
443 |
)
|
444 |
query_tensors.update(nn_query_dict)
|
|
|
445 |
response: VespaQueryResponse = await session.query(
|
446 |
body={
|
447 |
**query_tensors,
|
448 |
"presentation.timing": True,
|
449 |
"yql": (
|
450 |
+
f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where {nn_string} or userQuery()"
|
451 |
+
),
|
452 |
+
"ranking.profile": self.get_rank_profile(
|
453 |
+
"retrieval-and-rerank", sim_map
|
454 |
),
|
|
|
455 |
"timeout": timeout,
|
456 |
"hits": hits,
|
457 |
"query": query,
|
colpali-with-snippets/security/clients.pem
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
-----BEGIN CERTIFICATE-----
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-----END CERTIFICATE-----
|
|
|
1 |
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIBODCB36ADAgECAhEAr37yU2TKTDMdQW1txTaMSjAKBggqhkjOPQQDAjAeMRww
|
3 |
+
GgYDVQQDExNjbG91ZC52ZXNwYS5leGFtcGxlMB4XDTI0MTAxNzA5NDY1M1oXDTM0
|
4 |
+
MTAxNTA5NDY1M1owHjEcMBoGA1UEAxMTY2xvdWQudmVzcGEuZXhhbXBsZTBZMBMG
|
5 |
+
ByqGSM49AgEGCCqGSM49AwEHA0IABPQjpb7RFvtnw288EY5eolq2v+0qC0h4JeW5
|
6 |
+
jCchXp4KUa5ufqeqyTcAxsfLn3BloPFDJ7Vb2gct9tZONa7xvc4wCgYIKoZIzj0E
|
7 |
+
AwIDSAAwRQIgR3wU3NUS02Behd0ojxo5sa5NVi0HhNW8RoAy0UyoGnACIQDWOqq+
|
8 |
+
zdKHJDorFuMWeMKKUe0cVQXZV3RvU5ssuXyEnw==
|
9 |
-----END CERTIFICATE-----
|
frontend/app.py
CHANGED
@@ -323,14 +323,13 @@ def SimMapButtonReady(query_id, idx, token, img_src):
|
|
323 |
)
|
324 |
|
325 |
|
326 |
-
def SimMapButtonPoll(query_id, idx, token):
|
327 |
return Button(
|
328 |
Lucide(icon="loader-circle", size="15", cls="animate-spin"),
|
329 |
size="sm",
|
330 |
disabled=True,
|
331 |
-
hx_get=f"/get_sim_map?query_id={query_id}&idx={idx}&token={token}",
|
332 |
-
|
333 |
-
hx_trigger=f"every {(idx+1)*0.3:.2f}s",
|
334 |
hx_swap="outerHTML",
|
335 |
cls="pointer-events-auto text-xs h-5 rounded-none px-2",
|
336 |
)
|
@@ -352,7 +351,6 @@ def SearchResult(results: list, query_id: Optional[str] = None):
|
|
352 |
fields = result["fields"] # Extract the 'fields' part of each result
|
353 |
blur_image_base64 = f"data:image/jpeg;base64,{fields['blur_image']}"
|
354 |
|
355 |
-
# Filter sim_map fields that are words with 4 or more characters
|
356 |
sim_map_fields = {
|
357 |
key: value
|
358 |
for key, value in fields.items()
|
@@ -370,14 +368,17 @@ def SearchResult(results: list, query_id: Optional[str] = None):
|
|
370 |
SimMapButtonReady(
|
371 |
query_id=query_id,
|
372 |
idx=idx,
|
373 |
-
token=key.split("_")[-
|
374 |
img_src=sim_map_base64,
|
375 |
)
|
376 |
)
|
377 |
else:
|
378 |
sim_map_buttons.append(
|
379 |
SimMapButtonPoll(
|
380 |
-
query_id=query_id,
|
|
|
|
|
|
|
381 |
)
|
382 |
)
|
383 |
|
|
|
323 |
)
|
324 |
|
325 |
|
326 |
+
def SimMapButtonPoll(query_id, idx, token, token_idx):
|
327 |
return Button(
|
328 |
Lucide(icon="loader-circle", size="15", cls="animate-spin"),
|
329 |
size="sm",
|
330 |
disabled=True,
|
331 |
+
hx_get=f"/get_sim_map?query_id={query_id}&idx={idx}&token={token}&token_idx={token_idx}",
|
332 |
+
hx_trigger="every 0.5s",
|
|
|
333 |
hx_swap="outerHTML",
|
334 |
cls="pointer-events-auto text-xs h-5 rounded-none px-2",
|
335 |
)
|
|
|
351 |
fields = result["fields"] # Extract the 'fields' part of each result
|
352 |
blur_image_base64 = f"data:image/jpeg;base64,{fields['blur_image']}"
|
353 |
|
|
|
354 |
sim_map_fields = {
|
355 |
key: value
|
356 |
for key, value in fields.items()
|
|
|
368 |
SimMapButtonReady(
|
369 |
query_id=query_id,
|
370 |
idx=idx,
|
371 |
+
token=key.split("_")[-2],
|
372 |
img_src=sim_map_base64,
|
373 |
)
|
374 |
)
|
375 |
else:
|
376 |
sim_map_buttons.append(
|
377 |
SimMapButtonPoll(
|
378 |
+
query_id=query_id,
|
379 |
+
idx=idx,
|
380 |
+
token=key.split("_")[-2],
|
381 |
+
token_idx=int(key.split("_")[-1]),
|
382 |
)
|
383 |
)
|
384 |
|
main.py
CHANGED
@@ -1,26 +1,33 @@
|
|
1 |
import asyncio
|
2 |
-
import base64
|
3 |
-
import io
|
4 |
import os
|
5 |
import time
|
6 |
-
from concurrent.futures import ThreadPoolExecutor
|
7 |
-
from functools import partial
|
8 |
from pathlib import Path
|
|
|
9 |
import uuid
|
10 |
-
import hashlib
|
11 |
-
|
12 |
import google.generativeai as genai
|
13 |
-
from fasthtml.common import
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
from vespa.application import Vespa
|
|
|
|
|
|
|
17 |
|
18 |
-
from backend.
|
19 |
-
from backend.colpali import (
|
20 |
-
add_sim_maps_to_result,
|
21 |
-
get_query_embeddings_and_token_map,
|
22 |
-
is_special_token,
|
23 |
-
)
|
24 |
from backend.modelmanager import ModelManager
|
25 |
from backend.vespa_app import VespaQueryClient
|
26 |
from frontend.app import (
|
@@ -77,10 +84,6 @@ app, rt = fast_app(
|
|
77 |
),
|
78 |
)
|
79 |
vespa_app: Vespa = VespaQueryClient()
|
80 |
-
result_cache = LRUCache(max_size=20) # Each result can be ~10MB
|
81 |
-
task_cache = LRUCache(
|
82 |
-
max_size=1000
|
83 |
-
) # Map from query_id to boolean value - False if not all results are ready.
|
84 |
thread_pool = ThreadPoolExecutor()
|
85 |
# Gemini config
|
86 |
|
@@ -95,9 +98,11 @@ But, you should NOT include backticks (`) or HTML tags in your response.
|
|
95 |
gemini_model = genai.GenerativeModel(
|
96 |
"gemini-1.5-flash-8b", system_instruction=GEMINI_SYSTEM_PROMPT
|
97 |
)
|
98 |
-
STATIC_DIR = Path(
|
99 |
-
IMG_DIR = STATIC_DIR / "
|
100 |
-
|
|
|
|
|
101 |
|
102 |
|
103 |
@app.on_event("startup")
|
@@ -112,9 +117,9 @@ async def keepalive():
|
|
112 |
return
|
113 |
|
114 |
|
115 |
-
def generate_query_id(
|
116 |
-
hash_input = (
|
117 |
-
return
|
118 |
|
119 |
|
120 |
@rt("/static/{filepath:path}")
|
@@ -135,7 +140,7 @@ def get():
|
|
135 |
|
136 |
|
137 |
@rt("/search")
|
138 |
-
def get(
|
139 |
# Extract the 'query' and 'ranking' parameters from the URL
|
140 |
query_value = request.query_params.get("query", "").strip()
|
141 |
ranking_value = request.query_params.get("ranking", "nn+colpali")
|
@@ -160,12 +165,7 @@ def get(session, request):
|
|
160 |
)
|
161 |
)
|
162 |
# Generate a unique query_id based on the query and ranking value
|
163 |
-
|
164 |
-
session["query_id"] = generate_query_id(
|
165 |
-
session["session_id"], query_value, ranking_value
|
166 |
-
)
|
167 |
-
query_id = session.get("query_id")
|
168 |
-
print(f"Query id in /search: {query_id}")
|
169 |
# Show the loading message if a query is provided
|
170 |
return Layout(
|
171 |
Main(Search(request), data_overlayscrollbars_initialize=True, cls="border-t"),
|
@@ -176,23 +176,31 @@ def get(session, request):
|
|
176 |
) # Show SearchBox and Loading message initially
|
177 |
|
178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
@rt("/fetch_results")
|
180 |
-
async def get(session, request, query: str,
|
181 |
if "hx-request" not in request.headers:
|
182 |
return RedirectResponse("/search")
|
183 |
|
184 |
-
#
|
185 |
-
|
186 |
-
print(
|
187 |
-
f"/fetch_results: Fetching results for query: {query}, ranking: {ranking_value}"
|
188 |
-
)
|
189 |
-
# Generate a unique query_id based on the query and ranking value
|
190 |
-
print(f"Sesssion in /fetch_results: {session}")
|
191 |
-
if "query_id" not in session:
|
192 |
-
session["query_id"] = generate_query_id(
|
193 |
-
session["session_id"], query_value, ranking_value
|
194 |
-
)
|
195 |
-
query_id = session.get("query_id")
|
196 |
print(f"Query id in /fetch_results: {query_id}")
|
197 |
# Run the embedding and query against Vespa app
|
198 |
model = app.manager.model
|
@@ -204,30 +212,21 @@ async def get(session, request, query: str, nn: bool = True):
|
|
204 |
result = await vespa_app.get_result_from_query(
|
205 |
query=query,
|
206 |
q_embs=q_embs,
|
207 |
-
ranking=
|
208 |
token_to_idx=token_to_idx,
|
209 |
)
|
210 |
end = time.perf_counter()
|
211 |
print(
|
212 |
f"Search results fetched in {end - start:.2f} seconds, Vespa says searchtime was {result['timing']['searchtime']} seconds"
|
213 |
)
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
for sim_map_key in fields_to_add:
|
222 |
-
child["fields"][sim_map_key] = None
|
223 |
-
result_cache.set(query_id, result)
|
224 |
-
# Start generating the similarity map in the background
|
225 |
-
asyncio.create_task(
|
226 |
-
generate_similarity_map(
|
227 |
-
model, processor, query, q_embs, token_to_idx, result, query_id
|
228 |
-
)
|
229 |
)
|
230 |
-
search_results = get_results_children(result)
|
231 |
return SearchResult(search_results, query_id)
|
232 |
|
233 |
|
@@ -247,78 +246,84 @@ async def poll_vespa_keepalive():
|
|
247 |
print(f"Vespa keepalive: {time.time()}")
|
248 |
|
249 |
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
sim_map_task = partial(
|
255 |
-
add_sim_maps_to_result,
|
256 |
-
result=result,
|
257 |
-
model=model,
|
258 |
-
processor=processor,
|
259 |
query=query,
|
260 |
q_embs=q_embs,
|
|
|
261 |
token_to_idx=token_to_idx,
|
262 |
-
query_id=query_id,
|
263 |
-
result_cache=result_cache,
|
264 |
)
|
265 |
-
|
266 |
-
|
267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
|
269 |
|
270 |
@app.get("/get_sim_map")
|
271 |
-
async def get_sim_map(query_id: str, idx: int, token: str):
|
272 |
"""
|
273 |
Endpoint that each of the sim map button polls to get the sim map image
|
274 |
when it is ready. If it is not ready, returns a SimMapButtonPoll, that
|
275 |
continues to poll every 1 second.
|
276 |
"""
|
277 |
-
|
278 |
-
if
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
return SimMapButtonPoll(query_id=query_id, idx=idx, token=token)
|
284 |
else:
|
285 |
-
sim_map_key = f"sim_map_{token}"
|
286 |
-
sim_map_b64 = search_results[idx]["fields"].get(sim_map_key, None)
|
287 |
-
if sim_map_b64 is None:
|
288 |
-
return SimMapButtonPoll(query_id=query_id, idx=idx, token=token)
|
289 |
-
sim_map_img_src = f"data:image/png;base64,{sim_map_b64}"
|
290 |
return SimMapButtonReady(
|
291 |
-
query_id=query_id, idx=idx, token=token, img_src=
|
292 |
)
|
293 |
|
294 |
|
295 |
-
async def update_full_image_cache(docid: str, query_id: str, idx: int, image_data: str):
|
296 |
-
result = None
|
297 |
-
max_wait = 20 # seconds. If horribly slow network latency.
|
298 |
-
start_time = time.time()
|
299 |
-
while result is None and time.time() - start_time < max_wait:
|
300 |
-
result = result_cache.get(query_id)
|
301 |
-
if result is None:
|
302 |
-
await asyncio.sleep(0.1)
|
303 |
-
try:
|
304 |
-
result["root"]["children"][idx]["fields"]["full_image"] = image_data
|
305 |
-
except KeyError as err:
|
306 |
-
print(f"Error updating full image cache: {err}")
|
307 |
-
result_cache.set(query_id, result)
|
308 |
-
print(f"Full image cache updated for query_id {query_id}")
|
309 |
-
return
|
310 |
-
|
311 |
-
|
312 |
@app.get("/full_image")
|
313 |
async def full_image(docid: str, query_id: str, idx: int):
|
314 |
"""
|
315 |
Endpoint to get the full quality image for a given result id.
|
316 |
"""
|
317 |
-
|
318 |
-
|
319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
return Img(
|
321 |
-
src=f"data:image/
|
322 |
alt="something",
|
323 |
cls="result-image w-full h-full object-contain",
|
324 |
)
|
@@ -338,28 +343,25 @@ async def get_suggestions(request):
|
|
338 |
|
339 |
async def message_generator(query_id: str, query: str):
|
340 |
images = []
|
341 |
-
|
342 |
-
all_images_ready = False
|
343 |
max_wait = 10 # seconds
|
344 |
start_time = time.time()
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
if img is not None:
|
354 |
-
images.append(img)
|
355 |
-
if len(images) == len(search_results):
|
356 |
-
all_images_ready = True
|
357 |
-
break
|
358 |
else:
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
|
|
|
|
|
|
363 |
if not images:
|
364 |
yield "event: message\ndata: I am sorry, I do not have enough information in the image to answer your question.\n\n"
|
365 |
yield "event: close\ndata: \n\n"
|
|
|
1 |
import asyncio
|
|
|
|
|
2 |
import os
|
3 |
import time
|
|
|
|
|
4 |
from pathlib import Path
|
5 |
+
from concurrent.futures import ThreadPoolExecutor
|
6 |
import uuid
|
|
|
|
|
7 |
import google.generativeai as genai
|
8 |
+
from fasthtml.common import (
|
9 |
+
Div,
|
10 |
+
Img,
|
11 |
+
Main,
|
12 |
+
P,
|
13 |
+
Script,
|
14 |
+
Link,
|
15 |
+
fast_app,
|
16 |
+
HighlightJS,
|
17 |
+
FileResponse,
|
18 |
+
RedirectResponse,
|
19 |
+
Aside,
|
20 |
+
StreamingResponse,
|
21 |
+
JSONResponse,
|
22 |
+
serve,
|
23 |
+
)
|
24 |
+
from shad4fast import ShadHead
|
25 |
from vespa.application import Vespa
|
26 |
+
import base64
|
27 |
+
from fastcore.parallel import threaded
|
28 |
+
from PIL import Image
|
29 |
|
30 |
+
from backend.colpali import get_query_embeddings_and_token_map, gen_similarity_maps
|
|
|
|
|
|
|
|
|
|
|
31 |
from backend.modelmanager import ModelManager
|
32 |
from backend.vespa_app import VespaQueryClient
|
33 |
from frontend.app import (
|
|
|
84 |
),
|
85 |
)
|
86 |
vespa_app: Vespa = VespaQueryClient()
|
|
|
|
|
|
|
|
|
87 |
thread_pool = ThreadPoolExecutor()
|
88 |
# Gemini config
|
89 |
|
|
|
98 |
gemini_model = genai.GenerativeModel(
|
99 |
"gemini-1.5-flash-8b", system_instruction=GEMINI_SYSTEM_PROMPT
|
100 |
)
|
101 |
+
STATIC_DIR = Path("static")
|
102 |
+
IMG_DIR = STATIC_DIR / "full_images"
|
103 |
+
SIM_MAP_DIR = STATIC_DIR / "sim_maps"
|
104 |
+
os.makedirs(IMG_DIR, exist_ok=True)
|
105 |
+
os.makedirs(SIM_MAP_DIR, exist_ok=True)
|
106 |
|
107 |
|
108 |
@app.on_event("startup")
|
|
|
117 |
return
|
118 |
|
119 |
|
120 |
+
def generate_query_id(query, ranking_value):
|
121 |
+
hash_input = (query + ranking_value).encode("utf-8")
|
122 |
+
return hash(hash_input)
|
123 |
|
124 |
|
125 |
@rt("/static/{filepath:path}")
|
|
|
140 |
|
141 |
|
142 |
@rt("/search")
|
143 |
+
def get(request):
|
144 |
# Extract the 'query' and 'ranking' parameters from the URL
|
145 |
query_value = request.query_params.get("query", "").strip()
|
146 |
ranking_value = request.query_params.get("ranking", "nn+colpali")
|
|
|
165 |
)
|
166 |
)
|
167 |
# Generate a unique query_id based on the query and ranking value
|
168 |
+
query_id = generate_query_id(query_value, ranking_value)
|
|
|
|
|
|
|
|
|
|
|
169 |
# Show the loading message if a query is provided
|
170 |
return Layout(
|
171 |
Main(Search(request), data_overlayscrollbars_initialize=True, cls="border-t"),
|
|
|
176 |
) # Show SearchBox and Loading message initially
|
177 |
|
178 |
|
179 |
+
@rt("/fetch_results2")
|
180 |
+
def get(query: str, ranking: str):
|
181 |
+
# 1. Get the results from Vespa (without sim_maps and full_images)
|
182 |
+
# Call search-endpoint in Vespa sync.
|
183 |
+
|
184 |
+
# 2. Kick off tasks to fetch sim_maps and full_images
|
185 |
+
# Sim maps - call search endpoint async.
|
186 |
+
# (A) New rank_profile that does not calculate sim_maps.
|
187 |
+
# (A) Make vespa endpoints take select_fields as a parameter.
|
188 |
+
# One sim map per image per token.
|
189 |
+
# the filename query_id_result_idx_token_idx.png
|
190 |
+
# Full image. based on the doc_id.
|
191 |
+
# Each of these tasks saves to disk.
|
192 |
+
# Need a cleanup task to delete old files.
|
193 |
+
# Polling endpoints for sim_maps and full_images checks if file exists and returns it.
|
194 |
+
pass
|
195 |
+
|
196 |
+
|
197 |
@rt("/fetch_results")
|
198 |
+
async def get(session, request, query: str, ranking: str):
|
199 |
if "hx-request" not in request.headers:
|
200 |
return RedirectResponse("/search")
|
201 |
|
202 |
+
# Get the hash of the query and ranking value
|
203 |
+
query_id = generate_query_id(query, ranking)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
print(f"Query id in /fetch_results: {query_id}")
|
205 |
# Run the embedding and query against Vespa app
|
206 |
model = app.manager.model
|
|
|
212 |
result = await vespa_app.get_result_from_query(
|
213 |
query=query,
|
214 |
q_embs=q_embs,
|
215 |
+
ranking=ranking,
|
216 |
token_to_idx=token_to_idx,
|
217 |
)
|
218 |
end = time.perf_counter()
|
219 |
print(
|
220 |
f"Search results fetched in {end - start:.2f} seconds, Vespa says searchtime was {result['timing']['searchtime']} seconds"
|
221 |
)
|
222 |
+
search_results = vespa_app.results_to_search_results(result, token_to_idx)
|
223 |
+
get_and_store_sim_maps(
|
224 |
+
query_id=query_id,
|
225 |
+
query=query,
|
226 |
+
q_embs=q_embs,
|
227 |
+
ranking=ranking,
|
228 |
+
token_to_idx=token_to_idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
)
|
|
|
230 |
return SearchResult(search_results, query_id)
|
231 |
|
232 |
|
|
|
246 |
print(f"Vespa keepalive: {time.time()}")
|
247 |
|
248 |
|
249 |
+
@threaded
|
250 |
+
def get_and_store_sim_maps(query_id, query: str, q_embs, ranking, token_to_idx):
|
251 |
+
ranking_sim = ranking + "_sim"
|
252 |
+
vespa_sim_maps = vespa_app.get_sim_maps_from_query(
|
|
|
|
|
|
|
|
|
|
|
253 |
query=query,
|
254 |
q_embs=q_embs,
|
255 |
+
ranking=ranking_sim,
|
256 |
token_to_idx=token_to_idx,
|
|
|
|
|
257 |
)
|
258 |
+
img_paths = [
|
259 |
+
IMG_DIR / f"{query_id}_{idx}.jpg" for idx in range(len(vespa_sim_maps))
|
260 |
+
]
|
261 |
+
# All images should be downloaded, but best to wait 5 secs
|
262 |
+
max_wait = 5
|
263 |
+
start_time = time.time()
|
264 |
+
while (
|
265 |
+
not all([os.path.exists(img_path) for img_path in img_paths])
|
266 |
+
and time.time() - start_time < max_wait
|
267 |
+
):
|
268 |
+
time.sleep(0.2)
|
269 |
+
if not all([os.path.exists(img_path) for img_path in img_paths]):
|
270 |
+
print(f"Images not ready in 5 seconds for query_id: {query_id}")
|
271 |
+
return False
|
272 |
+
sim_map_generator = gen_similarity_maps(
|
273 |
+
model=app.manager.model,
|
274 |
+
processor=app.manager.processor,
|
275 |
+
device=app.manager.device,
|
276 |
+
query=query,
|
277 |
+
query_embs=q_embs,
|
278 |
+
token_idx_map=token_to_idx,
|
279 |
+
images=img_paths,
|
280 |
+
vespa_sim_maps=vespa_sim_maps,
|
281 |
+
)
|
282 |
+
for idx, token, token_idx, blended_img_base64 in sim_map_generator:
|
283 |
+
with open(SIM_MAP_DIR / f"{query_id}_{idx}_{token_idx}.png", "wb") as f:
|
284 |
+
f.write(base64.b64decode(blended_img_base64))
|
285 |
+
print(
|
286 |
+
f"Sim map saved to disk for query_id: {query_id}, idx: {idx}, token: {token}"
|
287 |
+
)
|
288 |
+
return True
|
289 |
|
290 |
|
291 |
@app.get("/get_sim_map")
|
292 |
+
async def get_sim_map(query_id: str, idx: int, token: str, token_idx: int):
|
293 |
"""
|
294 |
Endpoint that each of the sim map button polls to get the sim map image
|
295 |
when it is ready. If it is not ready, returns a SimMapButtonPoll, that
|
296 |
continues to poll every 1 second.
|
297 |
"""
|
298 |
+
sim_map_path = SIM_MAP_DIR / f"{query_id}_{idx}_{token_idx}.png"
|
299 |
+
if not os.path.exists(sim_map_path):
|
300 |
+
print(f"Sim map not ready for query_id: {query_id}, idx: {idx}, token: {token}")
|
301 |
+
return SimMapButtonPoll(
|
302 |
+
query_id=query_id, idx=idx, token=token, token_idx=token_idx
|
303 |
+
)
|
|
|
304 |
else:
|
|
|
|
|
|
|
|
|
|
|
305 |
return SimMapButtonReady(
|
306 |
+
query_id=query_id, idx=idx, token=token, img_src=sim_map_path
|
307 |
)
|
308 |
|
309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
@app.get("/full_image")
|
311 |
async def full_image(docid: str, query_id: str, idx: int):
|
312 |
"""
|
313 |
Endpoint to get the full quality image for a given result id.
|
314 |
"""
|
315 |
+
img_path = IMG_DIR / f"{query_id}_{idx}.jpg"
|
316 |
+
if not os.path.exists(img_path):
|
317 |
+
image_data = await vespa_app.get_full_image_from_vespa(docid)
|
318 |
+
# image data is base 64 encoded string. Save it to disk as jpg.
|
319 |
+
with open(img_path, "wb") as f:
|
320 |
+
f.write(base64.b64decode(image_data))
|
321 |
+
print(f"Full image saved to disk for query_id: {query_id}, idx: {idx}")
|
322 |
+
else:
|
323 |
+
with open(img_path, "rb") as f:
|
324 |
+
image_data = base64.b64encode(f.read()).decode("utf-8")
|
325 |
return Img(
|
326 |
+
src=f"data:image/jpeg;base64,{image_data}",
|
327 |
alt="something",
|
328 |
cls="result-image w-full h-full object-contain",
|
329 |
)
|
|
|
343 |
|
344 |
async def message_generator(query_id: str, query: str):
|
345 |
images = []
|
346 |
+
num_images = 3 # Number of images before firing chat request
|
|
|
347 |
max_wait = 10 # seconds
|
348 |
start_time = time.time()
|
349 |
+
# Check if full images are ready on disk
|
350 |
+
while len(images) < num_images and time.time() - start_time < max_wait:
|
351 |
+
for idx in range(num_images):
|
352 |
+
if not os.path.exists(IMG_DIR / f"{query_id}_{idx}.jpg"):
|
353 |
+
print(
|
354 |
+
f"Message generator: Full image not ready for query_id: {query_id}, idx: {idx}"
|
355 |
+
)
|
356 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
357 |
else:
|
358 |
+
print(
|
359 |
+
f"Message generator: image ready for query_id: {query_id}, idx: {idx}"
|
360 |
+
)
|
361 |
+
images.append(Image.open(IMG_DIR / f"{query_id}_{idx}.jpg"))
|
362 |
+
await asyncio.sleep(0.2)
|
363 |
+
# yield message with number of images ready
|
364 |
+
yield f"event: message\ndata: Generating response based on {len(images)} images.\n\n"
|
365 |
if not images:
|
366 |
yield "event: message\ndata: I am sorry, I do not have enough information in the image to answer your question.\n\n"
|
367 |
yield "event: close\ndata: \n\n"
|