manu commited on
Commit
9c66171
β€’
1 Parent(s): 01531d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -12
app.py CHANGED
@@ -15,13 +15,32 @@ from torch.utils.data import DataLoader
15
  from tqdm import tqdm
16
  from transformers import AutoProcessor
17
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  @spaces.GPU
20
  def search(query: str, ds, images, k):
 
 
 
 
 
 
 
 
21
  qs = []
22
  with torch.no_grad():
23
  batch_query = process_queries(processor, [query], mock_image)
24
- batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
25
  embeddings_query = model(**batch_query)
26
  qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
27
 
@@ -55,29 +74,24 @@ def index(files, ds):
55
  collate_fn=lambda x: process_images(processor, x),
56
  )
57
 
 
 
 
 
 
58
  print(f"model device: {model.device}")
59
 
60
- model = model.to(model.device)
61
 
62
  for batch_doc in tqdm(dataloader):
63
  with torch.no_grad():
64
- batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
65
  print(f"model device: {model.device}")
66
  print(f"model device: {batch_doc['input_ids']}")
67
  embeddings_doc = model(**batch_doc)
68
  ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
69
  return f"Uploaded and converted {len(images)} pages", ds, images
70
 
71
- # Load model
72
- model_name = "vidore/colpali"
73
- token = os.environ.get("HF_TOKEN")
74
- model = ColPali.from_pretrained(
75
- "google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cuda", token = token).eval()
76
-
77
- model.load_adapter(model_name)
78
- processor = AutoProcessor.from_pretrained(model_name, token = token)
79
 
80
- mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
81
 
82
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
83
  gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models πŸ“š")
 
15
  from tqdm import tqdm
16
  from transformers import AutoProcessor
17
 
18
+ # Load model
19
+ model_name = "vidore/colpali"
20
+ token = os.environ.get("HF_TOKEN")
21
+ model = ColPali.from_pretrained(
22
+ "google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cuda", token = token).eval()
23
+
24
+ model.load_adapter(model_name)
25
+ processor = AutoProcessor.from_pretrained(model_name, token = token)
26
+
27
+ mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
28
+
29
 
30
  @spaces.GPU
31
  def search(query: str, ds, images, k):
32
+
33
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
34
+ if device != model.device:
35
+ model.to(device)
36
+
37
+ print(f"model device: {model.device}")
38
+
39
+
40
  qs = []
41
  with torch.no_grad():
42
  batch_query = process_queries(processor, [query], mock_image)
43
+ batch_query = {k: v.to(device) for k, v in batch_query.items()}
44
  embeddings_query = model(**batch_query)
45
  qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
46
 
 
74
  collate_fn=lambda x: process_images(processor, x),
75
  )
76
 
77
+
78
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
79
+ if device != model.device:
80
+ model.to(device)
81
+
82
  print(f"model device: {model.device}")
83
 
 
84
 
85
  for batch_doc in tqdm(dataloader):
86
  with torch.no_grad():
87
+ batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
88
  print(f"model device: {model.device}")
89
  print(f"model device: {batch_doc['input_ids']}")
90
  embeddings_doc = model(**batch_doc)
91
  ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
92
  return f"Uploaded and converted {len(images)} pages", ds, images
93
 
 
 
 
 
 
 
 
 
94
 
 
95
 
96
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
97
  gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models πŸ“š")