chabi / main.py
anasmkh's picture
Update main.py
c4c7ec5 verified
from langchain_community.document_loaders import TextLoader , PyPDFLoader
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.llms import HuggingFacePipeline
from transformers import pipeline
from langchain import PromptTemplate
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from langchain.chains import RetrievalQA
import torch
import gradio as gr
# loader = PyPDFLoader('bipolar.pdf')
loader = TextLoader("info.txt")
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter()
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)
documents = text_splitter.split_documents(docs)
huggingface_embeddings = HuggingFaceBgeEmbeddings(
model_name="BAAI/bge-small-en-v1.5",
model_kwargs={'device':'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
vector = FAISS.from_documents(documents, huggingface_embeddings)
retriever = vector.as_retriever()
model_name = "aware-ai/bart-squadv2"
# model_name = "distilbert/distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
pipe = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_length=300,
temperature=0.9,
top_p=0.9,
repetition_penalty=1.15,
do_sample=True
)
local_llm = HuggingFacePipeline(pipeline=pipe)
qa_chain = RetrievalQA.from_llm(llm=local_llm, retriever=retriever)
def gradinterface(query,history):
if query == "exit":
pass
template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Use three sentences maximum and keep the answer as concise as possible.
{context}
Question: {question}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate(
input_variables=["context", "question"],
template=template,
)
qa_chain = RetrievalQA.from_chain_type(
local_llm,
retriever=vector.as_retriever(),
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
)
result = qa_chain({"query": query})
return result['result'].split(': ')[-1].strip()
demo = gr.ChatInterface(fn=gradinterface, title='OUR_OWN_BOT')
if __name__ == "__main__":
demo.launch(share=True)