ibm_project / app.py
adnaniqbal001's picture
Rename qabot.py to app.py
9f25444 verified
from ibm_watsonx_ai.foundation_models import ModelInference
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames
from ibm_watsonx_ai import Credentials
from langchain_ibm import WatsonxLLM, WatsonxEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import PyPDFLoader
from langchain.chains import RetrievalQA
import gradio as gr
import warnings
def warn(*args, **kwargs):
pass
warnings.warn = warn
warnings.filterwarnings('ignore')
## LLM
def get_llm():
model_id = 'mistralai/mixtral-8x7b-instruct-v01'
parameters = {
GenParams.MAX_NEW_TOKENS: 256,
GenParams.TEMPERATURE: 0.5,
}
project_id = "skills-network"
watsonx_llm = WatsonxLLM(
model_id=model_id,
url="https://us-south.ml.cloud.ibm.com",
project_id=project_id,
params=parameters,
)
return watsonx_llm
## Document loader
def document_loader(file_path):
loader = PyPDFLoader(file_path)
return loader.load()
## Text splitter
def text_splitter(data):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=50,
length_function=len,
)
return text_splitter.split_documents(data)
## Embedding model
def watsonx_embedding():
embed_params = {
EmbedTextParamsMetaNames.TRUNCATE_INPUT_TOKENS: 3,
EmbedTextParamsMetaNames.RETURN_OPTIONS: {"input_text": True},
}
return WatsonxEmbeddings(
model_id="ibm/slate-125m-english-rtrvr",
url="https://us-south.ml.cloud.ibm.com",
project_id="skills-network",
params=embed_params,
)
## Vector DB
def vector_database(chunks):
embedding_model = watsonx_embedding()
vectordb = Chroma.from_documents(chunks, embedding_model)
return vectordb
## Retriever
def retriever(file_path):
splits = document_loader(file_path)
chunks = text_splitter(splits)
vectordb = vector_database(chunks)
return vectordb.as_retriever()
## QA Chain
def retriever_qa(file, query):
llm = get_llm()
retriever_obj = retriever(file)
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever_obj,
return_source_documents=False
)
result = qa.invoke(query)
return result['result']
# Gradio Interface
rag_application = gr.Interface(
fn=retriever_qa,
allow_flagging="never",
inputs=[
gr.File(label="Upload PDF File", file_count="single", file_types=['.pdf'], type="filepath"),
gr.Textbox(label="Input Query", lines=2, placeholder="Type your question here...")
],
outputs=gr.Textbox(label="Output"),
title="RAG Chatbot",
description="Upload a PDF document and ask any question. The chatbot will try to answer using the provided document."
)
if __name__ == "__main__":
rag_application.launch()