p3nguknight commited on
Commit
185aa79
β€’
0 Parent(s):

Initial commit

Browse files
Files changed (7) hide show
  1. .gitattributes +35 -0
  2. .gitignore +3 -0
  3. README.md +11 -0
  4. app.py +226 -0
  5. packages.txt +1 -0
  6. plants_and_people.pdf +0 -0
  7. requirements.txt +7 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__
2
+ NOTES.md
3
+ .venv/
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Colpali Pixtral
3
+ emoji: πŸ–Ί
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
app.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ from pathlib import Path
4
+ from typing import cast
5
+
6
+ import gradio as gr
7
+ import spaces
8
+ import torch
9
+ from colpali_engine.models.paligemma.colpali import ColPali, ColPaliProcessor
10
+ from huggingface_hub import snapshot_download
11
+ from mistral_common.protocol.instruct.messages import (
12
+ ImageURLChunk,
13
+ TextChunk,
14
+ UserMessage,
15
+ )
16
+ from mistral_common.protocol.instruct.request import ChatCompletionRequest
17
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
18
+ from mistral_inference.generate import generate
19
+ from mistral_inference.transformer import Transformer
20
+ from pdf2image import convert_from_path
21
+ from torch.utils.data import DataLoader
22
+ from tqdm import tqdm
23
+
24
+ models_path = Path.home().joinpath("pixtral", "Pixtral")
25
+ models_path.mkdir(parents=True, exist_ok=True)
26
+
27
+ snapshot_download(
28
+ repo_id="mistral-community/pixtral-12b-240910",
29
+ allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
30
+ local_dir=models_path,
31
+ )
32
+
33
+
34
+ def image_to_base64(image_path):
35
+ with open(image_path, "rb") as img:
36
+ encoded_string = base64.b64encode(img.read()).decode("utf-8")
37
+ return f"data:image/jpeg;base64,{encoded_string}"
38
+
39
+
40
+ @spaces.GPU
41
+ def model_inference(
42
+ images,
43
+ text,
44
+ ):
45
+ tokenizer = MistralTokenizer.from_file(f"{models_path}/tekken.json")
46
+ model = Transformer.from_folder(models_path)
47
+
48
+ messages = [
49
+ UserMessage(
50
+ content=[ImageURLChunk(image_url=image_to_base64(i[0])) for i in images]
51
+ + [TextChunk(text=text)]
52
+ )
53
+ ]
54
+
55
+ completion_request = ChatCompletionRequest(messages=messages)
56
+
57
+ encoded = tokenizer.encode_chat_completion(completion_request)
58
+
59
+ images = encoded.images
60
+ tokens = encoded.tokens
61
+
62
+ out_tokens, _ = generate(
63
+ [tokens],
64
+ model,
65
+ images=[images],
66
+ max_tokens=512,
67
+ temperature=0.45,
68
+ eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id,
69
+ )
70
+ result = tokenizer.decode(out_tokens[0])
71
+ return result
72
+
73
+
74
+ @spaces.GPU
75
+ def search(query: str, ds, images, k):
76
+ model_name = "vidore/colpali-v1.2"
77
+ token = os.environ.get("HF_TOKEN")
78
+ model = ColPali.from_pretrained(
79
+ "vidore/colpaligemma-3b-pt-448-base",
80
+ torch_dtype=torch.bfloat16,
81
+ device_map="cuda",
82
+ token=token,
83
+ ).eval()
84
+
85
+ model.load_adapter(model_name)
86
+ model = model.eval()
87
+ processor = cast(
88
+ ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name, token=token)
89
+ )
90
+
91
+ qs = []
92
+ with torch.no_grad():
93
+ batch_query = processor.process_queries([query])
94
+ batch_query = {k: v.to("cuda") for k, v in batch_query.items()}
95
+ embeddings_query = model(**batch_query)
96
+ qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
97
+
98
+ scores = processor.score(qs, ds)
99
+ top_k_indices = scores.argsort(axis=1)[0][-k:]
100
+ results = []
101
+ for idx in top_k_indices:
102
+ results.append((images[idx])) # , f"Page {idx}"
103
+ del model
104
+ del processor
105
+ torch.cuda.empty_cache()
106
+ return results
107
+
108
+
109
+ def index(files, ds):
110
+ images = convert_files(files)
111
+ return index_gpu(images, ds)
112
+
113
+
114
+ def convert_files(files):
115
+ images = []
116
+ for f in files:
117
+ images.extend(convert_from_path(f, thread_count=4))
118
+
119
+ if len(images) >= 150:
120
+ raise gr.Error("The number of images in the dataset should be less than 150.")
121
+ return images
122
+
123
+
124
+ @spaces.GPU
125
+ def index_gpu(images, ds):
126
+ model_name = "vidore/colpali-v1.2"
127
+ token = os.environ.get("HF_TOKEN")
128
+ model = ColPali.from_pretrained(
129
+ "vidore/colpaligemma-3b-pt-448-base",
130
+ torch_dtype=torch.bfloat16,
131
+ device_map="cuda",
132
+ token=token,
133
+ ).eval()
134
+
135
+ model.load_adapter(model_name)
136
+ model = model.eval()
137
+ processor = cast(
138
+ ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name, token=token)
139
+ )
140
+
141
+ # run inference - docs
142
+ dataloader = DataLoader(
143
+ images,
144
+ batch_size=4,
145
+ shuffle=False,
146
+ collate_fn=lambda x: processor.process_images(x),
147
+ )
148
+
149
+ for batch_doc in tqdm(dataloader):
150
+ with torch.no_grad():
151
+ batch_doc = {k: v.to("cuda") for k, v in batch_doc.items()}
152
+ embeddings_doc = model(**batch_doc)
153
+ ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
154
+ del model
155
+ del processor
156
+ torch.cuda.empty_cache()
157
+ return f"Uploaded and converted {len(images)} pages", ds, images
158
+
159
+
160
+ def get_example():
161
+ return [
162
+ [["plants_and_people.pdf"], "What is the global population in 2050 ? "],
163
+ [["plants_and_people.pdf"], "Where was domesticated Teosinte ?"],
164
+ ]
165
+
166
+
167
+ css = """
168
+ #col-container {
169
+ margin: 0 auto;
170
+ max-width: 600px;
171
+ }
172
+ """
173
+ file = gr.File(file_types=["pdf"], file_count="multiple", label="pdfs")
174
+ query = gr.Textbox(placeholder="Enter your query here", label="query")
175
+
176
+ with gr.Blocks(title="ColPali + Pixtral", theme=gr.themes.Soft(), css=css) as demo:
177
+ with gr.Column(elem_id="col-container"):
178
+ gr.Markdown("# ColPali + Pixtral")
179
+
180
+ with gr.Row():
181
+ gr.Examples(
182
+ examples=get_example(),
183
+ inputs=[file, query],
184
+ )
185
+
186
+ with gr.Row():
187
+ with gr.Column(scale=2):
188
+ gr.Markdown("## Upload PDFs")
189
+
190
+ file.render()
191
+ message = gr.Textbox("Files not yet uploaded", label="Status")
192
+ convert_button = gr.Button("πŸ”„ Index documents")
193
+ embeds = gr.State(value=[])
194
+ imgs = gr.State(value=[])
195
+ img_chunk = gr.State(value=[])
196
+
197
+ with gr.Column(scale=3):
198
+ gr.Markdown("## Search with ColPali")
199
+ query.render()
200
+ k = gr.Slider(
201
+ minimum=1, maximum=4, step=1, label="Number of results", value=1
202
+ )
203
+ search_button = gr.Button("πŸ” Search", variant="primary")
204
+
205
+ # Define the actions
206
+
207
+ output_gallery = gr.Gallery(
208
+ label="Retrieved Documents", height=600, show_label=True
209
+ )
210
+
211
+ convert_button.click(
212
+ index, inputs=[file, embeds], outputs=[message, embeds, imgs]
213
+ )
214
+ search_button.click(
215
+ search, inputs=[query, embeds, imgs, k], outputs=[output_gallery]
216
+ )
217
+
218
+ gr.Markdown("## Get your answer with Pixtral")
219
+ answer_button = gr.Button("Answer", variant="primary")
220
+ output = gr.Markdown(label="Output")
221
+ answer_button.click(
222
+ model_inference, inputs=[output_gallery, query], outputs=output
223
+ )
224
+
225
+ if __name__ == "__main__":
226
+ demo.queue(max_size=10).launch()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ poppler-utils
plants_and_people.pdf ADDED
Binary file (487 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ transformers @ git+https://github.com/huggingface/transformers@78b2929
3
+ huggingface_hub==0.25.0
4
+ pdf2image==1.17.0
5
+ spaces==0.30.2
6
+ colpali_engine==0.3.0
7
+ mistral_inference==1.4.0