RAG-ColPali / llamaindex_utils.py
Huy
Fix bug
c0b2adc
raw
history blame
24.5 kB
import torch
import json
import asyncio
import qdrant_client
from PIL import Image
from pydantic import PrivateAttr, Field
from typing import Union, Optional, List, Any, Dict, Set
from dataclasses import dataclass
from llama_index.core.vector_stores.types import VectorStoreQueryResult
from llama_index.core.vector_stores.utils import (
legacy_metadata_dict_to_node,
metadata_dict_to_node,
)
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.retrievers import BaseRetriever
from llama_index.core import QueryBundle, PromptTemplate
from llama_index.core.schema import NodeWithScore, TextNode
from llama_index.core.llms import LLM
from llama_index.core.question_gen import LLMQuestionGenerator
from llama_index.core.tools import ToolMetadata
from llama_index.core.output_parsers.utils import parse_json_markdown
from llama_index.core.question_gen.types import SubQuestion
from models import ColPali, ColPaliProcessor
from prompt_templates import (DEFAULT_GEN_PROMPT_TMPL,
DEFAULT_FINAL_ANSWER_PROMPT_TMPL,
DEFAULT_SUB_QUESTION_PROMPT_TMPL,
DEFAULT_SYNTHESIZE_PROMPT_TMPL)
from typing import Any, List, Optional, Tuple, cast
from qdrant_client.http.models import Payload
from collections import defaultdict
def parse_to_query_result(response: List[Any]) -> VectorStoreQueryResult:
"""
Convert vector store response to VectorStoreQueryResult.
Args:
response: List[Any]: List of results returned from the vector store.
"""
nodes = []
similarities = []
ids = []
for point in response:
payload = cast(Payload, point.payload)
try:
node = metadata_dict_to_node(payload)
except Exception:
metadata, node_info, relationships = legacy_metadata_dict_to_node(
payload
)
node = TextNode(
id_=str(point.id),
text=payload.get("text"),
metadata=metadata,
start_char_idx=node_info.get("start", None),
end_char_idx=node_info.get("end", None),
relationships=relationships,
)
nodes.append(node)
ids.append(str(point.id))
try:
similarities.append(point.score)
except AttributeError:
# certain requests do not return a score
similarities.append(1.0)
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
class ColPaliGemmaEmbedding(BaseEmbedding):
_model: ColPali = PrivateAttr()
_processor: ColPaliProcessor = PrivateAttr()
device: Union[torch.device | str] = Field(default="cpu",
description="Device to use")
def __init__(self,
model: ColPali,
processor: ColPaliProcessor,
device: Optional[str] = 'cpu',
**kwargs):
super().__init__(device=device,
**kwargs)
self._model = model.to(device).eval()
self._processor = processor
@classmethod
def class_name(cls) -> str:
return "ColPaliGemmaEmbedding"
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding.
Args:
query (str): Query String
"""
with torch.no_grad():
processed_query = self._processor.process_queries([query])
processed_query = {k: v.to(self.device) for k, v in processed_query.items()}
query_embeddings = self._model(**processed_query)
return query_embeddings.to('cpu')[0]
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding.
Args:
text (str): Text String
"""
with torch.no_grad():
processed_query = self._processor.process_queries([text])
processed_query = {k: v.to(self.device) for k, v in processed_query.items()}
query_embeddings = self._model(**processed_query)
return query_embeddings.to('cpu')[0]
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings.
Args:
texts (List[str]): List of text string
"""
with torch.no_grad():
processed_queries = self._processor.process_queries(texts)
processed_query = {k: v.to(self.device) for k, v in processed_query.items()}
query_embeddings = self._model(**processed_queries)
return query_embeddings.to('cpu')
async def _aget_query_embedding(self, query: str) -> List[float]:
return self._get_query_embedding(query)
async def _aget_text_embedding(self, text: str) -> List[float]:
return self._get_text_embedding(text)
class ColPaliRetriever(BaseRetriever):
def __init__(self,
vector_store_client: Union[qdrant_client.QdrantClient | qdrant_client.AsyncQdrantClient],
target_collection: str,
embed_model: ColPaliGemmaEmbedding,
query_mode: str = 'default',
similarity_top_k: int = 3,
) -> None:
self._vector_store_client = vector_store_client
self._target_collection = target_collection
self._embed_model = embed_model
self._query_mode = query_mode
self._similarity_top_k = similarity_top_k
super().__init__()
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Get retrived nodes from the vector store by retriever given query string.
Args:
query_bundle (QueryBundle): QueryBundle class includes query string
Returns:
List[NodeWithScore]: List of retrieved nodes.
"""
if query_bundle.embedding is None:
query_embedding = self._embed_model._get_query_embedding(query_bundle.query_str)
else:
query_embedding = query_bundle.embedding
query_embedding = query_embedding.cpu().float().numpy().tolist()
# Get nodes from vector store
response = self._vector_store_client.query_points(collection_name=self._target_collection,
query=query_embedding,
limit=self._similarity_top_k).points
# Parse to structured output nodes
query_result = parse_to_query_result(response)
nodes_with_scores = []
for idx, node in enumerate(query_result.nodes):
score = None
if query_result.similarities is not None:
score = query_result.similarities[idx]
nodes_with_scores.append(NodeWithScore(node=node, score=score))
return nodes_with_scores
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Asynchronously get retrived nodes from the vector store by retriever given query string.
Args:
query_bundle (QueryBundle): QueryBundle class includes query string
Returns:
List[NodeWithScore]: List of retrieved nodes.
"""
if query_bundle.embedding is None:
query_embedding = await self._embed_model._aget_query_embedding(query_bundle.query_str)
else:
query_embedding = query_bundle.embedding
query_embedding = query_embedding.cpu().float().numpy().tolist()
# Get nodes from vector store
responses = await self._vector_store_client.query_points(collection_name=self._target_collection,
query=query_embedding,
limit=self._similarity_top_k)
responses = responses.points
# Parse to structured output nodes
query_result = parse_to_query_result(responses)
nodes_with_scores = []
for idx, node in enumerate(query_result.nodes):
score = None
if query_result.similarities is not None:
score = query_result.similarities[idx]
nodes_with_scores.append(NodeWithScore(node=node, score=score))
return nodes_with_scores
def fuse_results(retrieved_nodes: List[NodeWithScore], similarity_top_k: int) -> List[NodeWithScore]:
"""Fuse retrieved nodes using Reciprocal Rank
Args:
retrieved_nodes (List[NodeWithScore]): List of nodes.
similarity_top_k (int): get top K nodes.
Returns:
List[NodeWithScore]: List of nodes after fused
"""
k = 60.0
fused_scores = {}
text_to_node = {}
for rank, node_with_score in enumerate(sorted(retrieved_nodes, key=lambda x: x.score or 0.0, reverse=True)):
text = node_with_score.node.get_content(metadata_mode='all')
text_to_node[text] = node_with_score
fused_scores[text] = fused_scores.get(text, 0.0) + 1.0 / (rank + k)
# Sort results by calculated score
reranked_results = dict(sorted(fused_scores.items(), key=lambda x: x[1], reverse=True))
reranked_nodes: List[NodeWithScore] = []
for text, score in reranked_results.items():
reranked_nodes.append(text_to_node[text])
reranked_nodes[-1].score = score
return reranked_nodes[:similarity_top_k]
def generate_queries(llm: LLM, query: str, num_queries: int) -> List[str]:
"""Generate num_queries queries
Args:
llm (LLM): LLM model
query (str): query string
num_queries (int): Number of queries to generate
Returns:
generate_queries List[str]: List of generated queries
"""
query_prompt = PromptTemplate(DEFAULT_GEN_PROMPT_TMPL)
generate_queries = llm.predict(query_prompt,
num_queries=num_queries,
query=query)
generate_queries = generate_queries.split('\n')
return generate_queries
async def agenerate_queries(llm: LLM, query: str, num_queries: int):
"""Asynchronously generate num_queries queries
Args:
llm (LLM): LLM model
query (str): query string
num_queries (int): Number of queries to generate
Returns:
generate_queries List[str]: List of generated queries
"""
query_prompt = PromptTemplate(DEFAULT_GEN_PROMPT_TMPL)
generate_queries = await llm.apredict(query_prompt,
num_queries=num_queries,
query=query)
generate_queries = generate_queries.split('\n')
return generate_queries
# Tree Summarization
def synthesize_results(queries: List[SubQuestion], contexts: Dict[str, Set[str]], llm: LLM, num_children: int) -> Tuple[str, List[str]]:
"""Summarize the results generated from LLM.
Args:
queries (List[SubQuestion]): Generated results
contexts (Dict[str, Set[str]]): Dictionary maps context information string to its set of source images
llm (LLM): LLM Model
num_children (int): Number of children for Tree Summarization
Returns:
Tuple[str, List[str]]: Synthesized text, set of source images.
"""
qa_prompt = PromptTemplate(DEFAULT_SYNTHESIZE_PROMPT_TMPL)
new_contexts = defaultdict(set)
keys = list(contexts.keys())
for idx in range(0, len(keys), num_children):
contexts_batch = keys[idx: idx + num_children]
context_str = '\n\n'.join([f"{i + 1}. {text}" for i, text in enumerate(contexts_batch)])
fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str="\n".join([query.sub_question for query in queries]))
combined_result = llm.complete(fmt_qa_prompt)
# Parse json string to dictionary
json_dict = parse_json_markdown(str(combined_result))
if len(json_dict['choices']) > 0:
for choice in json_dict['choices']:
new_contexts[json_dict['summarized_text']] = new_contexts[json_dict['summarized_text']].union(contexts[contexts_batch[choice - 1]])
else:
new_contexts[json_dict['summarized_text']] = set()
if len(new_contexts) == 1:
synthesized_text = list(new_contexts.keys())[0]
return synthesized_text, list(new_contexts[synthesized_text])
else:
return synthesize_results(queries, new_contexts, llm, num_children=num_children)
async def asynthesize_results(queries: List[SubQuestion], contexts: Dict[str, Set[str]], llm: LLM, num_children: int) -> Union[str, List[str]]:
"""Asynchronously sumamarize the results generated from LLM.
Args:
queries (List[SubQuestion]): Generated results
contexts (Dict[str, Set[str]]): Dictionary maps context information string to its set of source images
llm (LLM): LLM Model
num_children (int): Number of children for Tree Summarization
Returns:
Tuple[str, List[str]]: Synthesized text, set of source images.
"""
qa_prompt = PromptTemplate(DEFAULT_SYNTHESIZE_PROMPT_TMPL)
fmt_qa_prompts = []
keys = list(contexts.keys())
contexts_batches = []
for idx in range(0, len(keys), num_children):
contexts_batch = keys[idx: idx + num_children]
context_str = '\n\n'.join([f"{idx + 1}. {text}" for idx, text in enumerate(contexts_batch)])
fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str="\n".join([query.sub_question for query in queries]))
fmt_qa_prompts.append(fmt_qa_prompt)
contexts_batches.append(contexts_batch)
tasks = []
async with asyncio.TaskGroup() as tg:
for fmt_qa_prompt in fmt_qa_prompts:
task = tg.create_task(llm.acomplete(fmt_qa_prompt))
tasks.append(task)
responses = [str(task.result()) for task in tasks]
new_contexts = defaultdict(set)
for idx, response in enumerate(responses):
# Parse json string to dictionary
json_dict = parse_json_markdown(response)
if len(json_dict["choices"]) > 0:
for choice in json_dict["choices"]:
new_contexts[json_dict["summarized_text"]] = new_contexts[json_dict["summarized_text"]].union(contexts[contexts_batches[idx][choice - 1]])
else:
new_contexts[json_dict["summarized_text"]] = set()
if len(new_contexts) == 1:
synthesized_text = list(new_contexts.keys())[0]
return synthesized_text, list(new_contexts[synthesized_text])
else:
return await asynthesize_results(queries, new_contexts, llm, num_children=num_children)
class CustomFusionRetriever(BaseRetriever):
def __init__(self,
llm,
retriever_mappings: Dict[str, BaseRetriever],
similarity_top_k: int = 3,
num_generated_queries = 3,
) -> None:
self._retriever_mappings = retriever_mappings
self._similarity_top_k = similarity_top_k
self._num_generated_queries = num_generated_queries
self._llm = llm
super().__init__()
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Retrieve self._similarity_top_k content nodes given query
Args:
query_bundle (QueryBundle): query bundle include query string
"""
# Get data from query bundle
query_dict = json.loads(query_bundle.query_str)
original_query = query_dict['sub_question']
tool_name = query_dict['tool_name']
# Rewrite original query to n queries
generated_queries = generate_queries(self._llm, original_query, num_queries=self._num_generated_queries)
# For each generated query, retrieve relevant nodes
retrieved_nodes = []
for query in generated_queries:
if len(query) == 0:
continue
retrieved_nodes.extend(self._retriever_mappings[tool_name].retrieve(query))
# Fuse retrieved nodes using reciprocal rank
fused_results = fuse_results(retrieved_nodes,
similarity_top_k=self._similarity_top_k)
return fused_results
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Asynchronously retrieve self._similarity_top_k content nodes given query
Args:
query_bundle (QueryBundle): query bundle include query string
"""
# Get data from query bundle
query_dict = json.loads(query_bundle.query_str)
original_query = query_dict['sub_question']
tool_name = query_dict['tool_name']
# Rewrite original query to n queries
generated_queries = await agenerate_queries(llm=self._llm, query=original_query, num_queries=self._num_generated_queries)
# For each generated query, retrieve relevant nodes
tasks = []
async with asyncio.TaskGroup() as tg:
for query in generated_queries:
if len(query) == 0:
continue
task = tg.create_task(self._retriever_mappings[tool_name].aretrieve(query))
tasks.append(task)
retrieved_nodes = [node for task in tasks for node in task.result()]
# Fuse retrieved nodes using reciprocal rank
fused_results = fuse_results(retrieved_nodes,
similarity_top_k=self._similarity_top_k)
return fused_results
@dataclass
class Response:
response: str
source_images: Optional[List] = None
def __str__(self):
return self.response
class CustomQueryEngine:
def __init__(self,
retriever_tools: List[ToolMetadata],
fusion_retriever: BaseRetriever,
qa_prompt: PromptTemplate = None,
llm: LLM = None,
num_children: int = 3):
self._qa_prompt = qa_prompt if qa_prompt else PromptTemplate(DEFAULT_FINAL_ANSWER_PROMPT_TMPL)
self._llm = llm
self._num_children = num_children
self._sub_question_generator = LLMQuestionGenerator.from_defaults(llm=self._llm,
prompt_template_str=DEFAULT_SUB_QUESTION_PROMPT_TMPL)
self._fusion_retriever = fusion_retriever
self._retriever_tools = retriever_tools
def query(self, query_str: str) -> Response:
# Generate sub queries
sub_queries = self._sub_question_generator.generate(tools=self._retriever_tools,
query=QueryBundle(query_str=query_str))
if len(sub_queries) == 0:
response_template = PromptTemplate("Cannot answer the query: {query_str}")
return Response(response=response_template.format(query_str=query_str), source_images=[])
else:
# Dictionary to map response -> source_images
response2images_mapping = defaultdict(set)
# For each sub queries retrieve relevant image nodes
# With fusion retriever, each sub query is rewritten to n queries -> retrieve relevant nodes for each generated query
# -> fuse all nodes retrieved from multiple generated queries using reciprocal rank -> get top k results
for sub_query in sub_queries:
retrieved_nodes = self._fusion_retriever.retrieve(QueryBundle(query_str=sub_query.model_dump_json()))
# Using LLM to get the answer for sub query from retrieved nodes
for retrieved_node in retrieved_nodes:
response2images_mapping[str(self._llm.complete([sub_query.sub_question, Image.open(retrieved_node.node.resolve_image())]))].add(retrieved_node.node.image)
# Synthesize results
synthesized_text, source_images = synthesize_results(queries=sub_queries,
contexts=response2images_mapping,
llm=self._llm,
num_children=self._num_children)
final_answer = self._llm.predict(self._qa_prompt,
context_str=synthesized_text,
query_str=query_str)
response_template = PromptTemplate("Retrieved Information:\n"
"------------------------\n"
"{retrieved_information}\n"
"-------------------------\n\n"
"Answer:\n"
"{final_answer}")
return Response(response=response_template.format(retrieved_information=synthesized_text, final_answer=final_answer), source_images=source_images)
async def aquery(self, query_str: str):
sub_queries = await self._sub_question_generator.agenerate(tools=self._retriever_tools,
query=QueryBundle(query_str=query_str))
if len(sub_queries) == 0:
response_template = PromptTemplate("Cannot answer the query: {query_str}")
return Response(response=response_template.format(query_str=query_str), source_images=[])
else:
retrieved_subquestion_nodes = []
async with asyncio.TaskGroup() as tg:
for sub_query in sub_queries:
task = tg.create_task(self._fusion_retriever.aretrieve(QueryBundle(query_str=sub_query.model_dump_json())))
retrieved_subquestion_nodes.append([sub_query.sub_question, task])
retrieved_subquestion_nodes = [[sub_question, task.result()] for sub_question, task in retrieved_subquestion_nodes]
answers = []
# For each sub queries retrieve relevant image nodes
# With fusion retriever, each sub query is rewritten to n queries -> retrieve relevant nodes for each generated query
# -> fuse all nodes retrieved from multiple generated queries using reciprocal rank -> get top k results
async with asyncio.TaskGroup() as tg:
for sub_question, retrieved_nodes in retrieved_subquestion_nodes:
for retrieved_node in retrieved_nodes:
task = tg.create_task(self._llm.acomplete([sub_question, Image.open(retrieved_node.node.resolve_image())]))
answers.append([task, retrieved_node.node.image])
# Dictionary to map response -> source_images
response2images_mapping = defaultdict(set)
for task, image in answers:
response2images_mapping[str(task.result())].add(image)
# Synthesize results
synthesized_text, source_images = await asynthesize_results(queries=sub_queries,
contexts=response2images_mapping,
llm=self._llm,
num_children=self._num_children)
final_answer = await self._llm.apredict(self._qa_prompt,
context_str=synthesized_text,
query_str=query_str)
response_template = PromptTemplate("Retrieved Information:\n"
"------------------------\n"
"{retrieved_information}\n"
"-------------------------\n\n"
"Answer:\n"
"{final_answer}")
return Response(response=response_template.format(retrieved_information=synthesized_text, final_answer=final_answer), source_images=source_images)