resume-rag / app.py
suvadityamuk's picture
chore: made optims
ad11fe1
raw
history blame
6.43 kB
import os
import re
import json
import time
import wandb
import torch
import spaces
import psutil
import pymupdf
import gradio as gr
from qdrant_client import QdrantClient
from utils import download_pdf_from_gdrive, merge_strings_with_prefix
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, QuantoConfig
def rag_query(query: str):
"""
Allows searching the vector database which contains
information for a man named Suvaditya for a given query
by performing semantic search. Returns results by
looking at his resume, which contains a plethora of
information about him.
Args:
query: The query against which the search will be run,
in the form a single string phrase no more than
10 words.
Returns:
search_results: A list of results that come closest
to the given query semantically,
determined by Cosine Similarity.
"""
return client.query(
collection_name="resume",
query_text=query
)
def generate_answer(chat_history):
# Generate result
tool_prompt = tokenizer.apply_chat_template(
chat_history,
tools=[rag_query],
return_tensors="pt",
return_dict=True,
add_generation_prompt=True,
)
tool_prompt = tool_prompt.to(model.device)
out = model.generate(
**tool_prompt,
max_new_tokens=512,
do_sample=True,
top_p=0.95,
num_beams=4
)
generated_text = out[0, tool_prompt['input_ids'].shape[1]:]
generated_text = tokenizer.decode(generated_text)
torch.cuda.empty_cache()
return generated_text
def parse_tool_request(tool_call, top_k=5):
pattern = r"<tool_call>(.*?)</tool_call>"
match_result = re.search(pattern, tool_call, re.DOTALL)
if match_result:
result = match_result.group(1).strip()
else:
return None, None
query = json.loads(result)["arguments"]["query"]
query_results = [
query_piece.metadata["document"] for query_piece in rag_query(query)
]
return query_results[:top_k], query
def update_chat_history(chat_history, tool_query, query_results):
assistant_tool_message = {
"role": "assistant",
"metadata": "🛠️ Using Qdrant Engine to search for the query 🛠️",
"tool_calls": [{
"type": "function",
"function": {
"name": "rag_query",
"arguments": {"query": f"{tool_query}"}
}
}]
}
result_tool_message = {
"role": "tool",
"name": "rag_query",
"content": "\n".join(query_results)
}
chat_history.append(assistant_tool_message)
chat_history.append(result_tool_message)
return chat_history
if __name__ == "__main__":
RESUME_PATH = os.path.join(os.getcwd(), "Resume.pdf")
RESUME_URL = "https://drive.google.com/file/d/1YMF9NNTG5gubwJ7ipI5JfxAJKhlD9h2v/"
# Download file
download_pdf_from_gdrive(RESUME_URL, RESUME_PATH)
doc = pymupdf.open(RESUME_PATH)
fulltext = doc[0].get_text().split("\n")
fulltext = merge_strings_with_prefix(fulltext)
# Embed the sentences
client = QdrantClient(":memory:", optimize_for_ram_usage=True)
client.set_model("sentence-transformers/all-MiniLM-L6-v2")
if not client.collection_exists(collection_name="resume"):
client.create_collection(
collection_name="resume",
vectors_config=client.get_fastembed_vector_params(),
)
_ = client.add(
collection_name="resume",
documents=fulltext,
ids=range(len(fulltext)),
batch_size=100,
parallel=0,
)
wandb.init(project="resume-rag", name="zerogpu-run")
model_name = "Qwen/Qwen2.5-3B-Instruct"
@spaces.GPU
def rag_process(message, chat_history):
# Append current user message to chat history
current_message = {
"role": "user",
"content": message
}
chat_history.append(current_message)
start_time = time.time()
# Generate LLM answer
generated_text = generate_answer(chat_history)
# Detect if tool call is requested by LLM. If yes, then
# execute tool and use else return None
query_results, tool_query = parse_tool_request(generated_text)
# If tool call was requested
if query_results is not None and tool_query is not None:
# Update chat history with result of tool call
chat_history = update_chat_history(
chat_history, tool_query, query_results
)
# Generate result from the
generated_text = generate_answer(chat_history)
metrics = {
"conversation": {
"turn": len(chat_history) // 2,
"history": chat_history,
"current_question": message,
"current_answer": generated_text[:-10],
"tool_query": tool_query,
"rag_results": query_results
},
"performance": {
"response_time": time.time() - start_time,
"gpu_memory_used": torch.cuda.memory_allocated() if torch.cuda.is_available() else 0,
"cpu_memory": psutil.Process().memory_info().rss,
"gpu_utilization": torch.cuda.utilization() if torch.cuda.is_available() else 0
}
}
wandb.log(metrics)
wandb.finish()
return generated_text[:-10]
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=QuantoConfig(
weights="int8"
)
# quantization_config = BitsAndBytesConfig(
# load_in_4bit=True,
# # bnb_4bit_compute_dtype=torch.float16,
# # bnb_4bit_quant_type="nf4"
# )
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
demo = gr.ChatInterface(
fn=rag_process,
type="messages",
title="Resume RAG, a personal space on ZeroGPU!",
examples=["Where did Suvaditya complete his Bachelor's Degree?", "Where is Suvaditya currently working?"],
description="Ask any question about Suvaditya's resume and get an answer!",
theme="ocean"
)
demo.launch()