import gradio as gr import spaces import torch from datasets import load_dataset from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForCausalLM import os import lancedb os.environ["HF_TOKEN"] = os.getenv("auth") db = lancedb.connect("embedding_dataset") tbl = db.open_table("my_table") embedding_model = SentenceTransformer(model_name_or_path="all-mpnet-base-v2", device="cuda") tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it") model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16, device_map="auto") @spaces.GPU() def process_query(query): query_embedding = embedding_model.encode(query) search_hits = tbl.search(query_embedding).metric("cosine").limit(5).to_list() context = search_hits[0]["text"] url = search_hits[0]["url"] print(url) input_text = ( f"You are being provided a query: {query}" f"YOu are being provided context to the query: {context}" "Please provide a detailed and contextually relevant response." ) input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") len_text = len(input_text) with torch.inference_mode(): generated_outputs = model.generate(**input_ids, max_new_tokens=1000, do_sample=False) generated_outputs = tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) response = generated_outputs[0][len_text:] return url, response # demo = gr.Interface( # fn=process_query, # inputs=gr.Textbox(label="User Query"), # outputs=[gr.Textbox(label="URL"), gr.Textbox(label="Generated Response")] # ) # demo.launch() demo = gr.Blocks() with demo: gr.Markdown("# RAG on PyImageSearch blog posts") gr.Markdown("This interface processes a user query by finding the most relevant context from PyImageSearch and generating a detailed response.") with gr.Row(): with gr.Column(): user_query = gr.Textbox(label="User Query", placeholder="Enter your query here...", lines=2) with gr.Column(): search_url = gr.Textbox(label="URL", interactive=False) generated_response = gr.Textbox(label="Generated Response", interactive=False) submit_button = gr.Button("Submit") submit_button.click( fn=process_query, inputs=user_query, outputs=[search_url, generated_response] ) demo.launch()