import base64 import os from pathlib import Path from typing import cast import gradio as gr import spaces import torch from colpali_engine.models.paligemma.colpali import ColPali, ColPaliProcessor from huggingface_hub import snapshot_download from mistral_common.protocol.instruct.messages import ( ImageURLChunk, TextChunk, UserMessage, ) from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_inference.generate import generate from mistral_inference.transformer import Transformer from pdf2image import convert_from_path from torch.utils.data import DataLoader from tqdm import tqdm models_path = Path.home().joinpath("pixtral", "Pixtral") models_path.mkdir(parents=True, exist_ok=True) snapshot_download( repo_id="mistral-community/pixtral-12b-240910", allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], local_dir=models_path, ) def image_to_base64(image_path): with open(image_path, "rb") as img: encoded_string = base64.b64encode(img.read()).decode("utf-8") return f"data:image/jpeg;base64,{encoded_string}" @spaces.GPU def model_inference( images, text, ): tokenizer = MistralTokenizer.from_file(f"{models_path}/tekken.json") model = Transformer.from_folder(models_path) messages = [ UserMessage( content=[ImageURLChunk(image_url=image_to_base64(i[0])) for i in images] + [TextChunk(text=text)] ) ] completion_request = ChatCompletionRequest(messages=messages) encoded = tokenizer.encode_chat_completion(completion_request) images = encoded.images tokens = encoded.tokens out_tokens, _ = generate( [tokens], model, images=[images], max_tokens=512, temperature=0.45, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id, ) result = tokenizer.decode(out_tokens[0]) return result @spaces.GPU def search(query: str, ds, images, k): model_name = "vidore/colpali-v1.2" token = os.environ.get("HF_TOKEN") model = ColPali.from_pretrained( "vidore/colpaligemma-3b-pt-448-base", torch_dtype=torch.bfloat16, device_map="cuda", token=token, ).eval() model.load_adapter(model_name) model = model.eval() processor = cast( ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name, token=token) ) qs = [] with torch.no_grad(): batch_query = processor.process_queries([query]) batch_query = {k: v.to("cuda") for k, v in batch_query.items()} embeddings_query = model(**batch_query) qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) scores = processor.score(qs, ds) top_k_indices = scores.argsort(axis=1)[0][-k:] results = [] for idx in top_k_indices: results.append((images[idx])) # , f"Page {idx}" del model del processor torch.cuda.empty_cache() return results def index(files, ds): images = convert_files(files) return index_gpu(images, ds) def convert_files(files): images = [] for f in files: images.extend(convert_from_path(f, thread_count=4)) if len(images) >= 150: raise gr.Error("The number of images in the dataset should be less than 150.") return images @spaces.GPU def index_gpu(images, ds): model_name = "vidore/colpali-v1.2" token = os.environ.get("HF_TOKEN") model = ColPali.from_pretrained( "vidore/colpaligemma-3b-pt-448-base", torch_dtype=torch.bfloat16, device_map="cuda", token=token, ).eval() model.load_adapter(model_name) model = model.eval() processor = cast( ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name, token=token) ) # run inference - docs dataloader = DataLoader( images, batch_size=4, shuffle=False, collate_fn=lambda x: processor.process_images(x), ) for batch_doc in tqdm(dataloader): with torch.no_grad(): batch_doc = {k: v.to("cuda") for k, v in batch_doc.items()} embeddings_doc = model(**batch_doc) ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) del model del processor torch.cuda.empty_cache() return f"Uploaded and converted {len(images)} pages", ds, images def get_example(): return [ [["plants_and_people.pdf"], "What is the global population in 2050 ? "], [["plants_and_people.pdf"], "Where was domesticated Teosinte ?"], ] css = """ #col-container { margin: 0 auto; max-width: 600px; } """ file = gr.File(file_types=["pdf"], file_count="multiple", label="pdfs") query = gr.Textbox(placeholder="Enter your query here", label="query") with gr.Blocks(title="ColPali + Pixtral", theme=gr.themes.Soft(), css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# ColPali + Pixtral") with gr.Row(): gr.Examples( examples=get_example(), inputs=[file, query], ) with gr.Row(): with gr.Column(scale=2): gr.Markdown("## Upload PDFs") file.render() message = gr.Textbox("Files not yet uploaded", label="Status") convert_button = gr.Button("🔄 Index documents") embeds = gr.State(value=[]) imgs = gr.State(value=[]) img_chunk = gr.State(value=[]) with gr.Column(scale=3): gr.Markdown("## Search with ColPali") query.render() k = gr.Slider( minimum=1, maximum=4, step=1, label="Number of results", value=1 ) search_button = gr.Button("🔍 Search", variant="primary") # Define the actions output_gallery = gr.Gallery( label="Retrieved Documents", height=600, show_label=True ) convert_button.click( index, inputs=[file, embeds], outputs=[message, embeds, imgs] ) search_button.click( search, inputs=[query, embeds, imgs, k], outputs=[output_gallery] ) gr.Markdown("## Get your answer with Pixtral") answer_button = gr.Button("Answer", variant="primary") output = gr.Markdown(label="Output") answer_button.click( model_inference, inputs=[output_gallery, query], outputs=output ) if __name__ == "__main__": demo.queue(max_size=10).launch()