Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import re | |
import json | |
import time | |
import wandb | |
import torch | |
import spaces | |
import psutil | |
import pymupdf | |
import gradio as gr | |
from qdrant_client import QdrantClient | |
from utils import download_pdf_from_gdrive, merge_strings_with_prefix | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, QuantoConfig | |
def rag_query(query: str): | |
""" | |
Allows searching the vector database which contains | |
information for a man named Suvaditya for a given query | |
by performing semantic search. Returns results by | |
looking at his resume, which contains a plethora of | |
information about him. | |
Args: | |
query: The query against which the search will be run, | |
in the form a single string phrase no more than | |
10 words. | |
Returns: | |
search_results: A list of results that come closest | |
to the given query semantically, | |
determined by Cosine Similarity. | |
""" | |
return client.query( | |
collection_name="resume", | |
query_text=query | |
) | |
def generate_answer(chat_history): | |
# Generate result | |
tool_prompt = tokenizer.apply_chat_template( | |
chat_history, | |
tools=[rag_query], | |
return_tensors="pt", | |
return_dict=True, | |
add_generation_prompt=True, | |
) | |
tool_prompt = tool_prompt.to(model.device) | |
out = model.generate( | |
**tool_prompt, | |
max_new_tokens=512, | |
do_sample=True, | |
top_p=0.95, | |
num_beams=4 | |
) | |
generated_text = out[0, tool_prompt['input_ids'].shape[1]:] | |
generated_text = tokenizer.decode(generated_text) | |
torch.cuda.empty_cache() | |
return generated_text | |
def parse_tool_request(tool_call, top_k=5): | |
pattern = r"<tool_call>(.*?)</tool_call>" | |
match_result = re.search(pattern, tool_call, re.DOTALL) | |
if match_result: | |
result = match_result.group(1).strip() | |
else: | |
return None, None | |
query = json.loads(result)["arguments"]["query"] | |
query_results = [ | |
query_piece.metadata["document"] for query_piece in rag_query(query) | |
] | |
return query_results[:top_k], query | |
def update_chat_history(chat_history, tool_query, query_results): | |
assistant_tool_message = { | |
"role": "assistant", | |
"metadata": "🛠️ Using Qdrant Engine to search for the query 🛠️", | |
"tool_calls": [{ | |
"type": "function", | |
"function": { | |
"name": "rag_query", | |
"arguments": {"query": f"{tool_query}"} | |
} | |
}] | |
} | |
result_tool_message = { | |
"role": "tool", | |
"name": "rag_query", | |
"content": "\n".join(query_results) | |
} | |
chat_history.append(assistant_tool_message) | |
chat_history.append(result_tool_message) | |
return chat_history | |
if __name__ == "__main__": | |
RESUME_PATH = os.path.join(os.getcwd(), "Resume.pdf") | |
RESUME_URL = "https://drive.google.com/file/d/1YMF9NNTG5gubwJ7ipI5JfxAJKhlD9h2v/" | |
# Download file | |
download_pdf_from_gdrive(RESUME_URL, RESUME_PATH) | |
doc = pymupdf.open(RESUME_PATH) | |
fulltext = doc[0].get_text().split("\n") | |
fulltext = merge_strings_with_prefix(fulltext) | |
# Embed the sentences | |
client = QdrantClient(":memory:", optimize_for_ram_usage=True) | |
client.set_model("sentence-transformers/all-MiniLM-L6-v2") | |
if not client.collection_exists(collection_name="resume"): | |
client.create_collection( | |
collection_name="resume", | |
vectors_config=client.get_fastembed_vector_params(), | |
) | |
_ = client.add( | |
collection_name="resume", | |
documents=fulltext, | |
ids=range(len(fulltext)), | |
batch_size=100, | |
parallel=0, | |
) | |
wandb.init(project="resume-rag", name="zerogpu-run") | |
model_name = "Qwen/Qwen2.5-3B-Instruct" | |
def rag_process(message, chat_history): | |
# Append current user message to chat history | |
current_message = { | |
"role": "user", | |
"content": message | |
} | |
chat_history.append(current_message) | |
start_time = time.time() | |
# Generate LLM answer | |
generated_text = generate_answer(chat_history) | |
# Detect if tool call is requested by LLM. If yes, then | |
# execute tool and use else return None | |
query_results, tool_query = parse_tool_request(generated_text) | |
# If tool call was requested | |
if query_results is not None and tool_query is not None: | |
# Update chat history with result of tool call | |
chat_history = update_chat_history( | |
chat_history, tool_query, query_results | |
) | |
# Generate result from the | |
generated_text = generate_answer(chat_history) | |
metrics = { | |
"conversation": { | |
"turn": len(chat_history) // 2, | |
"history": chat_history, | |
"current_question": message, | |
"current_answer": generated_text[:-10], | |
"tool_query": tool_query, | |
"rag_results": query_results | |
}, | |
"performance": { | |
"response_time": time.time() - start_time, | |
"gpu_memory_used": torch.cuda.memory_allocated() if torch.cuda.is_available() else 0, | |
"cpu_memory": psutil.Process().memory_info().rss, | |
"gpu_utilization": torch.cuda.utilization() if torch.cuda.is_available() else 0 | |
} | |
} | |
wandb.log(metrics) | |
wandb.finish() | |
return generated_text[:-10] | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
quantization_config=QuantoConfig( | |
weights="int8" | |
) | |
# quantization_config = BitsAndBytesConfig( | |
# load_in_4bit=True, | |
# # bnb_4bit_compute_dtype=torch.float16, | |
# # bnb_4bit_quant_type="nf4" | |
# ) | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
demo = gr.ChatInterface( | |
fn=rag_process, | |
type="messages", | |
title="Resume RAG, a personal space on ZeroGPU!", | |
examples=["Where did Suvaditya complete his Bachelor's Degree?", "Where is Suvaditya currently working?"], | |
description="Ask any question about Suvaditya's resume and get an answer!", | |
theme="ocean" | |
) | |
demo.launch() |