Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import random | |
| import time | |
| from langchain import PromptTemplate | |
| from langchain.llms import OpenAI | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings, OpenAIEmbeddings | |
| from langchain.vectorstores import Pinecone | |
| from langchain.chains import LLMChain | |
| from langchain.chains.question_answering import load_qa_chain | |
| import pinecone | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| #OPENAI_API_KEY = "" | |
| OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") | |
| OPENAI_TEMP = 1 | |
| OPENAI_API_LINK = "[OpenAI API Key](https://platform.openai.com/account/api-keys)" | |
| OPENAI_LINK = "[OpenAI](https://openai.com)" | |
| PINECONE_KEY = os.environ.get("PINECONE_KEY", "") | |
| PINECONE_ENV = os.environ.get("PINECONE_ENV", "asia-northeast1-gcp") | |
| PINECONE_INDEX = os.environ.get("PINECONE_INDEX", '3gpp-r16') | |
| PINECONE_LINK = "[Pinecone](https://www.pinecone.io)" | |
| LANGCHAIN_LINK = "[LangChain](https://python.langchain.com/en/latest/index.html)" | |
| EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "hkunlp/instructor-large") | |
| EMBEDDING_LOADER = os.environ.get("EMBEDDING_LOADER", "HuggingFaceInstructEmbeddings") | |
| EMBEDDING_LIST = ["HuggingFaceInstructEmbeddings", "HuggingFaceEmbeddings", "OpenAIEmbeddings"] | |
| # return top-k text chunks from vector store | |
| TOP_K_DEFAULT = 15 | |
| TOP_K_MAX = 30 | |
| SCORE_DEFAULT = 0.33 | |
| BUTTON_MIN_WIDTH = 215 | |
| LLM_NULL = "LLM-UNLOAD-critical" | |
| LLM_DONE = "LLM-LOADED-9cf" | |
| DB_NULL = "DB-UNLOAD-critical" | |
| DB_DONE = "DB-LOADED-9cf" | |
| FORK_BADGE = "Fork-HuggingFace Space-9cf" | |
| def get_logo(inputs, logo) -> str: | |
| return f"""https://img.shields.io/badge/{inputs}?style=flat&logo={logo}&logoColor=white""" | |
| def get_status(inputs, logo, pos) -> str: | |
| return f"""<img | |
| src = "{get_logo(inputs, logo)}"; | |
| style = "margin: 0 auto;float:{pos};border: 2px solid transparent;"; | |
| >""" | |
| KEY_INIT = "Initialize Model" | |
| KEY_SUBMIT = "Submit" | |
| KEY_CLEAR = "Clear" | |
| MODEL_NULL = get_status(LLM_NULL, "openai", "right") | |
| MODEL_DONE = get_status(LLM_DONE, "openai", "right") | |
| DOCS_NULL = get_status(DB_NULL, "processingfoundation", "right") | |
| DOCS_DONE = get_status(DB_DONE, "processingfoundation", "right") | |
| TAB_1 = "Chatbot" | |
| TAB_2 = "Details" | |
| TAB_3 = "Database" | |
| TAB_4 = "TODO" | |
| FAVICON = './icon.svg' | |
| LLM_LIST = ["gpt-3.5-turbo", "text-davinci-003"] | |
| DOC_1 = '3GPP' | |
| DOC_2 = 'HTTP2' | |
| DOC_SUPPORTED = [DOC_1] | |
| DOC_DEFAULT = [DOC_1] | |
| DOC_LABEL = "Reference Docs" | |
| MODEL_WARNING = f"Please paste your **{OPENAI_API_LINK}** and then **{KEY_INIT}**" | |
| DOCS_WARNING = f"""Database Unloaded | |
| Please check your **{TAB_3}** config and then **{KEY_INIT}** | |
| Or you could uncheck **{DOC_LABEL}** to ask LLM directly""" | |
| webui_title = """ | |
| # OpenAI Chatbot Based on Vector Database | |
| """ | |
| dup_link = f'''<a href="https://huggingface.co/spaces/ShawnAI/VectorDB-ChatBot?duplicate=true" | |
| style="display:grid; width: 200px;"> | |
| <img src="{get_logo(FORK_BADGE, "addthis")}"></a>''' | |
| init_message = f"""This demonstration website is based on \ | |
| **{OPENAI_LINK}** with **{LANGCHAIN_LINK}** and **{PINECONE_LINK}** | |
| 1. Insert your **{OPENAI_API_LINK}** and click `{KEY_INIT}` | |
| 2. Insert your **Question** and click `{KEY_SUBMIT}` | |
| """ | |
| PROMPT_DOC = PromptTemplate( | |
| input_variables=["context", "chat_history", "question"], | |
| template="""Context: | |
| ## | |
| {context} | |
| ## | |
| Chat History: | |
| ## | |
| {chat_history} | |
| ## | |
| Question: | |
| {question} | |
| Answer:""" | |
| ) | |
| PROMPT_BASE = PromptTemplate( | |
| input_variables=['question', "chat_history"], | |
| template="""Chat History: | |
| ## | |
| {chat_history} | |
| ## | |
| Question: | |
| ## | |
| {question} | |
| ## | |
| Answer:""" | |
| ) | |
| #---------------------------------------------------------------------------------------------------------- | |
| #---------------------------------------------------------------------------------------------------------- | |
| def init_rwkv(): | |
| try: | |
| import rwkv | |
| return True | |
| except Exception: | |
| print("RWKV not found, skip local llm") | |
| return False | |
| def init_model(api_key, emb_name, emb_loader, db_api_key, db_env, db_index): | |
| init_rwkv() | |
| try: | |
| if not (api_key and api_key.startswith("sk-") and len(api_key) > 50): | |
| return None,MODEL_NULL+DOCS_NULL,None,None,None,None | |
| llm_dict = {} | |
| for llm_name in LLM_LIST: | |
| if llm_name == "gpt-3.5-turbo": | |
| llm_dict[llm_name] = ChatOpenAI(model_name=llm_name, | |
| temperature = OPENAI_TEMP, | |
| openai_api_key = api_key | |
| ) | |
| else: | |
| llm_dict[llm_name] = OpenAI(model_name=llm_name, | |
| temperature = OPENAI_TEMP, | |
| openai_api_key = api_key) | |
| if not (emb_name and db_api_key and db_env and db_index): | |
| return api_key,MODEL_DONE+DOCS_NULL,llm_dict,None,None,None | |
| if emb_loader == "OpenAIEmbeddings": | |
| embeddings = eval(emb_loader)(openai_api_key=api_key) | |
| else: | |
| embeddings = eval(emb_loader)(model_name=emb_name) | |
| pinecone.init(api_key = db_api_key, | |
| environment = db_env) | |
| db = Pinecone.from_existing_index(index_name = db_index, | |
| embedding = embeddings) | |
| return api_key, MODEL_DONE+DOCS_DONE, llm_dict, None, db, None | |
| except Exception as e: | |
| print(e) | |
| return None,MODEL_NULL+DOCS_NULL,None,None,None,None | |
| def get_chat_history(inputs) -> str: | |
| res = [] | |
| for human, ai in inputs: | |
| res.append(f"Q: {human}\nA: {ai}") | |
| return "\n".join(res) | |
| def remove_duplicates(documents, score_min): | |
| seen_content = set() | |
| unique_documents = [] | |
| for (doc, score) in documents: | |
| if (doc.page_content not in seen_content) and (score >= score_min): | |
| seen_content.add(doc.page_content) | |
| unique_documents.append(doc) | |
| return unique_documents | |
| def doc_similarity(query, db, top_k, score): | |
| docs = db.similarity_search_with_score(query = query, | |
| k=top_k) | |
| #docsearch = db.as_retriever(search_kwargs={'k':top_k}) | |
| #docs = docsearch.get_relevant_documents(query) | |
| udocs = remove_duplicates(docs, score) | |
| return udocs | |
| def user(user_message, history): | |
| return "", history+[[user_message, None]] | |
| def bot(box_message, ref_message, | |
| llm_dropdown, llm_dict, doc_list, | |
| db, top_k, score): | |
| # bot_message = random.choice(["Yes", "No"]) | |
| # 0 is user question, 1 is bot response | |
| question = box_message[-1][0] | |
| history = box_message[:-1] | |
| if (not llm_dict): | |
| box_message[-1][1] = MODEL_WARNING | |
| return box_message, "", "" | |
| if not ref_message: | |
| ref_message = question | |
| details = f"Q: {question}" | |
| else: | |
| details = f"Q: {question}\nR: {ref_message}" | |
| llm = llm_dict[llm_dropdown] | |
| if DOC_1 in doc_list: | |
| if (not db): | |
| box_message[-1][1] = DOCS_WARNING | |
| return box_message, "", "" | |
| docs = doc_similarity(ref_message, db, top_k, score) | |
| delta_top_k = top_k - len(docs) | |
| if delta_top_k > 0: | |
| docs = doc_similarity(ref_message, db, top_k+delta_top_k, score) | |
| prompt = PROMPT_DOC | |
| #chain = load_qa_chain(llm, chain_type="stuff") | |
| else: | |
| prompt = PROMPT_BASE | |
| docs = [] | |
| chain = LLMChain(llm = llm, | |
| prompt = prompt, | |
| output_key = 'output_text') | |
| all_output = chain({"question": question, | |
| "context": docs, | |
| "chat_history": get_chat_history(history) | |
| }) | |
| bot_message = all_output['output_text'] | |
| source = "".join([f"""<details> <summary>{doc.metadata["source"]}</summary> | |
| {doc.page_content} | |
| </details>""" for i, doc in enumerate(docs)]) | |
| #print(source) | |
| box_message[-1][1] = bot_message | |
| return box_message, "", [[details, bot_message + '\n\nMetadata:\n' + source]] | |
| #---------------------------------------------------------------------------------------------------------- | |
| #---------------------------------------------------------------------------------------------------------- | |
| with gr.Blocks( | |
| title = TAB_1, | |
| theme = "Base", | |
| css = """.bigbox { | |
| min-height:250px; | |
| } | |
| """) as demo: | |
| llm = gr.State() | |
| chain_2 = gr.State() # not inuse | |
| vector_db = gr.State() | |
| gr.Markdown(webui_title) | |
| gr.Markdown(dup_link) | |
| gr.Markdown(init_message) | |
| with gr.Row(): | |
| with gr.Column(scale=10): | |
| llm_api_textbox = gr.Textbox( | |
| label = "OpenAI API Key", | |
| # show_label = False, | |
| value = OPENAI_API_KEY, | |
| placeholder = "Paste Your OpenAI API Key (sk-...) and Hit ENTER", | |
| lines=1, | |
| type='password') | |
| with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH): | |
| init = gr.Button(KEY_INIT) #.style(full_width=False) | |
| model_statusbox = gr.HTML(MODEL_NULL+DOCS_NULL) | |
| with gr.Tab(TAB_1): | |
| with gr.Row(): | |
| with gr.Column(scale=10): | |
| chatbot = gr.Chatbot(elem_classes="bigbox") | |
| #with gr.Column(scale=1): | |
| with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH): | |
| doc_check = gr.CheckboxGroup(choices = DOC_SUPPORTED, | |
| value = DOC_DEFAULT, | |
| label = DOC_LABEL, | |
| interactive=True) | |
| llm_dropdown = gr.Dropdown(LLM_LIST, | |
| value=LLM_LIST[0], | |
| multiselect=False, | |
| interactive=True, | |
| label="LLM Selection", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=10): | |
| query = gr.Textbox(label="Question:", | |
| lines=2) | |
| ref = gr.Textbox(label="Reference(optional):") | |
| with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH): | |
| clear = gr.Button(KEY_CLEAR) | |
| submit = gr.Button(KEY_SUBMIT,variant="primary") | |
| with gr.Tab(TAB_2): | |
| with gr.Row(): | |
| with gr.Column(): | |
| top_k = gr.Slider(1, | |
| TOP_K_MAX, | |
| value=TOP_K_DEFAULT, | |
| step=1, | |
| label="Vector similarity top_k", | |
| interactive=True) | |
| with gr.Column(): | |
| score = gr.Slider(0.01, | |
| 0.99, | |
| value=SCORE_DEFAULT, | |
| step=0.01, | |
| label="Vector similarity score", | |
| interactive=True) | |
| detail_panel = gr.Chatbot(label="Related Docs") | |
| with gr.Tab(TAB_3): | |
| with gr.Row(): | |
| with gr.Column(): | |
| emb_textbox = gr.Textbox( | |
| label = "Embedding Model", | |
| # show_label = False, | |
| value = EMBEDDING_MODEL, | |
| placeholder = "Paste Your Embedding Model Repo on HuggingFace", | |
| lines=1, | |
| interactive=True, | |
| type='email') | |
| with gr.Column(): | |
| emb_dropdown = gr.Dropdown( | |
| EMBEDDING_LIST, | |
| value=EMBEDDING_LOADER, | |
| multiselect=False, | |
| interactive=True, | |
| label="Embedding Loader") | |
| with gr.Accordion("Pinecone Database for "+DOC_1): | |
| with gr.Row(): | |
| db_api_textbox = gr.Textbox( | |
| label = "Pinecone API Key", | |
| # show_label = False, | |
| value = PINECONE_KEY, | |
| placeholder = "Paste Your Pinecone API Key (xx-xx-xx-xx-xx) and Hit ENTER", | |
| lines=1, | |
| interactive=True, | |
| type='password') | |
| with gr.Row(): | |
| db_env_textbox = gr.Textbox( | |
| label = "Pinecone Environment", | |
| # show_label = False, | |
| value = PINECONE_ENV, | |
| placeholder = "Paste Your Pinecone Environment (xx-xx-xx) and Hit ENTER", | |
| lines=1, | |
| interactive=True, | |
| type='email') | |
| db_index_textbox = gr.Textbox( | |
| label = "Pinecone Index", | |
| # show_label = False, | |
| value = PINECONE_INDEX, | |
| placeholder = "Paste Your Pinecone Index (xxxx) and Hit ENTER", | |
| lines=1, | |
| interactive=True, | |
| type='email') | |
| with gr.Tab(TAB_4): | |
| "TODO" | |
| init_input = [llm_api_textbox, emb_textbox, emb_dropdown, db_api_textbox, db_env_textbox, db_index_textbox] | |
| init_output = [llm_api_textbox, model_statusbox, | |
| llm, chain_2, | |
| vector_db, chatbot] | |
| llm_api_textbox.submit(init_model, init_input, init_output) | |
| init.click(init_model, init_input, init_output) | |
| submit.click(user, | |
| [query, chatbot], | |
| [query, chatbot], | |
| queue=False).then( | |
| bot, | |
| [chatbot, ref, | |
| llm_dropdown, llm, doc_check, | |
| vector_db, top_k, score], | |
| [chatbot, ref, detail_panel] | |
| ) | |
| clear.click(lambda: (None,None,None), None, [query, ref, chatbot], queue=False) | |
| #---------------------------------------------------------------------------------------------------------- | |
| #---------------------------------------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| demo.launch(share = False, | |
| inbrowser = True, | |
| favicon_path = FAVICON) | |