File size: 5,720 Bytes
8fe2d96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48448ec
8fe2d96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# 1. Using langchain Vector store
# https://python.langchain.com/v0.1/docs/modules/data_connection/vectorstores/
# VectorStore - FAISS 
# 2. Embedding - HuggingFaceInferenceAPIEmbeddings with "BAAI/bge-base-en-v1.5"
# 3. llm use mistral and llama.
# "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
# "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"

import gradio as gr
import os
from langchain.prompts import ChatPromptTemplate
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain.schema.runnable import RunnablePassthrough
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings

API_TOKEN = os.environ.get('HUGGINGFACE_API_KEY')
HF_API_KEY = API_TOKEN

llm_urls = { 
    "Mistral 7B": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2", 
    "Llama 8B": "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
}

def initialize_vector_store_retriever(file):
    # Load the document, split it into chunks, embed each chunk and load it into the vector store.
    #raw_documents = TextLoader('./llm.txt').load()
    raw_documents = TextLoader(file).load()
    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    documents = text_splitter.split_documents(raw_documents)

    API_URL = "https://api-inference.huggingface.co/models/BAAI/bge-base-en-v1.5"
    embeddings = HuggingFaceInferenceAPIEmbeddings(
        endpoint_url=API_URL,
        api_key=HF_API_KEY,
    )
    db = FAISS.from_documents(documents, embeddings)
    retriever = db.as_retriever()
    return retriever

def generate_llm_rag_prompt() -> ChatPromptTemplate:
    #template = "<s>[INST] {context} {prompt} [/INST]"
    template = "<s>[INST] <<SYS>>{system}<</SYS>>{context} {prompt} [/INST]"

    prompt_template = ChatPromptTemplate.from_template(template)
    return prompt_template


    
def create_chain(retriever, llm):

    url = llm_urls[llm]
    model_endpoint = HuggingFaceEndpoint(
        endpoint_url=url,
        huggingfacehub_api_token=HF_API_KEY,
        task="text2text-generation",
        max_new_tokens=200
    )

    if retriever != None:
        def get_system(input):    
            return "You are a helpful and honest assistant. Please, respond concisely and truthfully."

        retrieval = {"context": retriever, "prompt": RunnablePassthrough(), "system": get_system}
        chain = retrieval | generate_llm_rag_prompt() | model_endpoint 
        return chain, model_endpoint
    else:
        return None, model_endpoint


def query(question_text, llm, session_data):
    if question_text == "":
        without_rag_text = "Query result without RAG is not available. Enter a question first."
        rag_text = "Query result with RAG is not available. Enter a question first."
        return without_rag_text, rag_text
    
    if len(session_data)>0:
        retriever = session_data[0]
    else:
        retriever = None     
    chain, model_endpoint = create_chain(retriever, llm)
    without_rag_text = "Query result without RAG:\n\n" + model_endpoint(question_text).strip()
    if (retriever == None):
        rag_text = "Query result With RAG is not available. Load Vector Store first."
    else:
        ans = chain.invoke(question_text).strip()
        s = ans
        s = [s.split("[INST] <<SYS>>")[1] for s in s.split("[/SYS]>[/INST]") if s.find("[INST] <<SYS>>") >=0]
        if len(s) >= 2:
            s = s[1:-1]
        else:
            s = ans
        rag_text = "Query result With RAG:\n\n" + "".join(s).split("[/INST]")[0]
    return without_rag_text, rag_text

def upload_file(file, session_data): 
    #file_paths = [file.name for file in files]
    #file = files[0]
    session_data = [initialize_vector_store_retriever(file)]
    return gr.File(value=file, visible=True), session_data

def initialize_vector_store(session_data):
    session_data = [initialize_vector_store_retriever()]
    return session_data
    
with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">Retrieval Augmented Generation</h1>""")
    session_data = gr.State([])

    file_output = gr.File(visible=False)
    upload_button = gr.UploadButton("Click to Upload a text File to Vector Store", file_types=["text"], file_count="single")
    upload_button.upload(upload_file, [upload_button, session_data], [file_output, session_data])

    #initialize_VS_button = gr.Button("Load text file to Vector Store")
    with gr.Row():
        with gr.Column(scale=4):    
            question_text = gr.Textbox(show_label=False, placeholder="Ask a question", lines=2)    
        with gr.Column(scale=1):    
            llm_Choice = gr.Radio(["Llama 8B", "Mistral 7B"], value="Mistral 7B", label="Select lanaguage model:", info="")    
    query_Button = gr.Button("Query")

    with gr.Row():
        with gr.Column(scale=1):
            without_rag_text = gr.Textbox(show_label=False, placeholder="Query result without using RAG", lines=15)
        with gr.Column(scale=1):
            rag_text = gr.Textbox(show_label=False, placeholder="Query result with RAG", lines=15)

    #initialize_VS_button.click(
    #    initialize_vector_store,
    #    [session_data],
    #    [session_data],
    #    #show_progress=True,
    #)     
    query_Button.click(
        query,
        [question_text, llm_Choice, session_data],
        [without_rag_text, rag_text],
        #show_progress=True,
    )

demo.queue().launch(share=False, inbrowser=True)