|
import streamlit as st |
|
import asyncio |
|
from autogen import AssistantAgent, UserProxyAgent |
|
|
|
|
|
st.write("""# AutoGen Chat Agents""") |
|
|
|
class TrackableAssistantAgent(AssistantAgent): |
|
def _process_received_message(self, message, sender, silent): |
|
with st.chat_message(sender.name): |
|
st.markdown(message) |
|
return super()._process_received_message(message, sender, silent) |
|
|
|
|
|
class TrackableUserProxyAgent(UserProxyAgent): |
|
def _process_received_message(self, message, sender, silent): |
|
with st.chat_message(sender.name): |
|
st.markdown(message) |
|
return super()._process_received_message(message, sender, silent) |
|
|
|
|
|
selected_model = None |
|
selected_key = None |
|
with st.sidebar: |
|
st.header("OpenAI Configuration") |
|
selected_model = st.selectbox("Model", ['gpt-3.5-turbo', 'gpt-4'], index=1) |
|
selected_key = st.text_input("API Key", type="password") |
|
|
|
with st.container(): |
|
|
|
|
|
|
|
user_input = st.chat_input("Type something...") |
|
if user_input: |
|
if not selected_key or not selected_model: |
|
st.warning( |
|
'You must provide valid OpenAI API key and choose preferred model', icon="⚠️") |
|
st.stop() |
|
|
|
llm_config = { |
|
"request_timeout": 600, |
|
"config_list": [ |
|
{ |
|
"model": selected_model, |
|
"api_key": selected_key |
|
} |
|
] |
|
} |
|
|
|
assistant = TrackableAssistantAgent( |
|
name="assistant", llm_config=llm_config) |
|
|
|
|
|
user_proxy = TrackableUserProxyAgent( |
|
name="user", human_input_mode="NEVER", llm_config=llm_config) |
|
|
|
|
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
|
|
|
|
async def initiate_chat(): |
|
await user_proxy.a_initiate_chat( |
|
assistant, |
|
message=user_input, |
|
) |
|
|
|
|
|
loop.run_until_complete(initiate_chat()) |
|
|