|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
import os |
|
from langchain.prompts import ChatPromptTemplate |
|
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint |
|
from langchain.schema.runnable import RunnablePassthrough |
|
from langchain_community.document_loaders import TextLoader |
|
from langchain_text_splitters import CharacterTextSplitter |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings |
|
|
|
API_TOKEN = os.environ.get('HUGGINGFACE_API_KEY') |
|
HF_API_KEY = API_TOKEN |
|
|
|
llm_urls = { |
|
"Mistral 7B": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2", |
|
"Llama 8B": "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct" |
|
} |
|
|
|
def initialize_vector_store_retriever(file): |
|
|
|
|
|
raw_documents = TextLoader(file).load() |
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) |
|
documents = text_splitter.split_documents(raw_documents) |
|
|
|
API_URL = "https://api-inference.huggingface.co/models/BAAI/bge-base-en-v1.5" |
|
embeddings = HuggingFaceInferenceAPIEmbeddings( |
|
endpoint_url=API_URL, |
|
api_key=HF_API_KEY, |
|
) |
|
db = FAISS.from_documents(documents, embeddings) |
|
retriever = db.as_retriever() |
|
return retriever |
|
|
|
def generate_llm_rag_prompt() -> ChatPromptTemplate: |
|
|
|
template = "<s>[INST] <<SYS>>{system}<</SYS>>{context} {prompt} [/INST]" |
|
|
|
prompt_template = ChatPromptTemplate.from_template(template) |
|
return prompt_template |
|
|
|
|
|
|
|
def create_chain(retriever, llm): |
|
|
|
url = llm_urls[llm] |
|
model_endpoint = HuggingFaceEndpoint( |
|
endpoint_url=url, |
|
huggingfacehub_api_token=HF_API_KEY, |
|
task="text2text-generation", |
|
max_new_tokens=200 |
|
) |
|
|
|
if retriever != None: |
|
def get_system(input): |
|
return "You are a helpful and honest assistant. Please, respond concisely and truthfully." |
|
|
|
retrieval = {"context": retriever, "prompt": RunnablePassthrough(), "system": get_system} |
|
chain = retrieval | generate_llm_rag_prompt() | model_endpoint |
|
return chain, model_endpoint |
|
else: |
|
return None, model_endpoint |
|
|
|
|
|
def query(question_text, llm, session_data): |
|
if question_text == "": |
|
without_rag_text = "Query result without RAG is not available. Enter a question first." |
|
rag_text = "Query result with RAG is not available. Enter a question first." |
|
return without_rag_text, rag_text |
|
|
|
if len(session_data)>0: |
|
retriever = session_data[0] |
|
else: |
|
retriever = None |
|
chain, model_endpoint = create_chain(retriever, llm) |
|
without_rag_text = "Query result without RAG:\n\n" + model_endpoint(question_text).strip() |
|
if (retriever == None): |
|
rag_text = "Query result With RAG is not available. Load Vector Store first." |
|
else: |
|
ans = chain.invoke(question_text).strip() |
|
s = ans |
|
s = [s.split("[INST] <<SYS>>")[1] for s in s.split("[/SYS]>[/INST]") if s.find("[INST] <<SYS>>") >=0] |
|
if len(s) >= 2: |
|
s = s[1:-1] |
|
else: |
|
s = ans |
|
rag_text = "Query result With RAG:\n\n" + "".join(s).split("[/INST]")[0] |
|
return without_rag_text, rag_text |
|
|
|
def upload_file(file, session_data): |
|
|
|
|
|
session_data = [initialize_vector_store_retriever(file)] |
|
return gr.File(value=file, visible=True), session_data |
|
|
|
def initialize_vector_store(session_data): |
|
session_data = [initialize_vector_store_retriever()] |
|
return session_data |
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML("""<h1 align="center">Retrieval Augmented Generation</h1>""") |
|
session_data = gr.State([]) |
|
|
|
file_output = gr.File(visible=False) |
|
upload_button = gr.UploadButton("Click to Upload a text File to Vector Store", file_types=["text"], file_count="single") |
|
upload_button.upload(upload_file, [upload_button, session_data], [file_output, session_data]) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
question_text = gr.Textbox(show_label=False, placeholder="Ask a question", lines=2) |
|
with gr.Column(scale=1): |
|
llm_Choice = gr.Radio(["Llama 8B", "Mistral 7B"], value="Mistral 7B", label="Select lanaguage model:", info="") |
|
query_Button = gr.Button("Query") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
without_rag_text = gr.Textbox(show_label=False, placeholder="Query result without using RAG", lines=15) |
|
with gr.Column(scale=1): |
|
rag_text = gr.Textbox(show_label=False, placeholder="Query result with RAG", lines=15) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
query_Button.click( |
|
query, |
|
[question_text, llm_Choice, session_data], |
|
[without_rag_text, rag_text], |
|
|
|
) |
|
|
|
demo.queue().launch(share=False, inbrowser=True) |