Spaces:
Paused
Paused
import streamlit as st | |
import pickle | |
import os | |
import torch | |
from tqdm.auto import tqdm | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
# from langchain.vectorstores import Chroma | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceInstructEmbeddings | |
from langchain import HuggingFacePipeline | |
from langchain.chains import RetrievalQA | |
st.set_page_config( | |
page_title = 'aitGPT', | |
page_icon = '✅') | |
st.markdown("# Hello") | |
def load_scraped_web_info(): | |
with open("ait-web-document", "rb") as fp: | |
ait_web_documents = pickle.load(fp) | |
text_splitter = RecursiveCharacterTextSplitter( | |
# Set a really small chunk size, just to show. | |
chunk_size = 500, | |
chunk_overlap = 100, | |
length_function = len, | |
) | |
chunked_text = text_splitter.create_documents([doc for doc in tqdm(ait_web_documents)]) | |
st.markdown(f"Number of Documents: {len(ait_web_documents)}") | |
st.markdown(f"Number of chunked texts: {len(chunked_text)}") | |
def load_embedding_model(): | |
embedding_model = HuggingFaceInstructEmbeddings(model_name='hkunlp/instructor-base', | |
model_kwargs = {'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu')}) | |
return embedding_model | |
def load_faiss_index(): | |
vector_database = FAISS.load_local("faiss_index", embedding_model) | |
return vector_database | |
def load_llm_model(): | |
# llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0', | |
# task= 'text2text-generation', | |
# model_kwargs={ "device_map": "auto", | |
# "load_in_8bit": True,"max_length": 256, "temperature": 0, | |
# "repetition_penalty": 1.5}) | |
llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0', | |
task= 'text2text-generation', | |
model_kwargs={ "max_length": 256, "temperature": 0, | |
"torch_dtype":torch.float32, | |
"repetition_penalty": 1.3}) | |
return llm | |
def load_retriever(llm, db): | |
qa_retriever = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", | |
retriever=db.as_retriever()) | |
return qa_retriever | |
#-------------- | |
load_scraped_web_info() | |
embedding_model = load_embedding_model() | |
vector_database = load_faiss_index() | |
llm_model = load_llm_model() | |
qa_retriever = load_retriever(llm= llm_model, db= vector_database) | |
print("all load done") | |
query_input = st.text_input(label= 'your question') | |
def retrieve_document(query_input): | |
related_doc = vector_database.similarity_search(query_input) | |
return related_doc | |
def retrieve_answer(query_input): | |
answer = qa_retriever.run(query_input) | |
return answer | |
output_1 = st.text_area(label = "Here is the relevant documents", | |
value = retrieve_document(query_input)) | |
output_2 = st.text_area(label = "Here is the answer", | |
value = retrieve_answer(query_input)) | |
# faiss_retriever = vector_database.as_retriever() | |
# print("Succesfully had FAISS as retriever") | |