|
from langchain_community.document_loaders import TextLoader , PyPDFLoader |
|
from langchain_community.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.llms import HuggingFacePipeline |
|
from transformers import pipeline |
|
from langchain import PromptTemplate |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
from langchain.chains import RetrievalQA |
|
import torch |
|
import gradio as gr |
|
|
|
|
|
loader = TextLoader("info.txt") |
|
docs = loader.load() |
|
text_splitter = RecursiveCharacterTextSplitter() |
|
|
|
documents = text_splitter.split_documents(docs) |
|
|
|
huggingface_embeddings = HuggingFaceBgeEmbeddings( |
|
model_name="BAAI/bge-small-en-v1.5", |
|
model_kwargs={'device':'cpu'}, |
|
encode_kwargs={'normalize_embeddings': True} |
|
) |
|
|
|
vector = FAISS.from_documents(documents, huggingface_embeddings) |
|
retriever = vector.as_retriever() |
|
|
|
model_name = "aware-ai/bart-squadv2" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
pipe = pipeline( |
|
"text2text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_length=300, |
|
temperature=0.9, |
|
top_p=0.9, |
|
repetition_penalty=1.15, |
|
do_sample=True |
|
|
|
) |
|
local_llm = HuggingFacePipeline(pipeline=pipe) |
|
qa_chain = RetrievalQA.from_llm(llm=local_llm, retriever=retriever) |
|
|
|
|
|
|
|
def gradinterface(query,history): |
|
|
|
if query == "exit": |
|
pass |
|
|
|
|
|
template = """Use the following pieces of context to answer the question at the end. |
|
If you don't know the answer, just say that you don't know, don't try to make up an answer. |
|
Use three sentences maximum and keep the answer as concise as possible. |
|
{context} |
|
Question: {question} |
|
|
|
Helpful Answer:""" |
|
|
|
QA_CHAIN_PROMPT = PromptTemplate( |
|
input_variables=["context", "question"], |
|
template=template, |
|
) |
|
qa_chain = RetrievalQA.from_chain_type( |
|
local_llm, |
|
retriever=vector.as_retriever(), |
|
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}, |
|
) |
|
result = qa_chain({"query": query}) |
|
return result['result'].split(': ')[-1].strip() |
|
|
|
|
|
demo = gr.ChatInterface(fn=gradinterface, title='OUR_OWN_BOT') |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |
|
|