Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import base64 | |
import asyncio | |
from io import BytesIO | |
import gradio as gr | |
import qdrant_client | |
from PIL import Image | |
from typing import List, Dict, Tuple | |
import llamaindex_utils | |
from rag_pipeline import async_indexDocument | |
from models import get_lora_model, enable_lora, ColPali, ColPaliProcessor | |
from utils import load_tokenizer | |
from llama_index.llms.gemini import Gemini | |
from llama_index.core.tools import RetrieverTool | |
from huggingface_hub import hf_hub_download | |
GEMINI_API_KEY = os.getenv(key="GEMINI_API_KEY") | |
QDRANT_API_KEY = os.getenv(key="QDRANT_API_KEY") | |
HF_TOKEN_KEY = os.getenv(key="HF_TOKEN_KEY") | |
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
async def initialize_model() -> Dict: | |
"""Initialize models | |
Returns: | |
model_dict: Dict: Dictionary stores neccessary models | |
""" | |
if not os.path.exists("./pretrained/colpaligemma-3b-mix-448-base"): | |
os.makedirs("./pretrained/colpaligemma-3b-mix-448-base", exist_ok=True) | |
files_to_download = ["adapter_model.safetensors", | |
"config.json", | |
"model-00001-of-00002.safetensors", | |
"model-00002-of-00002.safetensors", | |
"preprocessor_config.json", | |
"tokenizer.json", | |
"tokenizer.model", | |
"tokenizer_config.json"] | |
for file in files_to_download: | |
hf_hub_download(repo_id="dnnhhuy/colpaligemma-3b-mix-448-base", | |
filename=file, | |
token=HF_TOKEN_KEY, | |
local_dir="./pretrained/colpaligemma-3b-mix-448-base") | |
model = ColPali.from_pretrained(model_dir='./pretrained/colpaligemma-3b-mix-448-base', torch_dtype=torch.bfloat16) | |
tokenizer = load_tokenizer(tokenizer_dir='./pretrained/colpaligemma-3b-mix-448-base') | |
processor = ColPaliProcessor(tokenizer=tokenizer).from_pretrained(pretrained_dir='./pretrained/colpaligemma-3b-mix-448-base') | |
model.model.language_model.model = get_lora_model(model.model.language_model.model, | |
rank=32, | |
alphas=32, | |
lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj', 'gate_proj', 'up_proj'], | |
training=False, | |
dropout_p=0.1, | |
torch_dtype=torch.bfloat16) | |
model.model.language_model.model = enable_lora(model.model.language_model.model, lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj', 'gate_proj', 'up_proj'], enabled=True) | |
model = get_lora_model(model, | |
rank=32, | |
alphas=32, | |
lora_modules=['custom_text_proj'], | |
training=False, | |
dropout_p=0.1, | |
torch_dtype=torch.bfloat16) | |
model = enable_lora(model, lora_modules=['custom_text_proj'], enabled=True) | |
model.load_lora('./pretrained/colpaligemma-3b-mix-448-base') | |
# Initialize LLM | |
generation_config = { | |
"temperature": 0.0, | |
"top_p": 0.95, | |
"top_k": 64, | |
"max_output_tokens": 1024, | |
"response_mime_type": "text/plain", | |
} | |
llm = Gemini(api_key=GEMINI_API_KEY, generation_config=generation_config) | |
# Setup Qdrant | |
# Creating Qdrant Client | |
vector_store_client = qdrant_client.AsyncQdrantClient(location="https://b3878645-ec71-426c-8afa-b8b3b7589e40.us-east4-0.gcp.cloud.qdrant.io", | |
api_key=QDRANT_API_KEY, | |
timeout=100) | |
embed_model = llamaindex_utils.ColPaliGemmaEmbedding(model=model, | |
processor=processor, | |
device=device) | |
collections = await get_collection_names(vector_store_client) | |
retrievers_dict = {} | |
for name in collections: | |
if name not in retrievers_dict: | |
retrievers_dict[name] = llamaindex_utils.ColPaliRetriever(vector_store_client=vector_store_client, | |
target_collection=name, | |
embed_model=embed_model, | |
similarity_top_k=3) | |
return {"llm": llm, | |
"vector_store_client": vector_store_client, | |
"model": model, | |
"processor": processor, | |
"embed_model": embed_model, | |
"collections": collections, | |
"retrievers_dict": retrievers_dict} | |
async def get_collection_names(vector_store_client): | |
collections = await vector_store_client.get_collections() | |
return [collection.name for collection in collections.collections] | |
async def index(files: List[str], | |
target_collection: str | |
) -> Tuple[str, gr.Dropdown, List[str], Dict[str, llamaindex_utils.ColPaliRetriever]]: | |
""" | |
Insert all image pages from files to speicified target collection to the vector store | |
and return the mapping from retriever's name to its object instance. | |
Args: | |
files (List[str]): List of file path | |
target_collection (str): Target collection to insert into the vector store | |
Returns: | |
Tuple[str, gr.Dropdown, List[str], Dict[str, llamaindex_utils.ColPaliRetriever]]: Return message, dropdown component, collections' names, dictionary mapping retriever to its object instance | |
""" | |
for file in files: | |
await async_indexDocument(file_path=file, | |
vector_store_client=model_dict["vector_store_client"], | |
target_collection=target_collection, | |
model=model_dict["model"], | |
processor=model_dict["processor"], | |
device=device) | |
if target_collection not in retrievers: | |
retrievers[target_collection] = llamaindex_utils.ColPaliRetriever(vector_store_client=model_dict["vector_store_client"], | |
target_collection=target_collection, | |
embed_model=model_dict["embed_model"], | |
similarity_top_k=3) | |
collection_names = await get_collection_names(model_dict["vector_store_client"]) | |
return (f"Uploaded and index {len(files)} files.", | |
gr.Dropdown(choices=collection_names), | |
collection_names) | |
async def search_with_llm(query: str, | |
similarity_top_k: int, | |
num_children: int) -> Tuple[str, List[Image.Image]]: | |
"""Search the result given query and list of retrievers. | |
Returns the search's response and list of images support for that response. | |
Args: | |
query (str): Query question | |
retrievers (Dict[str, llamaindex_utils.ColPaliRetriever]): Dictionary mapping between retrievers' names and their object instances | |
similarity_top_k (int): top K similarity results retrieved from the retriever | |
num_children (int): number of children for tree summarization | |
Returns: | |
Tuple[str, List[Image.Image]]: Returns the search's response and list of images support for that response. | |
""" | |
retriever_tools = [RetrieverTool.from_defaults( | |
name=key, | |
retriever=value, | |
description=f"Useful for retrieving information about {key} financials") for key, value in retrievers.items()] | |
retriever_mappings = {retriever_tool.metadata.name: retriever_tool.retriever for retriever_tool in retriever_tools} | |
fusion_retriever = llamaindex_utils.CustomFusionRetriever(llm=model_dict["llm"], | |
retriever_mappings=retriever_mappings, | |
similarity_top_k=similarity_top_k) | |
query_engine = llamaindex_utils.CustomQueryEngine(retriever_tools=[retriever_tool.metadata for retriever_tool in retriever_tools], | |
fusion_retriever=fusion_retriever, | |
llm=model_dict["llm"], | |
num_children=num_children) | |
response = await query_engine.aquery(query_str=query) | |
return response.response, [Image.open(BytesIO(base64.b64decode(image))) for image in response.source_images] | |
async def delete_collection(target_collection): | |
if await model_dict["vector_store_client"].collection_exists(collection_name=target_collection): | |
await model_dict["vector_store_client"].delete_collection(collection_name=target_collection, timeout=100) | |
choices = await get_collection_names(model_dict["vector_store_client"]) | |
return (f"Deleted collection {target_collection}", gr.Dropdown(choices=choices), choices) | |
else: | |
choices = await get_collection_names(model_dict["vector_store_client"]) | |
return (f"Collection {target_collection} is not found.", gr.Dropdown(choices=choices), choices) | |
def build_gui(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Image Based RAG System using ColPali ππ") | |
with gr.Row(equal_height=True): | |
with gr.Column(): | |
gr.Markdown("## 1οΈ. Upload PDFs") | |
files = gr.File(file_types=["pdf"], | |
file_count="multiple", | |
interactive=True) | |
choices = gr.State(value=model_dict["collections"]) | |
gr.Markdown("## 2οΈ. Index the PDFs and upload") | |
target_collection = gr.Dropdown(choices=choices.value, | |
allow_custom_value=True, | |
label="Collection name", | |
show_label=True, | |
interactive=True) | |
message_box = gr.Textbox(value="File not yet uploaded", | |
show_label=False, | |
interactive=False) | |
with gr.Row(equal_height=True): | |
delete_button = gr.Button("ποΈ Delete collection") | |
convert_button = gr.Button("π Convert and upload") | |
# Define the actions for conversion | |
convert_button.click(index, inputs=[files, target_collection], outputs=[message_box, target_collection, choices]) | |
# Define the actions for delete collection | |
delete_button.click(delete_collection, inputs=[target_collection], outputs=[message_box, target_collection, choices]) | |
with gr.Column(): | |
gr.Markdown("## 3οΈ. Enter your question") | |
query = gr.Textbox(placeholder="Enter your query to match", | |
lines=15, | |
max_lines=20, | |
autoscroll=True) | |
with gr.Accordion(label="Additional Settings", open=False): | |
similarity_top_k = gr.Slider(minimum=1, | |
maximum=10, | |
value=3, | |
step=1.0, | |
label="Top K similarity retrieved from the retriever") | |
num_children = gr.Slider(minimum=1, | |
maximum=10, | |
value=3, | |
step=1.0, | |
label="Set number of children for Tree Summarization") | |
search_button = gr.Button("π Search") | |
gr.Markdown("## 4οΈ. ColPali Retrieval") | |
with gr.Row(equal_height=True): | |
output_text = gr.Textbox(label="Query result", | |
show_label=True, | |
placeholder="Response from query", | |
lines=8, | |
max_lines=20, | |
interactive=False) | |
output_imgs = gr.Gallery(label="Most relevant images is...", | |
show_fullscreen_button=True, | |
show_label=True, | |
show_download_button=True, | |
interactive=False) | |
# Action for search button | |
search_button.click( | |
search_with_llm, | |
inputs=[query, similarity_top_k, num_children], | |
outputs=[output_text, output_imgs]) | |
return demo | |
async def amain(): | |
global model_dict, retrievers | |
model_dict = await initialize_model() | |
retrievers = model_dict["retrievers_dict"] | |
demo = build_gui() | |
demo.queue().launch(debug=True, share=False) | |
if __name__ == "__main__": | |
asyncio.run(amain()) | |