pvyas96's picture
Upload 2 files
7e55c3b verified
raw
history blame
No virus
2.89 kB
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, pipeline
from langchain.llms import HuggingFaceHub, HuggingFacePipeline
from dotenv import load_dotenv
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
import textwrap
import torch
import os
import streamlit as st
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_vector_store():
model_name = "BAAI/bge-small-en"
model_kwargs = {"device": device}
encode_kwargs = {"normalize_embeddings": True}
embeddings = HuggingFaceBgeEmbeddings(
model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
)
print('Embeddings loaded!')
load_vector_store = Chroma(persist_directory = 'vector stores/textdb', embedding_function = embeddings)
print('Vector store loaded!')
retriever = load_vector_store.as_retriever(
search_kwargs = {"k" : 10},
)
return retriever
#model
def load_model():
repo_id = 'llmware/dragon-mistral-7b-v0'
llm = HuggingFaceHub(
repo_id = repo_id,
model_kwargs = {'max_new_tokens' : 100}
)
print(llm('HI!'))
return llm
def qa_chain():
retriever = load_vector_store()
llm = load_model()
qa = RetrievalQA.from_chain_type(
llm = llm,
chain_type = 'stuff',
retriever = retriever,
return_source_documents = True,
verbose = True
)
return qa
def wrap_text_preserve_newlines(text, width=110):
# Split the input text into lines based on newline characters
lines = text.split('\n')
# Wrap each line individually
wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
# Join the wrapped lines back together using newline characters
wrapped_text = '\n'.join(wrapped_lines)
return wrapped_text
def process_llm_response(llm_response):
print(wrap_text_preserve_newlines(llm_response['result']))
print('\n\nSources:')
for source in llm_response["source_documents"]:
print(source.metadata['source'])
def main():
qa = qa_chain()
st.title('DOCUMENT-GPT')
text_query = st.text_area('Ask any question from your documents!')
generate_response_btn = st.button('Run RAG')
st.subheader('Response')
if generate_response_btn and text_query is not None:
with st.spinner('Generating Response. Please wait...'):
text_response = qa(f"<human>:" + text_query + "\n" + "<bot>:")
if text_response:
st.write(text_response["result"])
else:
st.error('Failed to get response')
if __name__ == "__main__":
hf_token = st.text_input("Paste Huggingface read api key")
if hf_token:
os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
main()