AdrienB134 commited on
Commit
79fd59c
1 Parent(s): cc33a9b
Files changed (1) hide show
  1. app.py +82 -4
app.py CHANGED
@@ -13,7 +13,83 @@ from pdf2image import convert_from_path
13
  from PIL import Image
14
  from torch.utils.data import DataLoader
15
  from tqdm import tqdm
16
- from transformers import AutoProcessor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # Load model
19
  model_name = "vidore/colpali-v1.2"
@@ -96,7 +172,7 @@ def index_gpu(images, ds):
96
  embeddings_doc = model(**batch_doc)
97
  ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
98
  return f"Uploaded and converted {len(images)} pages", ds, images
99
-
100
  @spaces.GPU
101
  def answer_gpu():
102
  return 0
@@ -116,6 +192,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
116
  message = gr.Textbox("Files not yet uploaded", label="Status")
117
  embeds = gr.State(value=[])
118
  imgs = gr.State(value=[])
 
119
 
120
  with gr.Column(scale=3):
121
  gr.Markdown("## 2️⃣ Search")
@@ -133,10 +210,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
133
  output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
134
 
135
  convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
136
- search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery])
137
 
138
  answer_button = gr.Button("Answer", variant="primary")
139
- answer_button.click(answer_gpu, inputs=[])
 
140
 
141
  if __name__ == "__main__":
142
  demo.queue(max_size=10).launch(debug=True)
 
13
  from PIL import Image
14
  from torch.utils.data import DataLoader
15
  from tqdm import tqdm
16
+ from transformers import AutoProcessor, Idefics3ForConditionalGeneration
17
+ import re
18
+ import time
19
+ from PIL import Image
20
+ import torch
21
+ import subprocess
22
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
23
+
24
+
25
+ ## Load idefics
26
+ id_processor = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3")
27
+
28
+ id_model = Idefics3ForConditionalGeneration.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3",
29
+ torch_dtype=torch.bfloat16,
30
+ #_attn_implementation="flash_attention_2"
31
+ ).to("cuda")
32
+
33
+ BAD_WORDS_IDS = processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
34
+ EOS_WORDS_IDS = [processor.tokenizer.eos_token_id]
35
+
36
+ @spaces.GPU
37
+ def model_inference(
38
+ images, text, assistant_prefix= None, decoding_strategy = "Greedy", temperature= 0.4, max_new_tokens=512,
39
+ repetition_penalty=1.2, top_p=0.8
40
+ ):
41
+ if text == "" and not images:
42
+ gr.Error("Please input a query and optionally image(s).")
43
+
44
+ if text == "" and images:
45
+ gr.Error("Please input a text query along the image(s).")
46
+
47
+ if isinstance(images, Image.Image):
48
+ images = [images]
49
+
50
+
51
+ resulting_messages = [
52
+ {
53
+ "role": "user",
54
+ "content": [{"type": "image"}] + [
55
+ {"type": "text", "text": text}
56
+ ]
57
+ }
58
+ ]
59
+
60
+ if assistant_prefix:
61
+ text = f"{assistant_prefix} {text}"
62
+
63
+
64
+ prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
65
+ inputs = processor(text=prompt, images=[images], return_tensors="pt")
66
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
67
+
68
+ generation_args = {
69
+ "max_new_tokens": max_new_tokens,
70
+ "repetition_penalty": repetition_penalty,
71
+
72
+ }
73
+
74
+ assert decoding_strategy in [
75
+ "Greedy",
76
+ "Top P Sampling",
77
+ ]
78
+ if decoding_strategy == "Greedy":
79
+ generation_args["do_sample"] = False
80
+ elif decoding_strategy == "Top P Sampling":
81
+ generation_args["temperature"] = temperature
82
+ generation_args["do_sample"] = True
83
+ generation_args["top_p"] = top_p
84
+
85
+
86
+ generation_args.update(inputs)
87
+
88
+ # Generate
89
+ generated_ids = model.generate(**generation_args)
90
+
91
+ generated_texts = processor.batch_decode(generated_ids[:, generation_args["input_ids"].size(1):], skip_special_tokens=True)
92
+ return generated_texts[0]
93
 
94
  # Load model
95
  model_name = "vidore/colpali-v1.2"
 
172
  embeddings_doc = model(**batch_doc)
173
  ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
174
  return f"Uploaded and converted {len(images)} pages", ds, images
175
+
176
  @spaces.GPU
177
  def answer_gpu():
178
  return 0
 
192
  message = gr.Textbox("Files not yet uploaded", label="Status")
193
  embeds = gr.State(value=[])
194
  imgs = gr.State(value=[])
195
+ img_chunk = gr.State(value=[])
196
 
197
  with gr.Column(scale=3):
198
  gr.Markdown("## 2️⃣ Search")
 
210
  output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
211
 
212
  convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
213
+ search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery, img_chunk])
214
 
215
  answer_button = gr.Button("Answer", variant="primary")
216
+ output = gr.Textbox(label="Output")
217
+ answer_button.click(model_inference, inputs=[img_chunk, query], outputs=output)
218
 
219
  if __name__ == "__main__":
220
  demo.queue(max_size=10).launch(debug=True)