Spaces:
Sleeping
Sleeping
import torch | |
import json | |
import asyncio | |
import qdrant_client | |
from PIL import Image | |
from pydantic import PrivateAttr, Field | |
from typing import Union, Optional, List, Any, Dict, Set | |
from dataclasses import dataclass | |
from llama_index.core.vector_stores.types import VectorStoreQueryResult | |
from llama_index.core.vector_stores.utils import ( | |
legacy_metadata_dict_to_node, | |
metadata_dict_to_node, | |
) | |
from llama_index.core.embeddings import BaseEmbedding | |
from llama_index.core.retrievers import BaseRetriever | |
from llama_index.core import QueryBundle, PromptTemplate | |
from llama_index.core.schema import NodeWithScore, TextNode | |
from llama_index.core.llms import LLM | |
from llama_index.core.question_gen import LLMQuestionGenerator | |
from llama_index.core.tools import ToolMetadata | |
from llama_index.core.output_parsers.utils import parse_json_markdown | |
from llama_index.core.question_gen.types import SubQuestion | |
from models import ColPali, ColPaliProcessor | |
from prompt_templates import (DEFAULT_GEN_PROMPT_TMPL, | |
DEFAULT_FINAL_ANSWER_PROMPT_TMPL, | |
DEFAULT_SUB_QUESTION_PROMPT_TMPL, | |
DEFAULT_SYNTHESIZE_PROMPT_TMPL) | |
from typing import Any, List, Optional, Tuple, cast | |
from qdrant_client.http.models import Payload | |
from collections import defaultdict | |
def parse_to_query_result(response: List[Any]) -> VectorStoreQueryResult: | |
""" | |
Convert vector store response to VectorStoreQueryResult. | |
Args: | |
response: List[Any]: List of results returned from the vector store. | |
""" | |
nodes = [] | |
similarities = [] | |
ids = [] | |
for point in response: | |
payload = cast(Payload, point.payload) | |
try: | |
node = metadata_dict_to_node(payload) | |
except Exception: | |
metadata, node_info, relationships = legacy_metadata_dict_to_node( | |
payload | |
) | |
node = TextNode( | |
id_=str(point.id), | |
text=payload.get("text"), | |
metadata=metadata, | |
start_char_idx=node_info.get("start", None), | |
end_char_idx=node_info.get("end", None), | |
relationships=relationships, | |
) | |
nodes.append(node) | |
ids.append(str(point.id)) | |
try: | |
similarities.append(point.score) | |
except AttributeError: | |
# certain requests do not return a score | |
similarities.append(1.0) | |
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids) | |
class ColPaliGemmaEmbedding(BaseEmbedding): | |
_model: ColPali = PrivateAttr() | |
_processor: ColPaliProcessor = PrivateAttr() | |
device: Union[torch.device | str] = Field(default="cpu", | |
description="Device to use") | |
def __init__(self, | |
model: ColPali, | |
processor: ColPaliProcessor, | |
device: Optional[str] = 'cpu', | |
**kwargs): | |
super().__init__(device=device, | |
**kwargs) | |
self._model = model.to(device).eval() | |
self._processor = processor | |
def class_name(cls) -> str: | |
return "ColPaliGemmaEmbedding" | |
def _get_query_embedding(self, query: str) -> List[float]: | |
"""Get query embedding. | |
Args: | |
query (str): Query String | |
""" | |
with torch.no_grad(): | |
processed_query = self._processor.process_queries([query]) | |
processed_query = {k: v.to(self.device) for k, v in processed_query.items()} | |
query_embeddings = self._model(**processed_query) | |
return query_embeddings.to('cpu')[0] | |
def _get_text_embedding(self, text: str) -> List[float]: | |
"""Get text embedding. | |
Args: | |
text (str): Text String | |
""" | |
with torch.no_grad(): | |
processed_query = self._processor.process_queries([text]) | |
processed_query = {k: v.to(self.device) for k, v in processed_query.items()} | |
query_embeddings = self._model(**processed_query) | |
return query_embeddings.to('cpu')[0] | |
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: | |
"""Get text embeddings. | |
Args: | |
texts (List[str]): List of text string | |
""" | |
with torch.no_grad(): | |
processed_queries = self._processor.process_queries(texts) | |
processed_query = {k: v.to(self.device) for k, v in processed_query.items()} | |
query_embeddings = self._model(**processed_queries) | |
return query_embeddings.to('cpu') | |
async def _aget_query_embedding(self, query: str) -> List[float]: | |
return self._get_query_embedding(query) | |
async def _aget_text_embedding(self, text: str) -> List[float]: | |
return self._get_text_embedding(text) | |
class ColPaliRetriever(BaseRetriever): | |
def __init__(self, | |
vector_store_client: Union[qdrant_client.QdrantClient | qdrant_client.AsyncQdrantClient], | |
target_collection: str, | |
embed_model: ColPaliGemmaEmbedding, | |
query_mode: str = 'default', | |
similarity_top_k: int = 3, | |
) -> None: | |
self._vector_store_client = vector_store_client | |
self._target_collection = target_collection | |
self._embed_model = embed_model | |
self._query_mode = query_mode | |
self._similarity_top_k = similarity_top_k | |
super().__init__() | |
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: | |
"""Get retrived nodes from the vector store by retriever given query string. | |
Args: | |
query_bundle (QueryBundle): QueryBundle class includes query string | |
Returns: | |
List[NodeWithScore]: List of retrieved nodes. | |
""" | |
if query_bundle.embedding is None: | |
query_embedding = self._embed_model._get_query_embedding(query_bundle.query_str) | |
else: | |
query_embedding = query_bundle.embedding | |
query_embedding = query_embedding.cpu().float().numpy().tolist() | |
# Get nodes from vector store | |
response = self._vector_store_client.query_points(collection_name=self._target_collection, | |
query=query_embedding, | |
limit=self._similarity_top_k).points | |
# Parse to structured output nodes | |
query_result = parse_to_query_result(response) | |
nodes_with_scores = [] | |
for idx, node in enumerate(query_result.nodes): | |
score = None | |
if query_result.similarities is not None: | |
score = query_result.similarities[idx] | |
nodes_with_scores.append(NodeWithScore(node=node, score=score)) | |
return nodes_with_scores | |
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: | |
"""Asynchronously get retrived nodes from the vector store by retriever given query string. | |
Args: | |
query_bundle (QueryBundle): QueryBundle class includes query string | |
Returns: | |
List[NodeWithScore]: List of retrieved nodes. | |
""" | |
if query_bundle.embedding is None: | |
query_embedding = await self._embed_model._aget_query_embedding(query_bundle.query_str) | |
else: | |
query_embedding = query_bundle.embedding | |
query_embedding = query_embedding.cpu().float().numpy().tolist() | |
# Get nodes from vector store | |
responses = await self._vector_store_client.query_points(collection_name=self._target_collection, | |
query=query_embedding, | |
limit=self._similarity_top_k) | |
responses = responses.points | |
# Parse to structured output nodes | |
query_result = parse_to_query_result(responses) | |
nodes_with_scores = [] | |
for idx, node in enumerate(query_result.nodes): | |
score = None | |
if query_result.similarities is not None: | |
score = query_result.similarities[idx] | |
nodes_with_scores.append(NodeWithScore(node=node, score=score)) | |
return nodes_with_scores | |
def fuse_results(retrieved_nodes: List[NodeWithScore], similarity_top_k: int) -> List[NodeWithScore]: | |
"""Fuse retrieved nodes using Reciprocal Rank | |
Args: | |
retrieved_nodes (List[NodeWithScore]): List of nodes. | |
similarity_top_k (int): get top K nodes. | |
Returns: | |
List[NodeWithScore]: List of nodes after fused | |
""" | |
k = 60.0 | |
fused_scores = {} | |
text_to_node = {} | |
for rank, node_with_score in enumerate(sorted(retrieved_nodes, key=lambda x: x.score or 0.0, reverse=True)): | |
text = node_with_score.node.get_content(metadata_mode='all') | |
text_to_node[text] = node_with_score | |
fused_scores[text] = fused_scores.get(text, 0.0) + 1.0 / (rank + k) | |
# Sort results by calculated score | |
reranked_results = dict(sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)) | |
reranked_nodes: List[NodeWithScore] = [] | |
for text, score in reranked_results.items(): | |
reranked_nodes.append(text_to_node[text]) | |
reranked_nodes[-1].score = score | |
return reranked_nodes[:similarity_top_k] | |
def generate_queries(llm: LLM, query: str, num_queries: int) -> List[str]: | |
"""Generate num_queries queries | |
Args: | |
llm (LLM): LLM model | |
query (str): query string | |
num_queries (int): Number of queries to generate | |
Returns: | |
generate_queries List[str]: List of generated queries | |
""" | |
query_prompt = PromptTemplate(DEFAULT_GEN_PROMPT_TMPL) | |
generate_queries = llm.predict(query_prompt, | |
num_queries=num_queries, | |
query=query) | |
generate_queries = generate_queries.split('\n') | |
return generate_queries | |
async def agenerate_queries(llm: LLM, query: str, num_queries: int): | |
"""Asynchronously generate num_queries queries | |
Args: | |
llm (LLM): LLM model | |
query (str): query string | |
num_queries (int): Number of queries to generate | |
Returns: | |
generate_queries List[str]: List of generated queries | |
""" | |
query_prompt = PromptTemplate(DEFAULT_GEN_PROMPT_TMPL) | |
generate_queries = await llm.apredict(query_prompt, | |
num_queries=num_queries, | |
query=query) | |
generate_queries = generate_queries.split('\n') | |
return generate_queries | |
# Tree Summarization | |
def synthesize_results(queries: List[SubQuestion], contexts: Dict[str, Set[str]], llm: LLM, num_children: int) -> Tuple[str, List[str]]: | |
"""Summarize the results generated from LLM. | |
Args: | |
queries (List[SubQuestion]): Generated results | |
contexts (Dict[str, Set[str]]): Dictionary maps context information string to its set of source images | |
llm (LLM): LLM Model | |
num_children (int): Number of children for Tree Summarization | |
Returns: | |
Tuple[str, List[str]]: Synthesized text, set of source images. | |
""" | |
qa_prompt = PromptTemplate(DEFAULT_SYNTHESIZE_PROMPT_TMPL) | |
new_contexts = defaultdict(set) | |
keys = list(contexts.keys()) | |
for idx in range(0, len(keys), num_children): | |
contexts_batch = keys[idx: idx + num_children] | |
context_str = '\n\n'.join([f"{i + 1}. {text}" for i, text in enumerate(contexts_batch)]) | |
fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str="\n".join([query.sub_question for query in queries])) | |
combined_result = llm.complete(fmt_qa_prompt) | |
# Parse json string to dictionary | |
json_dict = parse_json_markdown(str(combined_result)) | |
if len(json_dict['choices']) > 0: | |
for choice in json_dict['choices']: | |
new_contexts[json_dict['summarized_text']] = new_contexts[json_dict['summarized_text']].union(contexts[contexts_batch[choice - 1]]) | |
else: | |
new_contexts[json_dict['summarized_text']] = set() | |
if len(new_contexts) == 1: | |
synthesized_text = list(new_contexts.keys())[0] | |
return synthesized_text, list(new_contexts[synthesized_text]) | |
else: | |
return synthesize_results(queries, new_contexts, llm, num_children=num_children) | |
async def asynthesize_results(queries: List[SubQuestion], contexts: Dict[str, Set[str]], llm: LLM, num_children: int) -> Union[str, List[str]]: | |
"""Asynchronously sumamarize the results generated from LLM. | |
Args: | |
queries (List[SubQuestion]): Generated results | |
contexts (Dict[str, Set[str]]): Dictionary maps context information string to its set of source images | |
llm (LLM): LLM Model | |
num_children (int): Number of children for Tree Summarization | |
Returns: | |
Tuple[str, List[str]]: Synthesized text, set of source images. | |
""" | |
qa_prompt = PromptTemplate(DEFAULT_SYNTHESIZE_PROMPT_TMPL) | |
fmt_qa_prompts = [] | |
keys = list(contexts.keys()) | |
contexts_batches = [] | |
for idx in range(0, len(keys), num_children): | |
contexts_batch = keys[idx: idx + num_children] | |
context_str = '\n\n'.join([f"{idx + 1}. {text}" for idx, text in enumerate(contexts_batch)]) | |
fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str="\n".join([query.sub_question for query in queries])) | |
fmt_qa_prompts.append(fmt_qa_prompt) | |
contexts_batches.append(contexts_batch) | |
tasks = [] | |
async with asyncio.TaskGroup() as tg: | |
for fmt_qa_prompt in fmt_qa_prompts: | |
task = tg.create_task(llm.acomplete(fmt_qa_prompt)) | |
tasks.append(task) | |
responses = [str(task.result()) for task in tasks] | |
new_contexts = defaultdict(set) | |
for idx, response in enumerate(responses): | |
# Parse json string to dictionary | |
json_dict = parse_json_markdown(response) | |
if len(json_dict["choices"]) > 0: | |
for choice in json_dict["choices"]: | |
new_contexts[json_dict["summarized_text"]] = new_contexts[json_dict["summarized_text"]].union(contexts[contexts_batches[idx][choice - 1]]) | |
else: | |
new_contexts[json_dict["summarized_text"]] = set() | |
if len(new_contexts) == 1: | |
synthesized_text = list(new_contexts.keys())[0] | |
return synthesized_text, list(new_contexts[synthesized_text]) | |
else: | |
return await asynthesize_results(queries, new_contexts, llm, num_children=num_children) | |
class CustomFusionRetriever(BaseRetriever): | |
def __init__(self, | |
llm, | |
retriever_mappings: Dict[str, BaseRetriever], | |
similarity_top_k: int = 3, | |
num_generated_queries = 3, | |
) -> None: | |
self._retriever_mappings = retriever_mappings | |
self._similarity_top_k = similarity_top_k | |
self._num_generated_queries = num_generated_queries | |
self._llm = llm | |
super().__init__() | |
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: | |
"""Retrieve self._similarity_top_k content nodes given query | |
Args: | |
query_bundle (QueryBundle): query bundle include query string | |
""" | |
# Get data from query bundle | |
query_dict = json.loads(query_bundle.query_str) | |
original_query = query_dict['sub_question'] | |
tool_name = query_dict['tool_name'] | |
# Rewrite original query to n queries | |
generated_queries = generate_queries(self._llm, original_query, num_queries=self._num_generated_queries) | |
# For each generated query, retrieve relevant nodes | |
retrieved_nodes = [] | |
for query in generated_queries: | |
if len(query) == 0: | |
continue | |
retrieved_nodes.extend(self._retriever_mappings[tool_name].retrieve(query)) | |
# Fuse retrieved nodes using reciprocal rank | |
fused_results = fuse_results(retrieved_nodes, | |
similarity_top_k=self._similarity_top_k) | |
return fused_results | |
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: | |
"""Asynchronously retrieve self._similarity_top_k content nodes given query | |
Args: | |
query_bundle (QueryBundle): query bundle include query string | |
""" | |
# Get data from query bundle | |
query_dict = json.loads(query_bundle.query_str) | |
original_query = query_dict['sub_question'] | |
tool_name = query_dict['tool_name'] | |
# Rewrite original query to n queries | |
generated_queries = await agenerate_queries(llm=self._llm, query=original_query, num_queries=self._num_generated_queries) | |
# For each generated query, retrieve relevant nodes | |
tasks = [] | |
async with asyncio.TaskGroup() as tg: | |
for query in generated_queries: | |
if len(query) == 0: | |
continue | |
task = tg.create_task(self._retriever_mappings[tool_name].aretrieve(query)) | |
tasks.append(task) | |
retrieved_nodes = [node for task in tasks for node in task.result()] | |
# Fuse retrieved nodes using reciprocal rank | |
fused_results = fuse_results(retrieved_nodes, | |
similarity_top_k=self._similarity_top_k) | |
return fused_results | |
class Response: | |
response: str | |
source_images: Optional[List] = None | |
def __str__(self): | |
return self.response | |
class CustomQueryEngine: | |
def __init__(self, | |
retriever_tools: List[ToolMetadata], | |
fusion_retriever: BaseRetriever, | |
qa_prompt: PromptTemplate = None, | |
llm: LLM = None, | |
num_children: int = 3): | |
self._qa_prompt = qa_prompt if qa_prompt else PromptTemplate(DEFAULT_FINAL_ANSWER_PROMPT_TMPL) | |
self._llm = llm | |
self._num_children = num_children | |
self._sub_question_generator = LLMQuestionGenerator.from_defaults(llm=self._llm, | |
prompt_template_str=DEFAULT_SUB_QUESTION_PROMPT_TMPL) | |
self._fusion_retriever = fusion_retriever | |
self._retriever_tools = retriever_tools | |
def query(self, query_str: str) -> Response: | |
# Generate sub queries | |
sub_queries = self._sub_question_generator.generate(tools=self._retriever_tools, | |
query=QueryBundle(query_str=query_str)) | |
if len(sub_queries) == 0: | |
response_template = PromptTemplate("Cannot answer the query: {query_str}") | |
return Response(response=response_template.format(query_str=query_str), source_images=[]) | |
else: | |
# Dictionary to map response -> source_images | |
response2images_mapping = defaultdict(set) | |
# For each sub queries retrieve relevant image nodes | |
# With fusion retriever, each sub query is rewritten to n queries -> retrieve relevant nodes for each generated query | |
# -> fuse all nodes retrieved from multiple generated queries using reciprocal rank -> get top k results | |
for sub_query in sub_queries: | |
retrieved_nodes = self._fusion_retriever.retrieve(QueryBundle(query_str=sub_query.model_dump_json())) | |
# Using LLM to get the answer for sub query from retrieved nodes | |
for retrieved_node in retrieved_nodes: | |
response2images_mapping[str(self._llm.complete([sub_query.sub_question, Image.open(retrieved_node.node.resolve_image())]))].add(retrieved_node.node.image) | |
# Synthesize results | |
synthesized_text, source_images = synthesize_results(queries=sub_queries, | |
contexts=response2images_mapping, | |
llm=self._llm, | |
num_children=self._num_children) | |
final_answer = self._llm.predict(self._qa_prompt, | |
context_str=synthesized_text, | |
query_str=query_str) | |
response_template = PromptTemplate("Retrieved Information:\n" | |
"------------------------\n" | |
"{retrieved_information}\n" | |
"-------------------------\n\n" | |
"Answer:\n" | |
"{final_answer}") | |
return Response(response=response_template.format(retrieved_information=synthesized_text, final_answer=final_answer), source_images=source_images) | |
async def aquery(self, query_str: str): | |
sub_queries = await self._sub_question_generator.agenerate(tools=self._retriever_tools, | |
query=QueryBundle(query_str=query_str)) | |
if len(sub_queries) == 0: | |
response_template = PromptTemplate("Cannot answer the query: {query_str}") | |
return Response(response=response_template.format(query_str=query_str), source_images=[]) | |
else: | |
retrieved_subquestion_nodes = [] | |
async with asyncio.TaskGroup() as tg: | |
for sub_query in sub_queries: | |
task = tg.create_task(self._fusion_retriever.aretrieve(QueryBundle(query_str=sub_query.model_dump_json()))) | |
retrieved_subquestion_nodes.append([sub_query.sub_question, task]) | |
retrieved_subquestion_nodes = [[sub_question, task.result()] for sub_question, task in retrieved_subquestion_nodes] | |
answers = [] | |
# For each sub queries retrieve relevant image nodes | |
# With fusion retriever, each sub query is rewritten to n queries -> retrieve relevant nodes for each generated query | |
# -> fuse all nodes retrieved from multiple generated queries using reciprocal rank -> get top k results | |
async with asyncio.TaskGroup() as tg: | |
for sub_question, retrieved_nodes in retrieved_subquestion_nodes: | |
for retrieved_node in retrieved_nodes: | |
task = tg.create_task(self._llm.acomplete([sub_question, Image.open(retrieved_node.node.resolve_image())])) | |
answers.append([task, retrieved_node.node.image]) | |
# Dictionary to map response -> source_images | |
response2images_mapping = defaultdict(set) | |
for task, image in answers: | |
response2images_mapping[str(task.result())].add(image) | |
# Synthesize results | |
synthesized_text, source_images = await asynthesize_results(queries=sub_queries, | |
contexts=response2images_mapping, | |
llm=self._llm, | |
num_children=self._num_children) | |
final_answer = await self._llm.apredict(self._qa_prompt, | |
context_str=synthesized_text, | |
query_str=query_str) | |
response_template = PromptTemplate("Retrieved Information:\n" | |
"------------------------\n" | |
"{retrieved_information}\n" | |
"-------------------------\n\n" | |
"Answer:\n" | |
"{final_answer}") | |
return Response(response=response_template.format(retrieved_information=synthesized_text, final_answer=final_answer), source_images=source_images) | |