RAGDocGen / app.py
laitkor's picture
Update app.py
d7863a8 verified
raw
history blame contribute delete
3.71 kB
import os
import subprocess
# Run the setup script to install Tesseract
subprocess.run(["apt-get", "update"])
subprocess.run(["apt-get", "install", "-y", "tesseract-ocr"])
import pytesseract
from PIL import Image
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel, T5ForConditionalGeneration, T5Tokenizer, GPT2LMHeadModel, GPT2Tokenizer
import faiss
import numpy as np
import gradio as gr
# Set the tesseract command path
pytesseract.pytesseract.tesseract_cmd = '/usr/bin/tesseract'
# Define the path to the docs folder
docs_path = "docs"
# Load images from the folder and extract text using pytesseract
def load_documents(docs_path):
documents = []
for filename in os.listdir(docs_path):
if filename.endswith((".png", ".jpg", ".jpeg")):
image_path = os.path.join(docs_path, filename)
text = pytesseract.image_to_string(Image.open(image_path))
documents.append(text)
return documents
documents = load_documents(docs_path)
# Load model and tokenizer for encoding documents
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModel.from_pretrained("distilbert-base-uncased")
# Preprocess and encode documents
def encode_documents(documents):
inputs = tokenizer(documents, return_tensors='pt', padding=True, truncation=True)
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1).detach().numpy()
return embeddings
embeddings = encode_documents(documents)
# Index embeddings using FAISS
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
# Load T5 model and tokenizer for question generation
t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
def generate_questions(text):
input_text = f"generate question: {text}"
input_ids = t5_tokenizer.encode(input_text, return_tensors="pt")
outputs = t5_model.generate(input_ids)
question = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
return question
def retrieve_documents(query, index, documents):
query_embedding = encode_documents([query])
D, I = index.search(query_embedding, k=5)
return [documents[i] for i in I[0]]
print ("gpt start")
# Load GPT-2 model and tokenizer for answer generation
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2")
print ("gpt done")
def generate_answer(question, context):
input_text = f"Question: {question}\nContext: {context}\nAnswer:"
input_ids = gpt2_tokenizer.encode(input_text, return_tensors="pt")
outputs = gpt2_model.generate(input_ids, max_length=150)
answer = gpt2_tokenizer.decode(outputs[0], skip_special_tokens=True)
return answer
# Fine-tuning the language model (example code provided earlier)
# Generate documents based on prompts
def generate_support_document(prompt):
input_ids = gpt2_tokenizer.encode(prompt, return_tensors="pt")
outputs = gpt2_model.generate(input_ids, max_length=512, num_return_sequences=1)
document = gpt2_tokenizer.decode(outputs[0], skip_special_tokens=True)
return document
# Gradio Interface
def rag_pipeline(query):
question = generate_questions(query)
retrieved_docs = retrieve_documents(question, index, documents)
context = " ".join(retrieved_docs)
answer = generate_answer(question, context)
return question, context, answer
iface = gr.Interface(
fn=rag_pipeline,
inputs="text",
outputs=["text", "text", "text"],
title="RAG Pipeline for Product Support",
description="Ask a question about product support and get a detailed answer."
)
iface.launch()