import gradio as gr import os from pathlib import Path import autogen import chromadb import multiprocessing as mp from autogen.retrieve_utils import TEXT_FORMATS, get_file_from_url, is_url from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent from autogen.agentchat.contrib.retrieve_user_proxy_agent import ( RetrieveUserProxyAgent, PROMPT_CODE, ) TIMEOUT = 60 def initialize_agents(config_list, docs_path=None): if isinstance(config_list, gr.State): _config_list = config_list.value else: _config_list = config_list if docs_path is None: docs_path = "https://raw.githubusercontent.com/microsoft/autogen/main/README.md" assistant = RetrieveAssistantAgent( name="assistant", system_message="You are a helpful assistant.", ) ragproxyagent = RetrieveUserProxyAgent( name="ragproxyagent", human_input_mode="NEVER", max_consecutive_auto_reply=5, retrieve_config={ "task": "code", "docs_path": docs_path, "chunk_token_size": 2000, "model": _config_list[0]["model"], "client": chromadb.PersistentClient(path="/tmp/chromadb"), "embedding_model": "all-mpnet-base-v2", "customized_prompt": PROMPT_CODE, "get_or_create": True, "collection_name": "autogen_rag", }, ) return assistant, ragproxyagent def initiate_chat(config_list, problem, queue, n_results=3): global assistant, ragproxyagent if isinstance(config_list, gr.State): _config_list = config_list.value else: _config_list = config_list if len(_config_list[0].get("api_key", "")) < 2: queue.put( ["Hi, nice to meet you! Please enter your API keys in below text boxs."] ) return else: llm_config = ( { "request_timeout": TIMEOUT, # "seed": 42, "config_list": _config_list, "use_cache": False, }, ) assistant.llm_config.update(llm_config[0]) assistant.reset() try: ragproxyagent.initiate_chat( assistant, problem=problem, silent=False, n_results=n_results ) messages = ragproxyagent.chat_messages messages = [messages[k] for k in messages.keys()][0] messages = [m["content"] for m in messages if m["role"] == "user"] print("messages: ", messages) except Exception as e: messages = [str(e)] queue.put(messages) def chatbot_reply(input_text): """Chat with the agent through terminal.""" queue = mp.Queue() process = mp.Process( target=initiate_chat, args=(config_list, input_text, queue), ) process.start() try: # process.join(TIMEOUT+2) messages = queue.get(timeout=TIMEOUT) except Exception as e: messages = [ str(e) if len(str(e)) > 0 else "Invalid Request to OpenAI, please check your API keys." ] finally: try: process.terminate() except: pass return messages def get_description_text(): return """ # Microsoft AutoGen: Retrieve Chat Demo This demo shows how to use the RetrieveUserProxyAgent and RetrieveAssistantAgent to build a chatbot. #### [GitHub](https://github.com/microsoft/autogen) [Discord](https://discord.gg/pAbnFJrkgZ) [Blog](https://microsoft.github.io/autogen/blog/2023/10/18/RetrieveChat) [Paper](https://arxiv.org/abs/2308.08155) """ global assistant, ragproxyagent with gr.Blocks() as demo: config_list, assistant, ragproxyagent = ( gr.State( [ { "api_key": "", "api_base": "", "api_type": "azure", "api_version": "2023-07-01-preview", "model": "gpt-35-turbo", } ] ), None, None, ) assistant, ragproxyagent = initialize_agents(config_list) gr.Markdown(get_description_text()) chatbot = gr.Chatbot( [], elem_id="chatbot", bubble_full_width=False, avatar_images=(None, (os.path.join(os.path.dirname(__file__), "autogen.png"))), # height=600, ) txt_input = gr.Textbox( scale=4, show_label=False, placeholder="Enter text and press enter", container=False, ) with gr.Row(): def update_config(config_list): global assistant, ragproxyagent config_list = autogen.config_list_from_models( model_list=[os.environ.get("MODEL", "gpt-35-turbo")], ) if not config_list: config_list = [ { "api_key": "", "api_base": "", "api_type": "azure", "api_version": "2023-07-01-preview", "model": "gpt-35-turbo", } ] llm_config = ( { "request_timeout": TIMEOUT, # "seed": 42, "config_list": config_list, }, ) assistant.llm_config.update(llm_config[0]) ragproxyagent._model = config_list[0]["model"] return config_list def set_params(model, oai_key, aoai_key, aoai_base): os.environ["MODEL"] = model os.environ["OPENAI_API_KEY"] = oai_key os.environ["AZURE_OPENAI_API_KEY"] = aoai_key os.environ["AZURE_OPENAI_API_BASE"] = aoai_base return model, oai_key, aoai_key, aoai_base txt_model = gr.Dropdown( label="Model", choices=[ "gpt-4", "gpt-35-turbo", "gpt-3.5-turbo", ], allow_custom_value=True, value="gpt-35-turbo", container=True, ) txt_oai_key = gr.Textbox( label="OpenAI API Key", placeholder="Enter key and press enter", max_lines=1, show_label=True, value=os.environ.get("OPENAI_API_KEY", ""), container=True, type="password", ) txt_aoai_key = gr.Textbox( label="Azure OpenAI API Key", placeholder="Enter key and press enter", max_lines=1, show_label=True, value=os.environ.get("AZURE_OPENAI_API_KEY", ""), container=True, type="password", ) txt_aoai_base_url = gr.Textbox( label="Azure OpenAI API Base", placeholder="Enter base url and press enter", max_lines=1, show_label=True, value=os.environ.get("AZURE_OPENAI_API_BASE", ""), container=True, type="password", ) clear = gr.ClearButton([txt_input, chatbot]) with gr.Row(): def upload_file(file): return update_context_url(file.name) upload_button = gr.UploadButton( "Click to upload a context file or enter a url in the right textbox", file_types=[f".{i}" for i in TEXT_FORMATS], file_count="single", ) txt_context_url = gr.Textbox( label="Enter the url to your context file and chat on the context", info=f"File must be in the format of [{', '.join(TEXT_FORMATS)}]", max_lines=1, show_label=True, value="https://raw.githubusercontent.com/microsoft/autogen/main/README.md", container=True, ) txt_prompt = gr.Textbox( label="Enter your prompt for Retrieve Agent and press enter to replace the default prompt", max_lines=40, show_label=True, value=PROMPT_CODE, container=True, show_copy_button=True, ) def respond(message, chat_history, model, oai_key, aoai_key, aoai_base): global config_list set_params(model, oai_key, aoai_key, aoai_base) config_list = update_config(config_list) messages = chatbot_reply(message) _msg = ( messages[-1] if len(messages) > 0 and messages[-1] != "TERMINATE" else messages[-2] if len(messages) > 1 else "Context is not enough for answering the question. Please press `enter` in the context url textbox to make sure the context is activated for the chat." ) chat_history.append((message, _msg)) return "", chat_history def update_prompt(prompt): ragproxyagent.customized_prompt = prompt return prompt def update_context_url(context_url): global assistant, ragproxyagent file_extension = Path(context_url).suffix print("file_extension: ", file_extension) if file_extension.lower() not in [f".{i}" for i in TEXT_FORMATS]: return f"File must be in the format of {TEXT_FORMATS}" if is_url(context_url): try: file_path = get_file_from_url( context_url, save_path=os.path.join("/tmp", os.path.basename(context_url)), ) except Exception as e: return str(e) else: file_path = context_url context_url = os.path.basename(context_url) try: chromadb.PersistentClient(path="/tmp/chromadb").delete_collection( name="autogen_rag" ) except: pass assistant, ragproxyagent = initialize_agents(config_list, docs_path=file_path) return context_url txt_input.submit( respond, [txt_input, chatbot, txt_model, txt_oai_key, txt_aoai_key, txt_aoai_base_url], [txt_input, chatbot], ) txt_prompt.submit(update_prompt, [txt_prompt], [txt_prompt]) txt_context_url.submit(update_context_url, [txt_context_url], [txt_context_url]) upload_button.upload(upload_file, upload_button, [txt_context_url]) if __name__ == "__main__": demo.launch(share=True, server_name="0.0.0.0")