Spaces:
Runtime error
Runtime error
import gradio as gr | |
import random | |
import time | |
from langchain import PromptTemplate | |
from langchain.llms import OpenAI | |
from langchain.chat_models import ChatOpenAI | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import Pinecone | |
from langchain.chains import LLMChain | |
from langchain.chains.retrieval_qa.base import RetrievalQA | |
from langchain.chains.question_answering import load_qa_chain | |
import pinecone | |
import os | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
#OPENAI_API_KEY = "" | |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") | |
OPENAI_TEMP = 1 | |
PINECONE_KEY = os.environ.get("PINECONE_KEY", "") | |
PINECONE_ENV = os.environ.get("PINECONE_ENV", "asia-northeast1-gcp") | |
PINECONE_INDEX = os.environ.get("PINECONE_INDEX", "3gpp") | |
EMBEDDING_MODEL = os.environ.get("PINECONE_INDEX", "sentence-transformers/all-mpnet-base-v2") | |
# return top-k text chunks from vector store | |
TOP_K_DEFAULT = 10 | |
TOP_K_MAX = 25 | |
BUTTON_MIN_WIDTH = 210 | |
STATUS_NOK = "404-MODEL UNREADY-critical" | |
STATUS_OK = "200-MODEL LOADED-9cf" | |
FORK_BADGE = "Fork-HuggingFace Space-9cf" | |
def get_logo(inputs, logo) -> str: | |
return f"""https://img.shields.io/badge/{inputs}?style=flat&logo={logo}&logoColor=white""" | |
def get_status(inputs) -> str: | |
return f"""<img | |
src = "{get_logo(inputs, "openai")}"; | |
style = "margin: 0 auto;" | |
>""" | |
KEY_INIT = "Initialize Model" | |
KEY_SUBMIT = "Submit" | |
KEY_CLEAR = "Clear" | |
MODEL_NULL = get_status(STATUS_NOK) | |
MODEL_DONE = get_status(STATUS_OK) | |
MODEL_WARNING = f"Please paste your OpenAI API Key from \ | |
[openai.com](https://platform.openai.com/account/api-keys) and then **{KEY_INIT}**" | |
TAB_1 = "Chatbot" | |
FAVICON = './icon.svg' | |
LLM_LIST = ["gpt-3.5-turbo", "text-davinci-003"] | |
DOC_1 = '3GPP' | |
DOC_2 = 'HTTP2' | |
DOC_SUPPORTED = [DOC_1, DOC_2] | |
DOC_DEFAULT = [DOC_1] | |
webui_title = """ | |
# OpenAI Chatbot Based on Vector Database | |
## Example of 3GPP | |
""" | |
dup_link = f'''<a href="https://huggingface.co/spaces/ShawnAI/3GPP-ChatBot?duplicate=true"> | |
<img src="{get_logo(FORK_BADGE, "addthis")}"></a> ''' | |
init_message = f"""Welcome to use 3GPP Chatbot, this demo toolkit is based on OpenAI with LangChain and Pinecone | |
1. Insert your OpenAI API key and click `{KEY_INIT}` | |
2. Insert your Question and click `{KEY_SUBMIT}` | |
""" | |
#---------------------------------------------------------------------------------------------------------- | |
#---------------------------------------------------------------------------------------------------------- | |
def init_model(api_key, emb_name, db_api_key, db_env, db_index): | |
try: | |
if (api_key and api_key.startswith("sk-") and len(api_key) > 50) and \ | |
(emb_name and db_api_key and db_env and db_index): | |
embeddings = HuggingFaceEmbeddings(model_name=emb_name) | |
pinecone.init(api_key = db_api_key, | |
environment = db_env) | |
#llm = OpenAI(temperature=OPENAI_TEMP, model_name="gpt-3.5-turbo-0301") | |
llm_dict = {} | |
for llm_name in LLM_LIST: | |
if llm_name == "gpt-3.5-turbo": | |
llm_dict[llm_name] = ChatOpenAI(model_name=llm_name, | |
temperature = OPENAI_TEMP, | |
openai_api_key = api_key) | |
else: | |
llm_dict[llm_name] = OpenAI(model_name=llm_name, | |
temperature = OPENAI_TEMP, | |
openai_api_key = api_key) | |
''' | |
ChatOpenAI(model_name="gpt-3.5-turbo", | |
temperature = OPENAI_TEMP, | |
openai_api_key = api_key) | |
chain_1 = load_qa_chain(llm, chain_type="stuff") | |
#LLMChain(llm=llm, prompt=condense_question_prompt) | |
chain_2 = LLMChain(llm = llm, | |
prompt = PromptTemplate(template='{question}', | |
input_variables=['question']), | |
output_key = 'output_text') | |
''' | |
db = Pinecone.from_existing_index(index_name = db_index, | |
embedding = embeddings) | |
return api_key, MODEL_DONE, llm_dict, None, db, None | |
else: | |
return None,MODEL_NULL,None,None,None,None | |
except Exception as e: | |
print(e) | |
return None,MODEL_NULL,None,None,None,None | |
def get_chat_history(inputs) -> str: | |
res = [] | |
for human, ai in inputs: | |
res.append(f"Human: {human}\nAI: {ai}") | |
return "\n".join(res) | |
def remove_duplicates(documents): | |
seen_content = set() | |
unique_documents = [] | |
for doc in documents: | |
if doc.page_content not in seen_content: | |
seen_content.add(doc.page_content) | |
unique_documents.append(doc) | |
return unique_documents | |
def doc_similarity(query, db, top_k): | |
docsearch = db.as_retriever(search_kwargs={'k':top_k}) | |
docs = docsearch.get_relevant_documents(query) | |
udocs = remove_duplicates(docs) | |
return udocs | |
def user(user_message, history): | |
return "", history+[[user_message, None]] | |
def bot(box_message, ref_message, | |
llm_dropdown, llm_dict, doc_list, | |
db, top_k): | |
# bot_message = random.choice(["Yes", "No"]) | |
# 0 is user question, 1 is bot response | |
question = box_message[-1][0] | |
history = box_message[:-1] | |
if (not llm_dict) or (not doc_check) or (not db): | |
box_message[-1][1] = MODEL_WARNING | |
return box_message, "", "" | |
if not ref_message: | |
ref_message = question | |
details = f"Q: {question}" | |
else: | |
details = f"Q: {question}\nR: {ref_message}" | |
llm = llm_dict[llm_dropdown] | |
print(llm) | |
print(doc_list) | |
if DOC_1 in doc_list: | |
chain = load_qa_chain(llm, chain_type="stuff") | |
docs = doc_similarity(ref_message, db, top_k) | |
delta_top_k = top_k - len(docs) | |
if delta_top_k > 0: | |
docs = doc_similarity(ref_message, db, top_k+delta_top_k) | |
else: | |
chain = LLMChain(llm = llm, | |
prompt = PromptTemplate(template='{question}', | |
input_variables=['question']), | |
output_key = 'output_text') | |
docs = [] | |
all_output = chain({"input_documents": docs, | |
"question": question, | |
"chat_history": get_chat_history(history)}) | |
bot_message = all_output['output_text'] | |
source = "".join([f"""<details> <summary>{doc.metadata["source"]}</summary> | |
{doc.page_content} | |
</details>""" for i, doc in enumerate(docs)]) | |
#print(source) | |
box_message[-1][1] = bot_message | |
return box_message, "", [[details, bot_message + source]] | |
#---------------------------------------------------------------------------------------------------------- | |
#---------------------------------------------------------------------------------------------------------- | |
with gr.Blocks( | |
title = TAB_1, | |
theme = "Base", | |
css = """.bigbox { | |
min-height:250px; | |
} | |
""") as demo: | |
llm = gr.State() | |
chain_2 = gr.State() # not inuse | |
vector_db = gr.State() | |
gr.Markdown(webui_title) | |
gr.HTML(dup_link) | |
gr.Markdown(init_message) | |
with gr.Row(): | |
with gr.Column(scale=10): | |
llm_api_textbox = gr.Textbox( | |
label = "OpenAI API Key", | |
# show_label = False, | |
value = OPENAI_API_KEY, | |
placeholder = "Paste Your OpenAI API Key (sk-...) and Hit ENTER", | |
lines=1, | |
type='password') | |
with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH): | |
init = gr.Button(KEY_INIT) #.style(full_width=False) | |
model_statusbox = gr.HTML(MODEL_NULL) | |
with gr.Tab(TAB_1): | |
with gr.Row(): | |
with gr.Column(scale=10): | |
chatbot = gr.Chatbot(elem_classes="bigbox") | |
#with gr.Column(scale=1): | |
with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH): | |
doc_check = gr.CheckboxGroup(choices = DOC_SUPPORTED, | |
value = DOC_DEFAULT, | |
label = "Reference Docs", | |
interactive=True) | |
llm_dropdown = gr.Dropdown(LLM_LIST, | |
value=LLM_LIST[0], | |
multiselect=False, | |
interactive=True, | |
label="LLM Selection", | |
) | |
with gr.Row(): | |
with gr.Column(scale=10): | |
query = gr.Textbox(label="Question:", | |
lines=2) | |
ref = gr.Textbox(label="Reference(optional):") | |
with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH): | |
clear = gr.Button(KEY_CLEAR) | |
submit = gr.Button(KEY_SUBMIT,variant="primary") | |
with gr.Tab("Details"): | |
top_k = gr.Slider(1, | |
TOP_K_MAX, | |
value=TOP_K_DEFAULT, | |
step=1, | |
label="Vector similarity top_k", | |
interactive=True) | |
detail_panel = gr.Chatbot(label="Related Docs") | |
with gr.Tab("Database"): | |
with gr.Row(): | |
emb_textbox = gr.Textbox( | |
label = "Embedding Model", | |
# show_label = False, | |
value = EMBEDDING_MODEL, | |
placeholder = "Paste Your Embedding Model Repo on HuggingFace", | |
lines=1, | |
interactive=True, | |
type='email') | |
with gr.Accordion("Pinecone Database for "+DOC_1): | |
with gr.Row(): | |
db_api_textbox = gr.Textbox( | |
label = "Pinecone API Key", | |
# show_label = False, | |
value = PINECONE_KEY, | |
placeholder = "Paste Your Pinecone API Key (xx-xx-xx-xx-xx) and Hit ENTER", | |
lines=1, | |
interactive=True, | |
type='password') | |
with gr.Row(): | |
db_env_textbox = gr.Textbox( | |
label = "Pinecone Environment", | |
# show_label = False, | |
value = PINECONE_ENV, | |
placeholder = "Paste Your Pinecone Environment (xx-xx-xx) and Hit ENTER", | |
lines=1, | |
interactive=True, | |
type='email') | |
db_index_textbox = gr.Textbox( | |
label = "Pinecone Index", | |
# show_label = False, | |
value = PINECONE_INDEX, | |
placeholder = "Paste Your Pinecone Index (xxxx) and Hit ENTER", | |
lines=1, | |
interactive=True, | |
type='email') | |
init_input = [llm_api_textbox, emb_textbox, db_api_textbox, db_env_textbox, db_index_textbox] | |
init_output = [llm_api_textbox, model_statusbox, | |
llm, chain_2, | |
vector_db, chatbot] | |
llm_api_textbox.submit(init_model, init_input, init_output) | |
init.click(init_model, init_input, init_output) | |
submit.click(user, | |
[query, chatbot], | |
[query, chatbot], | |
queue=False).then( | |
bot, | |
[chatbot, ref, | |
llm_dropdown, llm, doc_check, | |
vector_db, top_k], | |
[chatbot, ref, detail_panel] | |
) | |
clear.click(lambda: (None,None,None), None, [query, ref, chatbot], queue=False) | |
#---------------------------------------------------------------------------------------------------------- | |
#---------------------------------------------------------------------------------------------------------- | |
if __name__ == "__main__": | |
demo.launch(share = False, | |
inbrowser = True, | |
favicon_path = FAVICON) | |