wb-droid's picture
Minor adjustment.
48448ec
# 1. Using langchain Vector store
# https://python.langchain.com/v0.1/docs/modules/data_connection/vectorstores/
# VectorStore - FAISS
# 2. Embedding - HuggingFaceInferenceAPIEmbeddings with "BAAI/bge-base-en-v1.5"
# 3. llm use mistral and llama.
# "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
# "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
import gradio as gr
import os
from langchain.prompts import ChatPromptTemplate
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain.schema.runnable import RunnablePassthrough
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
API_TOKEN = os.environ.get('HUGGINGFACE_API_KEY')
HF_API_KEY = API_TOKEN
llm_urls = {
"Mistral 7B": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2",
"Llama 8B": "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
}
def initialize_vector_store_retriever(file):
# Load the document, split it into chunks, embed each chunk and load it into the vector store.
#raw_documents = TextLoader('./llm.txt').load()
raw_documents = TextLoader(file).load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
documents = text_splitter.split_documents(raw_documents)
API_URL = "https://api-inference.huggingface.co/models/BAAI/bge-base-en-v1.5"
embeddings = HuggingFaceInferenceAPIEmbeddings(
endpoint_url=API_URL,
api_key=HF_API_KEY,
)
db = FAISS.from_documents(documents, embeddings)
retriever = db.as_retriever()
return retriever
def generate_llm_rag_prompt() -> ChatPromptTemplate:
#template = "<s>[INST] {context} {prompt} [/INST]"
template = "<s>[INST] <<SYS>>{system}<</SYS>>{context} {prompt} [/INST]"
prompt_template = ChatPromptTemplate.from_template(template)
return prompt_template
def create_chain(retriever, llm):
url = llm_urls[llm]
model_endpoint = HuggingFaceEndpoint(
endpoint_url=url,
huggingfacehub_api_token=HF_API_KEY,
task="text2text-generation",
max_new_tokens=200
)
if retriever != None:
def get_system(input):
return "You are a helpful and honest assistant. Please, respond concisely and truthfully."
retrieval = {"context": retriever, "prompt": RunnablePassthrough(), "system": get_system}
chain = retrieval | generate_llm_rag_prompt() | model_endpoint
return chain, model_endpoint
else:
return None, model_endpoint
def query(question_text, llm, session_data):
if question_text == "":
without_rag_text = "Query result without RAG is not available. Enter a question first."
rag_text = "Query result with RAG is not available. Enter a question first."
return without_rag_text, rag_text
if len(session_data)>0:
retriever = session_data[0]
else:
retriever = None
chain, model_endpoint = create_chain(retriever, llm)
without_rag_text = "Query result without RAG:\n\n" + model_endpoint(question_text).strip()
if (retriever == None):
rag_text = "Query result With RAG is not available. Load Vector Store first."
else:
ans = chain.invoke(question_text).strip()
s = ans
s = [s.split("[INST] <<SYS>>")[1] for s in s.split("[/SYS]>[/INST]") if s.find("[INST] <<SYS>>") >=0]
if len(s) >= 2:
s = s[1:-1]
else:
s = ans
rag_text = "Query result With RAG:\n\n" + "".join(s).split("[/INST]")[0]
return without_rag_text, rag_text
def upload_file(file, session_data):
#file_paths = [file.name for file in files]
#file = files[0]
session_data = [initialize_vector_store_retriever(file)]
return gr.File(value=file, visible=True), session_data
def initialize_vector_store(session_data):
session_data = [initialize_vector_store_retriever()]
return session_data
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">Retrieval Augmented Generation</h1>""")
session_data = gr.State([])
file_output = gr.File(visible=False)
upload_button = gr.UploadButton("Click to Upload a text File to Vector Store", file_types=["text"], file_count="single")
upload_button.upload(upload_file, [upload_button, session_data], [file_output, session_data])
#initialize_VS_button = gr.Button("Load text file to Vector Store")
with gr.Row():
with gr.Column(scale=4):
question_text = gr.Textbox(show_label=False, placeholder="Ask a question", lines=2)
with gr.Column(scale=1):
llm_Choice = gr.Radio(["Llama 8B", "Mistral 7B"], value="Mistral 7B", label="Select lanaguage model:", info="")
query_Button = gr.Button("Query")
with gr.Row():
with gr.Column(scale=1):
without_rag_text = gr.Textbox(show_label=False, placeholder="Query result without using RAG", lines=15)
with gr.Column(scale=1):
rag_text = gr.Textbox(show_label=False, placeholder="Query result with RAG", lines=15)
#initialize_VS_button.click(
# initialize_vector_store,
# [session_data],
# [session_data],
# #show_progress=True,
#)
query_Button.click(
query,
[question_text, llm_Choice, session_data],
[without_rag_text, rag_text],
#show_progress=True,
)
demo.queue().launch(share=False, inbrowser=True)