File size: 5,720 Bytes
8fe2d96 48448ec 8fe2d96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# 1. Using langchain Vector store
# https://python.langchain.com/v0.1/docs/modules/data_connection/vectorstores/
# VectorStore - FAISS
# 2. Embedding - HuggingFaceInferenceAPIEmbeddings with "BAAI/bge-base-en-v1.5"
# 3. llm use mistral and llama.
# "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
# "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
import gradio as gr
import os
from langchain.prompts import ChatPromptTemplate
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain.schema.runnable import RunnablePassthrough
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
API_TOKEN = os.environ.get('HUGGINGFACE_API_KEY')
HF_API_KEY = API_TOKEN
llm_urls = {
"Mistral 7B": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2",
"Llama 8B": "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
}
def initialize_vector_store_retriever(file):
# Load the document, split it into chunks, embed each chunk and load it into the vector store.
#raw_documents = TextLoader('./llm.txt').load()
raw_documents = TextLoader(file).load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
documents = text_splitter.split_documents(raw_documents)
API_URL = "https://api-inference.huggingface.co/models/BAAI/bge-base-en-v1.5"
embeddings = HuggingFaceInferenceAPIEmbeddings(
endpoint_url=API_URL,
api_key=HF_API_KEY,
)
db = FAISS.from_documents(documents, embeddings)
retriever = db.as_retriever()
return retriever
def generate_llm_rag_prompt() -> ChatPromptTemplate:
#template = "<s>[INST] {context} {prompt} [/INST]"
template = "<s>[INST] <<SYS>>{system}<</SYS>>{context} {prompt} [/INST]"
prompt_template = ChatPromptTemplate.from_template(template)
return prompt_template
def create_chain(retriever, llm):
url = llm_urls[llm]
model_endpoint = HuggingFaceEndpoint(
endpoint_url=url,
huggingfacehub_api_token=HF_API_KEY,
task="text2text-generation",
max_new_tokens=200
)
if retriever != None:
def get_system(input):
return "You are a helpful and honest assistant. Please, respond concisely and truthfully."
retrieval = {"context": retriever, "prompt": RunnablePassthrough(), "system": get_system}
chain = retrieval | generate_llm_rag_prompt() | model_endpoint
return chain, model_endpoint
else:
return None, model_endpoint
def query(question_text, llm, session_data):
if question_text == "":
without_rag_text = "Query result without RAG is not available. Enter a question first."
rag_text = "Query result with RAG is not available. Enter a question first."
return without_rag_text, rag_text
if len(session_data)>0:
retriever = session_data[0]
else:
retriever = None
chain, model_endpoint = create_chain(retriever, llm)
without_rag_text = "Query result without RAG:\n\n" + model_endpoint(question_text).strip()
if (retriever == None):
rag_text = "Query result With RAG is not available. Load Vector Store first."
else:
ans = chain.invoke(question_text).strip()
s = ans
s = [s.split("[INST] <<SYS>>")[1] for s in s.split("[/SYS]>[/INST]") if s.find("[INST] <<SYS>>") >=0]
if len(s) >= 2:
s = s[1:-1]
else:
s = ans
rag_text = "Query result With RAG:\n\n" + "".join(s).split("[/INST]")[0]
return without_rag_text, rag_text
def upload_file(file, session_data):
#file_paths = [file.name for file in files]
#file = files[0]
session_data = [initialize_vector_store_retriever(file)]
return gr.File(value=file, visible=True), session_data
def initialize_vector_store(session_data):
session_data = [initialize_vector_store_retriever()]
return session_data
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">Retrieval Augmented Generation</h1>""")
session_data = gr.State([])
file_output = gr.File(visible=False)
upload_button = gr.UploadButton("Click to Upload a text File to Vector Store", file_types=["text"], file_count="single")
upload_button.upload(upload_file, [upload_button, session_data], [file_output, session_data])
#initialize_VS_button = gr.Button("Load text file to Vector Store")
with gr.Row():
with gr.Column(scale=4):
question_text = gr.Textbox(show_label=False, placeholder="Ask a question", lines=2)
with gr.Column(scale=1):
llm_Choice = gr.Radio(["Llama 8B", "Mistral 7B"], value="Mistral 7B", label="Select lanaguage model:", info="")
query_Button = gr.Button("Query")
with gr.Row():
with gr.Column(scale=1):
without_rag_text = gr.Textbox(show_label=False, placeholder="Query result without using RAG", lines=15)
with gr.Column(scale=1):
rag_text = gr.Textbox(show_label=False, placeholder="Query result with RAG", lines=15)
#initialize_VS_button.click(
# initialize_vector_store,
# [session_data],
# [session_data],
# #show_progress=True,
#)
query_Button.click(
query,
[question_text, llm_Choice, session_data],
[without_rag_text, rag_text],
#show_progress=True,
)
demo.queue().launch(share=False, inbrowser=True) |