PDFgpt / app.py
swamisharan's picture
Update app.py
0902b78 verified
raw
history blame contribute delete
No virus
2.66 kB
import gradio as gr
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import pipeline
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.vectorstores import Chroma
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.document_loaders import PDFMinerLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
import chromadb
import tempfile
# Define Chroma Settings
CHROMA_SETTINGS = {
"chroma_db_impl": "duckdb+parquet",
"persist_directory": tempfile.mkdtemp(), # Use a temporary directory
"anonymized_telemetry": False
}
# Load model and tokenizer
checkpoint = "MBZUAI/LaMini-Flan-T5-783M"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
base_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, device_map=torch.device("cpu"), torch_dtype=torch.float32)
# Define functions
def data_ingestion(file_path):
loader = PDFMinerLoader(file_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=500)
texts = text_splitter.split_documents(documents)
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
db = Chroma.from_documents(texts, embeddings, persist_directory=CHROMA_SETTINGS["persist_directory"])
db.persist()
print(texts)
return db
def llm_pipeline():
pipe = pipeline(
"text2text-generation",
model=base_model,
tokenizer=tokenizer,
max_length=256,
do_sample=True,
temperature=0.3,
top_p=0.95
)
local_llm = HuggingFacePipeline(pipeline=pipe)
return local_llm
def qa_llm():
llm = llm_pipeline()
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
vectordb = Chroma(persist_directory=CHROMA_SETTINGS["persist_directory"], embedding_function=embeddings)
retriever = vectordb.as_retriever()
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True
)
return qa
def process_answer(file, instruction):
# Ingest the data from the uploaded PDF
data_ingestion(file.name)
# Process the question
qa = qa_llm()
generated_text = qa(instruction)
answer = generated_text["result"]
return answer
# Define Gradio interfac
iface = gr.Interface(
fn=process_answer,
inputs=["file", "text"],
outputs="text"
)
# Launch the interface
iface.launch()