thomasht86 commited on
Commit
2df1399
β€’
1 Parent(s): ecc0caa

Upload colpali.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. colpali.py +521 -0
colpali.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import torch
4
+ from PIL import Image
5
+ import numpy as np
6
+ from typing import cast
7
+ import pprint
8
+ from pathlib import Path
9
+ import base64
10
+ from io import BytesIO
11
+ from typing import Union, Tuple
12
+ import matplotlib
13
+ import re
14
+
15
+ from colpali_engine.models import ColPali, ColPaliProcessor
16
+ from colpali_engine.utils.torch_utils import get_torch_device
17
+ from einops import rearrange
18
+ from vidore_benchmark.interpretability.plot_utils import plot_similarity_heatmap
19
+ from vidore_benchmark.interpretability.torch_utils import (
20
+ normalize_similarity_map_per_query_token,
21
+ )
22
+ from vidore_benchmark.interpretability.vit_configs import VIT_CONFIG
23
+ from vidore_benchmark.utils.image_utils import scale_image
24
+ from vespa.application import Vespa
25
+ from vespa.io import VespaQueryResponse
26
+
27
+ matplotlib.use("Agg")
28
+
29
+ MAX_QUERY_TERMS = 64
30
+ # OUTPUT_DIR = Path(__file__).parent.parent / "output" / "sim_maps"
31
+ # OUTPUT_DIR.mkdir(exist_ok=True)
32
+
33
+ COLPALI_GEMMA_MODEL_ID = "vidore--colpaligemma-3b-pt-448-base"
34
+ COLPALI_GEMMA_MODEL_SNAPSHOT = "12c59eb7e23bc4c26876f7be7c17760d5d3a1ffa"
35
+ COLPALI_GEMMA_MODEL_PATH = (
36
+ Path().home()
37
+ / f".cache/huggingface/hub/models--{COLPALI_GEMMA_MODEL_ID}/snapshots/{COLPALI_GEMMA_MODEL_SNAPSHOT}"
38
+ )
39
+ COLPALI_MODEL_ID = "vidore--colpali-v1.2"
40
+ COLPALI_MODEL_SNAPSHOT = "9912ce6f8a462d8cf2269f5606eabbd2784e764f"
41
+ COLPALI_MODEL_PATH = (
42
+ Path().home()
43
+ / f".cache/huggingface/hub/models--{COLPALI_MODEL_ID}/snapshots/{COLPALI_MODEL_SNAPSHOT}"
44
+ )
45
+ COLPALI_GEMMA_MODEL_NAME = COLPALI_GEMMA_MODEL_ID.replace("--", "/")
46
+
47
+
48
+ def load_model() -> Tuple[ColPali, ColPaliProcessor]:
49
+ model_name = "vidore/colpali-v1.2"
50
+
51
+ device = get_torch_device("auto")
52
+ print(f"Using device: {device}")
53
+
54
+ # Load the model
55
+ model = cast(
56
+ ColPali,
57
+ ColPali.from_pretrained(
58
+ model_name,
59
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
60
+ device_map=device,
61
+ ),
62
+ ).eval()
63
+
64
+ # Load the processor
65
+ processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
66
+ return model, processor
67
+
68
+
69
+ def load_vit_config(model):
70
+ # Load the ViT config
71
+ print(f"VIT config: {VIT_CONFIG}")
72
+ vit_config = VIT_CONFIG[COLPALI_GEMMA_MODEL_NAME]
73
+ return vit_config
74
+
75
+
76
+ # Create dummy image
77
+ dummy_image = Image.new("RGB", (448, 448), (255, 255, 255))
78
+
79
+
80
+ def gen_similarity_map(
81
+ model, processor, device, vit_config, query, image: Union[Path, str]
82
+ ):
83
+ # Should take in the b64 image from Vespa query result
84
+ # And possibly the tensor representing the output_image
85
+ if isinstance(image, Path):
86
+ # image is a file path
87
+ try:
88
+ image = Image.open(image)
89
+ except Exception as e:
90
+ raise ValueError(f"Failed to open image from path: {e}")
91
+ elif isinstance(image, str):
92
+ # image is b64 string
93
+ try:
94
+ image = Image.open(BytesIO(base64.b64decode(image)))
95
+ except Exception as e:
96
+ raise ValueError(f"Failed to open image from b64: {e}")
97
+
98
+ # Preview the image
99
+ scale_image(image, 512)
100
+ # Preprocess inputs
101
+ input_text_processed = processor.process_queries([query]).to(device)
102
+ input_image_processed = processor.process_images([image]).to(device)
103
+ # Forward passes
104
+ with torch.no_grad():
105
+ output_text = model.forward(**input_text_processed)
106
+ output_image = model.forward(**input_image_processed)
107
+ # output_image is the tensor that we could get from the Vespa query
108
+ # Print shape of output_text and output_image
109
+ # Output image shape: torch.Size([1, 1030, 128])
110
+ # Remove the special tokens from the output
111
+ output_image = output_image[
112
+ :, : processor.image_seq_length, :
113
+ ] # (1, n_patches_x * n_patches_y, dim)
114
+
115
+ # Rearrange the output image tensor to explicitly represent the 2D grid of patches
116
+ output_image = rearrange(
117
+ output_image,
118
+ "b (h w) c -> b h w c",
119
+ h=vit_config.n_patch_per_dim,
120
+ w=vit_config.n_patch_per_dim,
121
+ ) # (1, n_patches_x, n_patches_y, dim)
122
+ # Get the similarity map
123
+ similarity_map = torch.einsum(
124
+ "bnk,bijk->bnij", output_text, output_image
125
+ ) # (1, query_tokens, n_patches_x, n_patches_y)
126
+
127
+ # Normalize the similarity map
128
+ similarity_map_normalized = normalize_similarity_map_per_query_token(
129
+ similarity_map
130
+ ) # (1, query_tokens, n_patches_x, n_patches_y)
131
+ # Use this cell output to choose a token using its index
132
+ query_tokens = processor.tokenizer.tokenize(
133
+ processor.decode(input_text_processed.input_ids[0])
134
+ )
135
+ # Choose a token
136
+ token_idx = (
137
+ 10 # e.g. if "12: '▁Kazakhstan',", set 12 to choose the token 'Kazakhstan'
138
+ )
139
+ selected_token = processor.decode(input_text_processed.input_ids[0, token_idx])
140
+ # strip whitespace
141
+ selected_token = selected_token.strip()
142
+ print(f"Selected token: `{selected_token}`")
143
+ # Retrieve the similarity map for the chosen token
144
+ pprint.pprint({idx: val for idx, val in enumerate(query_tokens)})
145
+ # Resize the image to square
146
+ input_image_square = image.resize((vit_config.resolution, vit_config.resolution))
147
+
148
+ # Plot the similarity map
149
+ fig, ax = plot_similarity_heatmap(
150
+ input_image_square,
151
+ patch_size=vit_config.patch_size,
152
+ image_resolution=vit_config.resolution,
153
+ similarity_map=similarity_map_normalized[0, token_idx, :, :],
154
+ )
155
+ ax = annotate_plot(ax, selected_token)
156
+ return fig, ax
157
+
158
+
159
+ # def save_figure(fig, filename: str = "similarity_map.png"):
160
+ # fig.savefig(
161
+ # OUTPUT_DIR / filename,
162
+ # bbox_inches="tight",
163
+ # pad_inches=0,
164
+ # )
165
+
166
+
167
+ def annotate_plot(ax, query, selected_token):
168
+ # Add the query text
169
+ ax.set_title(query, fontsize=18)
170
+ # Add annotation with selected token
171
+ ax.annotate(
172
+ f"Selected token:`{selected_token}`",
173
+ xy=(0.5, 0.95),
174
+ xycoords="axes fraction",
175
+ ha="center",
176
+ va="center",
177
+ fontsize=18,
178
+ color="black",
179
+ bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", lw=1),
180
+ )
181
+ return ax
182
+
183
+
184
+ def gen_similarity_map_new(
185
+ processor: ColPaliProcessor,
186
+ model: ColPali,
187
+ device,
188
+ vit_config,
189
+ query: str,
190
+ query_embs: torch.Tensor,
191
+ token_idx_map: dict,
192
+ token_to_show: str,
193
+ image: Union[Path, str],
194
+ ):
195
+ if isinstance(image, Path):
196
+ # image is a file path
197
+ try:
198
+ image = Image.open(image)
199
+ except Exception as e:
200
+ raise ValueError(f"Failed to open image from path: {e}")
201
+ elif isinstance(image, str):
202
+ # image is b64 string
203
+ try:
204
+ image = Image.open(BytesIO(base64.b64decode(image)))
205
+ except Exception as e:
206
+ raise ValueError(f"Failed to open image from b64: {e}")
207
+ token_idx = token_idx_map[token_to_show]
208
+ print(f"Selected token: `{token_to_show}`")
209
+ # strip whitespace
210
+ # Preview the image
211
+ # scale_image(image, 512)
212
+ # Preprocess inputs
213
+ input_image_processed = processor.process_images([image]).to(device)
214
+ # Forward passes
215
+ with torch.no_grad():
216
+ output_image = model.forward(**input_image_processed)
217
+ # output_image is the tensor that we could get from the Vespa query
218
+ # Print shape of output_text and output_image
219
+ # Output image shape: torch.Size([1, 1030, 128])
220
+ # Remove the special tokens from the output
221
+ print(f"Output image shape before dim: {output_image.shape}")
222
+ output_image = output_image[
223
+ :, : processor.image_seq_length, :
224
+ ] # (1, n_patches_x * n_patches_y, dim)
225
+ print(f"Output image shape after dim: {output_image.shape}")
226
+ # Rearrange the output image tensor to explicitly represent the 2D grid of patches
227
+ output_image = rearrange(
228
+ output_image,
229
+ "b (h w) c -> b h w c",
230
+ h=vit_config.n_patch_per_dim,
231
+ w=vit_config.n_patch_per_dim,
232
+ ) # (1, n_patches_x, n_patches_y, dim)
233
+ # Get the similarity map
234
+ print(f"Query embs shape: {query_embs.shape}")
235
+ # Add 1 extra dim to start of query_embs
236
+ query_embs = query_embs.unsqueeze(0).to(device)
237
+ print(f"Output image shape: {output_image.shape}")
238
+ similarity_map = torch.einsum(
239
+ "bnk,bijk->bnij", query_embs, output_image
240
+ ) # (1, query_tokens, n_patches_x, n_patches_y)
241
+ print(f"Similarity map shape: {similarity_map.shape}")
242
+ # Normalize the similarity map
243
+ similarity_map_normalized = normalize_similarity_map_per_query_token(
244
+ similarity_map
245
+ ) # (1, query_tokens, n_patches_x, n_patches_y)
246
+ print(f"Similarity map normalized shape: {similarity_map_normalized.shape}")
247
+ # Use this cell output to choose a token using its index
248
+ input_image_square = image.resize((vit_config.resolution, vit_config.resolution))
249
+
250
+ # Plot the similarity map
251
+ fig, ax = plot_similarity_heatmap(
252
+ input_image_square,
253
+ patch_size=vit_config.patch_size,
254
+ image_resolution=vit_config.resolution,
255
+ similarity_map=similarity_map_normalized[0, token_idx, :, :],
256
+ )
257
+ ax = annotate_plot(ax, query, token_to_show)
258
+ # save the figure
259
+ save_figure(fig, f"similarity_map_{token_to_show}.png")
260
+ return fig, ax
261
+
262
+
263
+ def get_query_embeddings_and_token_map(
264
+ processor, model, query, image
265
+ ) -> Tuple[torch.Tensor, dict]:
266
+ inputs = processor.process_queries([query]).to(model.device)
267
+ with torch.no_grad():
268
+ embeddings_query = model(**inputs)
269
+ q_emb = embeddings_query.to("cpu")[0] # Extract the single embedding
270
+ # Use this cell output to choose a token using its index
271
+ query_tokens = processor.tokenizer.tokenize(processor.decode(inputs.input_ids[0]))
272
+ # reverse key, values in dictionary
273
+ print(query_tokens)
274
+ token_to_idx = {val: idx for idx, val in enumerate(query_tokens)}
275
+ return q_emb, token_to_idx
276
+
277
+
278
+ def format_query_results(query, response, hits=5) -> dict:
279
+ query_time = response.json.get("timing", {}).get("searchtime", -1)
280
+ query_time = round(query_time, 2)
281
+ count = response.json.get("root", {}).get("fields", {}).get("totalCount", 0)
282
+ result_text = f"Query text: '{query}', query time {query_time}s, count={count}, top results:\n"
283
+ print(result_text)
284
+ return response.json
285
+
286
+
287
+ async def query_vespa_default(
288
+ app: Vespa,
289
+ query: str,
290
+ q_emb: torch.Tensor,
291
+ hits: int = 3,
292
+ timeout: str = "10s",
293
+ **kwargs,
294
+ ) -> dict:
295
+ async with app.asyncio(connections=1, total_timeout=120) as session:
296
+ query_embedding = format_q_embs(q_emb)
297
+ response: VespaQueryResponse = await session.query(
298
+ body={
299
+ "yql": "select id,title,url,image,page_number,text from pdf_page where userQuery();",
300
+ "ranking": "default",
301
+ "query": query,
302
+ "timeout": timeout,
303
+ "hits": hits,
304
+ "input.query(qt)": query_embedding,
305
+ "presentation.timing": True,
306
+ **kwargs,
307
+ },
308
+ )
309
+ assert response.is_successful(), response.json
310
+ return format_query_results(query, response)
311
+
312
+
313
+ def float_to_binary_embedding(float_query_embedding: dict) -> dict:
314
+ binary_query_embeddings = {}
315
+ for k, v in float_query_embedding.items():
316
+ binary_vector = (
317
+ np.packbits(np.where(np.array(v) > 0, 1, 0)).astype(np.int8).tolist()
318
+ )
319
+ binary_query_embeddings[k] = binary_vector
320
+ if len(binary_query_embeddings) >= MAX_QUERY_TERMS:
321
+ print(f"Warning: Query has more than {MAX_QUERY_TERMS} terms. Truncating.")
322
+ break
323
+ return binary_query_embeddings
324
+
325
+
326
+ def create_nn_query_strings(
327
+ binary_query_embeddings: dict, target_hits_per_query_tensor: int = 20
328
+ ) -> Tuple[str, dict]:
329
+ # Query tensors for nearest neighbor calculations
330
+ nn_query_dict = {}
331
+ for i in range(len(binary_query_embeddings)):
332
+ nn_query_dict[f"input.query(rq{i})"] = binary_query_embeddings[i]
333
+ nn = " OR ".join(
334
+ [
335
+ f"({{targetHits:{target_hits_per_query_tensor}}}nearestNeighbor(embedding,rq{i}))"
336
+ for i in range(len(binary_query_embeddings))
337
+ ]
338
+ )
339
+ return nn, nn_query_dict
340
+
341
+
342
+ def format_q_embs(q_embs: torch.Tensor) -> dict:
343
+ float_query_embedding = {k: v.tolist() for k, v in enumerate(q_embs)}
344
+ return float_query_embedding
345
+
346
+
347
+ async def query_vespa_nearest_neighbor(
348
+ app: Vespa,
349
+ query: str,
350
+ q_emb: torch.Tensor,
351
+ target_hits_per_query_tensor: int = 20,
352
+ hits: int = 3,
353
+ timeout: str = "10s",
354
+ **kwargs,
355
+ ) -> dict:
356
+ # Hyperparameter for speed vs. accuracy
357
+ async with app.asyncio(connections=1, total_timeout=180) as session:
358
+ float_query_embedding = format_q_embs(q_emb)
359
+ binary_query_embeddings = float_to_binary_embedding(float_query_embedding)
360
+
361
+ # Mixed tensors for MaxSim calculations
362
+ query_tensors = {
363
+ "input.query(qtb)": binary_query_embeddings,
364
+ "input.query(qt)": float_query_embedding,
365
+ }
366
+ nn_string, nn_query_dict = create_nn_query_strings(
367
+ binary_query_embeddings, target_hits_per_query_tensor
368
+ )
369
+ query_tensors.update(nn_query_dict)
370
+ response: VespaQueryResponse = await session.query(
371
+ body={
372
+ **query_tensors,
373
+ "presentation.timing": True,
374
+ "yql": f"select id,title,text,url,image,page_number from pdf_page where {nn_string}",
375
+ "ranking.profile": "retrieval-and-rerank",
376
+ "timeout": timeout,
377
+ "hits": hits,
378
+ **kwargs,
379
+ },
380
+ )
381
+ assert response.is_successful(), response.json
382
+ return format_query_results(query, response)
383
+
384
+
385
+ def is_special_token(token: str) -> bool:
386
+ # Pattern for tokens that start with '<', numbers, whitespace, or single characters
387
+ pattern = re.compile(r"^<.*$|^\d+$|^\s+$|^.$")
388
+ if pattern.match(token):
389
+ return True
390
+ return False
391
+
392
+
393
+ async def get_result_from_query(
394
+ app: Vespa,
395
+ processor: ColPaliProcessor,
396
+ model: ColPali,
397
+ query: str,
398
+ nn=False,
399
+ gen_sim_map=False,
400
+ ):
401
+ # Get the query embeddings and token map
402
+ print(query)
403
+ q_embs, token_to_idx = get_query_embeddings_and_token_map(
404
+ processor, model, query, dummy_image
405
+ )
406
+ print(token_to_idx)
407
+ # Use the token map to choose a token randomly for now
408
+ # Dynamically select a token containing 'water'
409
+
410
+ if nn:
411
+ result = await query_vespa_nearest_neighbor(app, query, q_embs)
412
+ else:
413
+ result = await query_vespa_default(app, query, q_embs)
414
+ # Print score, title id and text of the results
415
+ for idx, child in enumerate(result["root"]["children"]):
416
+ print(
417
+ f"Result {idx+1}: {child['relevance']}, {child['fields']['title']}, {child['fields']['id']}"
418
+ )
419
+
420
+ if gen_sim_map:
421
+ for single_result in result["root"]["children"]:
422
+ img = single_result["fields"]["image"]
423
+ for token in token_to_idx:
424
+ if is_special_token(token):
425
+ print(f"Skipping special token: {token}")
426
+ continue
427
+ fig, ax = gen_similarity_map_new(
428
+ processor,
429
+ model,
430
+ model.device,
431
+ load_vit_config(model),
432
+ query,
433
+ q_embs,
434
+ token_to_idx,
435
+ token,
436
+ img,
437
+ )
438
+ sim_map = base64.b64encode(fig.canvas.tostring_rgb()).decode("utf-8")
439
+ single_result["fields"][f"sim_map_{token}"] = sim_map
440
+ return result
441
+
442
+
443
+ def get_result_dummy(query: str, nn: bool = False):
444
+ result = {}
445
+ result["timing"] = {}
446
+ result["timing"]["querytime"] = 0.23700000000000002
447
+ result["timing"]["summaryfetchtime"] = 0.001
448
+ result["timing"]["searchtime"] = 0.23900000000000002
449
+ result["root"] = {}
450
+ result["root"]["id"] = "toplevel"
451
+ result["root"]["relevance"] = 1
452
+ result["root"]["fields"] = {}
453
+ result["root"]["fields"]["totalCount"] = 59
454
+ result["root"]["coverage"] = {}
455
+ result["root"]["coverage"]["coverage"] = 100
456
+ result["root"]["coverage"]["documents"] = 155
457
+ result["root"]["coverage"]["full"] = True
458
+ result["root"]["coverage"]["nodes"] = 1
459
+ result["root"]["coverage"]["results"] = 1
460
+ result["root"]["coverage"]["resultsFull"] = 1
461
+ result["root"]["children"] = []
462
+ elt0 = {}
463
+ elt0["id"] = "index:colpalidemo_content/0/424c85e7dece761d226f060f"
464
+ elt0["relevance"] = 2354.050122871995
465
+ elt0["source"] = "colpalidemo_content"
466
+ elt0["fields"] = {}
467
+ elt0["fields"]["id"] = "a767cb1868be9a776cd56b768347b089"
468
+ elt0["fields"]["url"] = (
469
+ "https://static.conocophillips.com/files/resources/conocophillips-2023-sustainability-report.pdf"
470
+ )
471
+ elt0["fields"]["title"] = "ConocoPhillips 2023 Sustainability Report"
472
+ elt0["fields"]["page_number"] = 50
473
+ elt0["fields"]["image"] = "empty for now - is base64 encoded image"
474
+ result["root"]["children"].append(elt0)
475
+ elt1 = {}
476
+ elt1["id"] = "index:colpalidemo_content/0/b927c4979f0beaf0d7fab8e9"
477
+ elt1["relevance"] = 2313.7529950886965
478
+ elt1["source"] = "colpalidemo_content"
479
+ elt1["fields"] = {}
480
+ elt1["fields"]["id"] = "9f2fc0aa02c9561adfaa1451c875658f"
481
+ elt1["fields"]["url"] = (
482
+ "https://static.conocophillips.com/files/resources/conocophillips-2023-managing-climate-related-risks.pdf"
483
+ )
484
+ elt1["fields"]["title"] = "ConocoPhillips Managing Climate Related Risks"
485
+ elt1["fields"]["page_number"] = 44
486
+ elt1["fields"]["image"] = "empty for now - is base64 encoded image"
487
+ result["root"]["children"].append(elt1)
488
+ elt2 = {}
489
+ elt2["id"] = "index:colpalidemo_content/0/9632d72238829d6afefba6c9"
490
+ elt2["relevance"] = 2312.230182081461
491
+ elt2["source"] = "colpalidemo_content"
492
+ elt2["fields"] = {}
493
+ elt2["fields"]["id"] = "d638ded1ddcb446268b289b3f65430fd"
494
+ elt2["fields"]["url"] = (
495
+ "https://static.conocophillips.com/files/resources/24-0976-sustainability-highlights_nature.pdf"
496
+ )
497
+ elt2["fields"]["title"] = (
498
+ "ConocoPhillips Sustainability Highlights - Nature (24-0976)"
499
+ )
500
+ elt2["fields"]["page_number"] = 0
501
+ elt2["fields"]["image"] = "empty for now - is base64 encoded image"
502
+ result["root"]["children"].append(elt2)
503
+ return result
504
+
505
+
506
+ if __name__ == "__main__":
507
+ model, processor = load_model()
508
+ vit_config = load_vit_config(model)
509
+ query = "How many percent of source water is fresh water?"
510
+ image_filepath = (
511
+ Path(__file__).parent.parent
512
+ / "static"
513
+ / "assets"
514
+ / "ConocoPhillips Sustainability Highlights - Nature (24-0976).png"
515
+ )
516
+ gen_similarity_map(
517
+ model, processor, model.device, vit_config, query=query, image=image_filepath
518
+ )
519
+ result = get_result_dummy("dummy query")
520
+ print(result)
521
+ print("Done")