yonigozlan's picture
yonigozlan HF staff
Update to 1.3
0e139b0 verified
import os
import gradio as gr
import spaces
import torch
from pdf2image import convert_from_path
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import ColPaliForRetrieval, ColPaliProcessor
@spaces.GPU
def install_fa2():
print("Install FA2")
os.system("pip install flash-attn --no-build-isolation")
# install_fa2()
model_name = "vidore/colpali-v1.3-hf"
model = ColPaliForRetrieval.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="cuda:0", # or "mps" if on Apple Silicon
# attn_implementation="flash_attention_2", # should work on A100
).eval()
processor = ColPaliProcessor.from_pretrained(model_name)
@spaces.GPU
def search(query: str, ds, images, k):
k = min(k, len(ds))
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
model.to(device)
qs = []
with torch.no_grad():
batch_query = processor(text=[query]).to(model.device)
query_embeddings = model(**batch_query).embeddings
qs.extend(list(torch.unbind(query_embeddings.to("cpu"))))
scores = processor.score_retrieval(qs, ds)
top_k_indices = scores[0].topk(k).indices.tolist()
results = []
for idx in top_k_indices:
results.append((images[idx], f"Page {idx}"))
return results
def index(files, ds):
print("Converting files")
images = convert_files(files)
print(f"Files converted with {len(images)} images.")
return index_gpu(images, ds)
def convert_files(files):
images = []
for f in files:
images.extend(convert_from_path(f, thread_count=4))
if len(images) >= 150:
raise gr.Error("The number of images in the dataset should be less than 150.")
return images
@spaces.GPU
def index_gpu(images, ds):
"""Example script to run inference with ColPali"""
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
model.to(device)
# run inference - docs
dataloader = DataLoader(
images,
batch_size=4,
shuffle=False,
collate_fn=lambda x: processor(images=x).to(model.device),
)
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc).embeddings
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return f"Uploaded and converted {len(images)} pages", ds, images
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"# ColPali: Efficient Document Retrieval with Vision Language Models πŸ“š"
)
gr.Markdown("""Demo to test the Transformers πŸ€— implementation of ColPali on PDF documents.<br>
ColPali is the model implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).<br>
This demo allows you to upload PDF files and search for the most relevant pages based on your query.
Refresh the page if you change documents!<br>
⚠️ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing english text. Performance is expected to drop for other page formats and languages.<br>
Other models will be released with better robustness towards different languages and document formats!
Demo by [manu](https://huggingface.co/spaces/manu/ColPali-demo)
""")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("## 1️⃣ Upload PDFs")
file = gr.File(file_count="multiple", label="Upload PDFs")
convert_button = gr.Button("πŸ”„ Index documents")
message = gr.Textbox("Files not yet uploaded", label="Status")
embeds = gr.State(value=[])
imgs = gr.State(value=[])
with gr.Column(scale=3):
gr.Markdown("## 2️⃣ Search")
query = gr.Textbox(placeholder="Enter your query here", label="Query")
k = gr.Slider(
minimum=1, maximum=10, step=1, label="Number of results", value=5
)
# Define the actions
search_button = gr.Button("πŸ” Search", variant="primary")
output_gallery = gr.Gallery(
label="Retrieved Documents", height=600, show_label=True
)
convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
search_button.click(
search, inputs=[query, embeds, imgs, k], outputs=[output_gallery]
)
if __name__ == "__main__":
demo.queue(max_size=10).launch(debug=True)