chat-with-pdf / rag.py
JaiSurya's picture
Initial app setup
9f493b6
raw
history blame contribute delete
No virus
3.9 kB
# this python file contains all steps from the retrieval to generation code
import torch
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer,util
from transformers import AutoTokenizer , AutoModelForCausalLM
class RAG:
def __init__(self):
self.model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.embedding_model_name = "all-mpnet-base-v2"
self.embeddings_filename = "embeddings.csv"
self.data_pd = pd.read_csv(self.embeddings_filename)
self.data_dict = pd.read_csv(self.embeddings_filename).to_dict(orient='records')
self.data_embeddings = self.get_embeddings()
self.embedding_model = SentenceTransformer(model_name_or_path = self.embedding_model_name,device = self.device)
# Tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
# LLM
self.llm_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=self.model_id,
torch_dtype=torch.float16).to(self.device)
def get_embeddings(self) -> list:
"""Returns the embeddings from the csv file"""
data_embeddings = []
for tensor_str in self.data_pd["embeddings"]:
values_str = tensor_str.split("[")[1].split("]")[0]
values_list = [float(val) for val in values_str.split(",")]
tensor_result = torch.tensor(values_list)
data_embeddings.append(tensor_result)
data_embeddings = torch.stack(data_embeddings).to(self.device)
return data_embeddings
def retrieve_relevant_resource(self,user_query : str , k = 5):
"""Function to retrieve relevant resource"""
query_embedding = self.embedding_model.encode(user_query, convert_to_tensor = True).to(self.device)
dot_score = util.dot_score( a = query_embedding, b = self.data_embeddings)[0]
score , idx = torch.topk(dot_score,k=k)
return score,idx
def prompt_formatter(self,query: str, context_items: list[dict]) -> str:
"""
Augments query with text-based context from context_items.
"""
# Join context items into one dotted paragraph
context = "- " + "\n- ".join([item["sentence_chunk"] for item in context_items])
base_prompt = """Based on the following context items, please answer the query.
\nNow use the following context items to answer the user query:
{context}
\nRelevant passages: <extract relevant passages from the context here>
User query: {query}
Answer:"""
# Update base prompt with context items and query
base_prompt = base_prompt.format(context=context, query=query)
# Create prompt template for instruction-tuned model
dialogue_template = [
{"role": "user",
"content": base_prompt}
]
# Apply the chat template
prompt = self.tokenizer.apply_chat_template(conversation=dialogue_template,
tokenize=False,
add_generation_prompt=True)
return prompt
def query(self,user_text : str):
scores, indices = self.retrieve_relevant_resource(user_text)
context_items = [self.data_dict[i] for i in indices]
prompt = self.prompt_formatter(query=user_text,context_items=context_items)
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
outputs = self.llm_model.generate(**input_ids,max_new_tokens=256)
output_text = self.tokenizer.decode(outputs[0])
output_text = output_text.split("<|assistant|>")
output_text = output_text[1].split("</s>")[0]
return output_text