TestGPT2ChatBot / app.py
Shahbazakbar's picture
Update app.py
fb75b56 verified
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import fitz # PyMuPDF
import easyocr
from PIL import Image
from sentence_transformers import SentenceTransformer
from chromadb import Client, Settings
# Load Zephyr 7B (fine-tuned for chat)
zephyr_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-alpha")
zephyr_model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceH4/zephyr-7b-alpha",
torch_dtype=torch.float16, # Use half-precision for faster inference
device_map="auto" # Automatically loads the model on GPU if available
)
# Load a sentence transformer model for embeddings
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
# Initialize Chroma client for RAG
chroma_client = Client(Settings())
collection = chroma_client.create_collection(name="knowledge_base")
# Function to extract text from PDF
def extract_text_from_pdf(pdf_path):
doc = fitz.open(pdf_path)
text = ""
for page in doc:
text += page.get_text()
return text
# Function to extract text from image
def extract_text_from_image(image_path):
reader = easyocr.Reader(['en'])
results = reader.readtext(image_path)
extracted_text = " ".join([res[1] for res in results])
return extracted_text
# Function to generate a response
def generate_response(prompt):
# Structure the input prompt for chat
formatted_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
# Tokenize the input prompt
inputs = zephyr_tokenizer(formatted_prompt, return_tensors="pt").to(zephyr_model.device)
# Generate the response
outputs = zephyr_model.generate(**inputs, max_length=200)
# Decode the response
response = zephyr_tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response
response = response.split("<|assistant|>")[-1].strip()
return response
# Function to add documents to the knowledge base
def add_to_knowledge_base(text_chunks):
embeddings = embedding_model.encode(text_chunks)
for idx, (chunk, embedding) in enumerate(zip(text_chunks, embeddings)):
collection.add(
documents=[chunk],
embeddings=[embedding.tolist()],
ids=[str(idx)]
)
# Function to retrieve relevant chunks
def retrieve_relevant_chunks(query, top_k=3):
query_embedding = embedding_model.encode(query)
results = collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=top_k
)
return results["documents"][0]
# Chatbot function to handle text, PDF, and image inputs
def chatbot(input_type, text_input, pdf_input, image_input):
if input_type == "Text":
if not text_input:
return "Please enter some text."
query = text_input
elif input_type == "PDF":
if pdf_input is None:
return "Please upload a PDF file."
pdf_text = extract_text_from_pdf(pdf_input)
query = f"Extracted text from PDF:\n{pdf_text}\n\nQuestion: {text_input}"
elif input_type == "Image":
if image_input is None:
return "Please upload an image file."
image_text = extract_text_from_image(image_input)
query = f"Extracted text from image:\n{image_text}\n\nQuestion: {text_input}"
else:
return "Invalid input type."
# Retrieve relevant chunks from the knowledge base
relevant_chunks = retrieve_relevant_chunks(query)
context = "\n\n".join(relevant_chunks)
# Generate response using the model
prompt = f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
response = generate_response(prompt)
return response
# Gradio interface
input_components = [
gr.Dropdown(choices=["Text", "PDF", "Image"], label="Input Type"),
gr.Textbox(lines=2, placeholder="Enter text...", label="Text Input"),
gr.File(label="Upload PDF", file_types=[".pdf"]),
gr.Image(label="Upload Image", type="filepath")
]
# Create the Gradio interface
interface = gr.Interface(
fn=chatbot,
inputs=input_components,
outputs="text",
title="RAG Chatbot with PDF and Image Support",
description="Select the input type (Text, PDF, or Image) and provide your input."
)
# Launch the app
interface.launch()