Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import gradio as gr | |
| import fitz # PyMuPDF | |
| import easyocr | |
| from PIL import Image | |
| from sentence_transformers import SentenceTransformer | |
| from chromadb import Client, Settings | |
| # Load Zephyr 7B (fine-tuned for chat) | |
| zephyr_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-alpha") | |
| zephyr_model = AutoModelForCausalLM.from_pretrained( | |
| "HuggingFaceH4/zephyr-7b-alpha", | |
| torch_dtype=torch.float16, # Use half-precision for faster inference | |
| device_map="auto" # Automatically loads the model on GPU if available | |
| ) | |
| # Load a sentence transformer model for embeddings | |
| embedding_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| # Initialize Chroma client for RAG | |
| chroma_client = Client(Settings()) | |
| collection = chroma_client.create_collection(name="knowledge_base") | |
| # Function to extract text from PDF | |
| def extract_text_from_pdf(pdf_path): | |
| doc = fitz.open(pdf_path) | |
| text = "" | |
| for page in doc: | |
| text += page.get_text() | |
| return text | |
| # Function to extract text from image | |
| def extract_text_from_image(image_path): | |
| reader = easyocr.Reader(['en']) | |
| results = reader.readtext(image_path) | |
| extracted_text = " ".join([res[1] for res in results]) | |
| return extracted_text | |
| # Function to generate a response | |
| def generate_response(prompt): | |
| # Structure the input prompt for chat | |
| formatted_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n" | |
| # Tokenize the input prompt | |
| inputs = zephyr_tokenizer(formatted_prompt, return_tensors="pt").to(zephyr_model.device) | |
| # Generate the response | |
| outputs = zephyr_model.generate(**inputs, max_length=200) | |
| # Decode the response | |
| response = zephyr_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract only the assistant's response | |
| response = response.split("<|assistant|>")[-1].strip() | |
| return response | |
| # Function to add documents to the knowledge base | |
| def add_to_knowledge_base(text_chunks): | |
| embeddings = embedding_model.encode(text_chunks) | |
| for idx, (chunk, embedding) in enumerate(zip(text_chunks, embeddings)): | |
| collection.add( | |
| documents=[chunk], | |
| embeddings=[embedding.tolist()], | |
| ids=[str(idx)] | |
| ) | |
| # Function to retrieve relevant chunks | |
| def retrieve_relevant_chunks(query, top_k=3): | |
| query_embedding = embedding_model.encode(query) | |
| results = collection.query( | |
| query_embeddings=[query_embedding.tolist()], | |
| n_results=top_k | |
| ) | |
| return results["documents"][0] | |
| # Chatbot function to handle text, PDF, and image inputs | |
| def chatbot(input_type, text_input, pdf_input, image_input): | |
| if input_type == "Text": | |
| if not text_input: | |
| return "Please enter some text." | |
| query = text_input | |
| elif input_type == "PDF": | |
| if pdf_input is None: | |
| return "Please upload a PDF file." | |
| pdf_text = extract_text_from_pdf(pdf_input) | |
| query = f"Extracted text from PDF:\n{pdf_text}\n\nQuestion: {text_input}" | |
| elif input_type == "Image": | |
| if image_input is None: | |
| return "Please upload an image file." | |
| image_text = extract_text_from_image(image_input) | |
| query = f"Extracted text from image:\n{image_text}\n\nQuestion: {text_input}" | |
| else: | |
| return "Invalid input type." | |
| # Retrieve relevant chunks from the knowledge base | |
| relevant_chunks = retrieve_relevant_chunks(query) | |
| context = "\n\n".join(relevant_chunks) | |
| # Generate response using the model | |
| prompt = f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:" | |
| response = generate_response(prompt) | |
| return response | |
| # Gradio interface | |
| input_components = [ | |
| gr.Dropdown(choices=["Text", "PDF", "Image"], label="Input Type"), | |
| gr.Textbox(lines=2, placeholder="Enter text...", label="Text Input"), | |
| gr.File(label="Upload PDF", file_types=[".pdf"]), | |
| gr.Image(label="Upload Image", type="filepath") | |
| ] | |
| # Create the Gradio interface | |
| interface = gr.Interface( | |
| fn=chatbot, | |
| inputs=input_components, | |
| outputs="text", | |
| title="RAG Chatbot with PDF and Image Support", | |
| description="Select the input type (Text, PDF, or Image) and provide your input." | |
| ) | |
| # Launch the app | |
| interface.launch() | |