File size: 2,888 Bytes
7e55c3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, pipeline
from langchain.llms import HuggingFaceHub, HuggingFacePipeline 
from dotenv import load_dotenv
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
import textwrap
import torch
import os
import streamlit as st

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def load_vector_store():
    model_name = "BAAI/bge-small-en"
    model_kwargs = {"device": device}
    encode_kwargs = {"normalize_embeddings": True}
    embeddings = HuggingFaceBgeEmbeddings(
        model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
    )
    print('Embeddings loaded!')
    load_vector_store = Chroma(persist_directory = 'vector stores/textdb', embedding_function = embeddings)
    print('Vector store loaded!')

    retriever = load_vector_store.as_retriever(
        search_kwargs = {"k" : 10},
    )
    return retriever

    
#model 
def load_model():
    repo_id = 'llmware/dragon-mistral-7b-v0'
    llm = HuggingFaceHub(
        repo_id = repo_id,
        model_kwargs = {'max_new_tokens' : 100}
    )
    print(llm('HI!'))
    return llm


def qa_chain():
    retriever = load_vector_store()
    llm = load_model()
    qa = RetrievalQA.from_chain_type(
        llm = llm,
        chain_type = 'stuff',
        retriever = retriever,
        return_source_documents = True,
        verbose = True 
    )
    return qa

def wrap_text_preserve_newlines(text, width=110):
    # Split the input text into lines based on newline characters
    lines = text.split('\n')

    # Wrap each line individually
    wrapped_lines = [textwrap.fill(line, width=width) for line in lines]

    # Join the wrapped lines back together using newline characters
    wrapped_text = '\n'.join(wrapped_lines)

    return wrapped_text

def process_llm_response(llm_response):
    print(wrap_text_preserve_newlines(llm_response['result']))
    print('\n\nSources:')
    for source in llm_response["source_documents"]:
        print(source.metadata['source'])

def main():
    qa = qa_chain()
    st.title('DOCUMENT-GPT')
    text_query = st.text_area('Ask any question from your documents!')
    generate_response_btn = st.button('Run RAG')
    
    st.subheader('Response')
    if generate_response_btn and text_query is not None:
        with st.spinner('Generating Response. Please wait...'):
            text_response = qa(f"<human>:" + text_query + "\n" + "<bot>:")
            if text_response:
                st.write(text_response["result"])
            else:
                st.error('Failed to get response')

if __name__ == "__main__":
    hf_token = st.text_input("Paste Huggingface read api key")
    if hf_token:
        os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
        main()