ShawnAI's picture
Update app.py
3f8dc6c
raw
history blame contribute delete
No virus
14.4 kB
import gradio as gr
import random
import time
from langchain import PromptTemplate
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings, OpenAIEmbeddings
from langchain.vectorstores import Pinecone
from langchain.chains import LLMChain
from langchain.chains.question_answering import load_qa_chain
import pinecone
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
#OPENAI_API_KEY = ""
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_TEMP = 1
OPENAI_API_LINK = "[OpenAI API Key](https://platform.openai.com/account/api-keys)"
OPENAI_LINK = "[OpenAI](https://openai.com)"
PINECONE_KEY = os.environ.get("PINECONE_KEY", "")
PINECONE_ENV = os.environ.get("PINECONE_ENV", "asia-northeast1-gcp")
PINECONE_INDEX = os.environ.get("PINECONE_INDEX", '3gpp-r16')
PINECONE_LINK = "[Pinecone](https://www.pinecone.io)"
LANGCHAIN_LINK = "[LangChain](https://python.langchain.com/en/latest/index.html)"
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "hkunlp/instructor-large")
EMBEDDING_LOADER = os.environ.get("EMBEDDING_LOADER", "HuggingFaceInstructEmbeddings")
EMBEDDING_LIST = ["HuggingFaceInstructEmbeddings", "HuggingFaceEmbeddings", "OpenAIEmbeddings"]
# return top-k text chunks from vector store
TOP_K_DEFAULT = 15
TOP_K_MAX = 30
SCORE_DEFAULT = 0.33
BUTTON_MIN_WIDTH = 215
LLM_NULL = "LLM-UNLOAD-critical"
LLM_DONE = "LLM-LOADED-9cf"
DB_NULL = "DB-UNLOAD-critical"
DB_DONE = "DB-LOADED-9cf"
FORK_BADGE = "Fork-HuggingFace Space-9cf"
def get_logo(inputs, logo) -> str:
return f"""https://img.shields.io/badge/{inputs}?style=flat&logo={logo}&logoColor=white"""
def get_status(inputs, logo, pos) -> str:
return f"""<img
src = "{get_logo(inputs, logo)}";
style = "margin: 0 auto;float:{pos};border: 2px solid transparent;";
>"""
KEY_INIT = "Initialize Model"
KEY_SUBMIT = "Submit"
KEY_CLEAR = "Clear"
MODEL_NULL = get_status(LLM_NULL, "openai", "right")
MODEL_DONE = get_status(LLM_DONE, "openai", "right")
DOCS_NULL = get_status(DB_NULL, "processingfoundation", "right")
DOCS_DONE = get_status(DB_DONE, "processingfoundation", "right")
TAB_1 = "Chatbot"
TAB_2 = "Details"
TAB_3 = "Database"
TAB_4 = "TODO"
FAVICON = './icon.svg'
LLM_LIST = ["gpt-3.5-turbo", "text-davinci-003"]
DOC_1 = '3GPP'
DOC_2 = 'HTTP2'
DOC_SUPPORTED = [DOC_1]
DOC_DEFAULT = [DOC_1]
DOC_LABEL = "Reference Docs"
MODEL_WARNING = f"Please paste your **{OPENAI_API_LINK}** and then **{KEY_INIT}**"
DOCS_WARNING = f"""Database Unloaded
Please check your **{TAB_3}** config and then **{KEY_INIT}**
Or you could uncheck **{DOC_LABEL}** to ask LLM directly"""
webui_title = """
# OpenAI Chatbot Based on Vector Database
"""
dup_link = f'''<a href="https://huggingface.co/spaces/ShawnAI/VectorDB-ChatBot?duplicate=true"
style="display:grid; width: 200px;">
<img src="{get_logo(FORK_BADGE, "addthis")}"></a>'''
init_message = f"""This demonstration website is based on \
**{OPENAI_LINK}** with **{LANGCHAIN_LINK}** and **{PINECONE_LINK}**
1. Insert your **{OPENAI_API_LINK}** and click `{KEY_INIT}`
2. Insert your **Question** and click `{KEY_SUBMIT}`
"""
PROMPT_DOC = PromptTemplate(
input_variables=["context", "chat_history", "question"],
template="""Context:
##
{context}
##
Chat History:
##
{chat_history}
##
Question:
{question}
Answer:"""
)
PROMPT_BASE = PromptTemplate(
input_variables=['question', "chat_history"],
template="""Chat History:
##
{chat_history}
##
Question:
##
{question}
##
Answer:"""
)
#----------------------------------------------------------------------------------------------------------
#----------------------------------------------------------------------------------------------------------
def init_rwkv():
try:
import rwkv
return True
except Exception:
print("RWKV not found, skip local llm")
return False
def init_model(api_key, emb_name, emb_loader, db_api_key, db_env, db_index):
init_rwkv()
try:
if not (api_key and api_key.startswith("sk-") and len(api_key) > 50):
return None,MODEL_NULL+DOCS_NULL,None,None,None,None
llm_dict = {}
for llm_name in LLM_LIST:
if llm_name == "gpt-3.5-turbo":
llm_dict[llm_name] = ChatOpenAI(model_name=llm_name,
temperature = OPENAI_TEMP,
openai_api_key = api_key
)
else:
llm_dict[llm_name] = OpenAI(model_name=llm_name,
temperature = OPENAI_TEMP,
openai_api_key = api_key)
if not (emb_name and db_api_key and db_env and db_index):
return api_key,MODEL_DONE+DOCS_NULL,llm_dict,None,None,None
if emb_loader == "OpenAIEmbeddings":
embeddings = eval(emb_loader)(openai_api_key=api_key)
else:
embeddings = eval(emb_loader)(model_name=emb_name)
pinecone.init(api_key = db_api_key,
environment = db_env)
db = Pinecone.from_existing_index(index_name = db_index,
embedding = embeddings)
return api_key, MODEL_DONE+DOCS_DONE, llm_dict, None, db, None
except Exception as e:
print(e)
return None,MODEL_NULL+DOCS_NULL,None,None,None,None
def get_chat_history(inputs) -> str:
res = []
for human, ai in inputs:
res.append(f"Q: {human}\nA: {ai}")
return "\n".join(res)
def remove_duplicates(documents, score_min):
seen_content = set()
unique_documents = []
for (doc, score) in documents:
if (doc.page_content not in seen_content) and (score >= score_min):
seen_content.add(doc.page_content)
unique_documents.append(doc)
return unique_documents
def doc_similarity(query, db, top_k, score):
docs = db.similarity_search_with_score(query = query,
k=top_k)
#docsearch = db.as_retriever(search_kwargs={'k':top_k})
#docs = docsearch.get_relevant_documents(query)
udocs = remove_duplicates(docs, score)
return udocs
def user(user_message, history):
return "", history+[[user_message, None]]
def bot(box_message, ref_message,
llm_dropdown, llm_dict, doc_list,
db, top_k, score):
# bot_message = random.choice(["Yes", "No"])
# 0 is user question, 1 is bot response
question = box_message[-1][0]
history = box_message[:-1]
if (not llm_dict):
box_message[-1][1] = MODEL_WARNING
return box_message, "", ""
if not ref_message:
ref_message = question
details = f"Q: {question}"
else:
details = f"Q: {question}\nR: {ref_message}"
llm = llm_dict[llm_dropdown]
if DOC_1 in doc_list:
if (not db):
box_message[-1][1] = DOCS_WARNING
return box_message, "", ""
docs = doc_similarity(ref_message, db, top_k, score)
delta_top_k = top_k - len(docs)
if delta_top_k > 0:
docs = doc_similarity(ref_message, db, top_k+delta_top_k, score)
prompt = PROMPT_DOC
#chain = load_qa_chain(llm, chain_type="stuff")
else:
prompt = PROMPT_BASE
docs = []
chain = LLMChain(llm = llm,
prompt = prompt,
output_key = 'output_text')
all_output = chain({"question": question,
"context": docs,
"chat_history": get_chat_history(history)
})
bot_message = all_output['output_text']
source = "".join([f"""<details> <summary>{doc.metadata["source"]}</summary>
{doc.page_content}
</details>""" for i, doc in enumerate(docs)])
#print(source)
box_message[-1][1] = bot_message
return box_message, "", [[details, bot_message + '\n\nMetadata:\n' + source]]
#----------------------------------------------------------------------------------------------------------
#----------------------------------------------------------------------------------------------------------
with gr.Blocks(
title = TAB_1,
theme = "Base",
css = """.bigbox {
min-height:250px;
}
""") as demo:
llm = gr.State()
chain_2 = gr.State() # not inuse
vector_db = gr.State()
gr.Markdown(webui_title)
gr.Markdown(dup_link)
gr.Markdown(init_message)
with gr.Row():
with gr.Column(scale=10):
llm_api_textbox = gr.Textbox(
label = "OpenAI API Key",
# show_label = False,
value = OPENAI_API_KEY,
placeholder = "Paste Your OpenAI API Key (sk-...) and Hit ENTER",
lines=1,
type='password')
with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH):
init = gr.Button(KEY_INIT) #.style(full_width=False)
model_statusbox = gr.HTML(MODEL_NULL+DOCS_NULL)
with gr.Tab(TAB_1):
with gr.Row():
with gr.Column(scale=10):
chatbot = gr.Chatbot(elem_classes="bigbox")
#with gr.Column(scale=1):
with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH):
doc_check = gr.CheckboxGroup(choices = DOC_SUPPORTED,
value = DOC_DEFAULT,
label = DOC_LABEL,
interactive=True)
llm_dropdown = gr.Dropdown(LLM_LIST,
value=LLM_LIST[0],
multiselect=False,
interactive=True,
label="LLM Selection",
)
with gr.Row():
with gr.Column(scale=10):
query = gr.Textbox(label="Question:",
lines=2)
ref = gr.Textbox(label="Reference(optional):")
with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH):
clear = gr.Button(KEY_CLEAR)
submit = gr.Button(KEY_SUBMIT,variant="primary")
with gr.Tab(TAB_2):
with gr.Row():
with gr.Column():
top_k = gr.Slider(1,
TOP_K_MAX,
value=TOP_K_DEFAULT,
step=1,
label="Vector similarity top_k",
interactive=True)
with gr.Column():
score = gr.Slider(0.01,
0.99,
value=SCORE_DEFAULT,
step=0.01,
label="Vector similarity score",
interactive=True)
detail_panel = gr.Chatbot(label="Related Docs")
with gr.Tab(TAB_3):
with gr.Row():
with gr.Column():
emb_textbox = gr.Textbox(
label = "Embedding Model",
# show_label = False,
value = EMBEDDING_MODEL,
placeholder = "Paste Your Embedding Model Repo on HuggingFace",
lines=1,
interactive=True,
type='email')
with gr.Column():
emb_dropdown = gr.Dropdown(
EMBEDDING_LIST,
value=EMBEDDING_LOADER,
multiselect=False,
interactive=True,
label="Embedding Loader")
with gr.Accordion("Pinecone Database for "+DOC_1):
with gr.Row():
db_api_textbox = gr.Textbox(
label = "Pinecone API Key",
# show_label = False,
value = PINECONE_KEY,
placeholder = "Paste Your Pinecone API Key (xx-xx-xx-xx-xx) and Hit ENTER",
lines=1,
interactive=True,
type='password')
with gr.Row():
db_env_textbox = gr.Textbox(
label = "Pinecone Environment",
# show_label = False,
value = PINECONE_ENV,
placeholder = "Paste Your Pinecone Environment (xx-xx-xx) and Hit ENTER",
lines=1,
interactive=True,
type='email')
db_index_textbox = gr.Textbox(
label = "Pinecone Index",
# show_label = False,
value = PINECONE_INDEX,
placeholder = "Paste Your Pinecone Index (xxxx) and Hit ENTER",
lines=1,
interactive=True,
type='email')
with gr.Tab(TAB_4):
"TODO"
init_input = [llm_api_textbox, emb_textbox, emb_dropdown, db_api_textbox, db_env_textbox, db_index_textbox]
init_output = [llm_api_textbox, model_statusbox,
llm, chain_2,
vector_db, chatbot]
llm_api_textbox.submit(init_model, init_input, init_output)
init.click(init_model, init_input, init_output)
submit.click(user,
[query, chatbot],
[query, chatbot],
queue=False).then(
bot,
[chatbot, ref,
llm_dropdown, llm, doc_check,
vector_db, top_k, score],
[chatbot, ref, detail_panel]
)
clear.click(lambda: (None,None,None), None, [query, ref, chatbot], queue=False)
#----------------------------------------------------------------------------------------------------------
#----------------------------------------------------------------------------------------------------------
if __name__ == "__main__":
demo.launch(share = False,
inbrowser = True,
favicon_path = FAVICON)