|
import os |
|
|
|
|
|
if 'OPENAI_API_KEY' not in os.environ: |
|
os.environ['OPENAI_API_KEY'] = 'none' |
|
|
|
import openai |
|
import pandas as pd |
|
import streamlit as st |
|
from IPython.core.display import HTML |
|
from PIL import Image |
|
from langchain.callbacks import wandb_tracing_enabled |
|
from chemcrow.agents import ChemCrow, make_tools |
|
from chemcrow.frontend.streamlit_callback_handler import \ |
|
StreamlitCallbackHandlerChem |
|
from utils import oai_key_isvalid |
|
|
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
ss = st.session_state |
|
ss.prompt = None |
|
|
|
icon = Image.open('assets/logo0.png') |
|
st.set_page_config( |
|
page_title="ChemCrow", |
|
page_icon = icon |
|
) |
|
|
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
[data-testid="stSidebar"][aria-expanded="true"]{ |
|
min-width: 450px; |
|
max-width: 450px; |
|
} |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
agent = ChemCrow( |
|
model='gpt-4', |
|
temp=0.1, |
|
openai_api_key=ss.get('api_key'), |
|
api_keys={ |
|
'rxn4chem': st.secrets['RXN4CHEM_API_KEY'] |
|
} |
|
).agent_executor |
|
|
|
tools = agent.tools |
|
|
|
tool_list = pd.Series( |
|
{f"✅ {t.name}":t.description for t in tools} |
|
).reset_index() |
|
tool_list.columns = ['Tool', 'Description'] |
|
|
|
def on_api_key_change(): |
|
api_key = ss.get('api_key') or os.getenv('OPENAI_API_KEY') |
|
|
|
if not oai_key_isvalid(api_key): |
|
st.write("Please input a valid OpenAI API key.") |
|
|
|
|
|
def run_prompt(prompt): |
|
st.chat_message("user").write(prompt) |
|
with st.chat_message("assistant"): |
|
st_callback = StreamlitCallbackHandlerChem( |
|
st.container(), |
|
max_thought_containers = 2, |
|
collapse_completed_thoughts = False, |
|
output_placeholder=ss |
|
) |
|
try: |
|
with wandb_tracing_enabled(): |
|
response = agent.run(prompt, callbacks=[st_callback]) |
|
st.write(response) |
|
except openai.error.AuthenticationError: |
|
st.write("Please input a valid OpenAI API key") |
|
except openai.error.APIError: |
|
|
|
print("OpenAI API error, please try again!") |
|
|
|
|
|
pre_prompts = [ |
|
'How can I synthesize safinamide?', |
|
( |
|
'Predict the product of a mixture of Ethylidenecyclohexane and HBr. ' |
|
'Then predict the same reaction, adding methyl peroxide into the ' |
|
'mixture. Compare the two products and explain the reaction mechanism.' |
|
), |
|
( |
|
'What is the boiling point of the reaction product between ' |
|
'isoamyl alcohol and acetic acid?' |
|
), |
|
( |
|
'Find 3 alkaloids in the cannabis plant, ' |
|
'and calculate their similarity with caffein' |
|
) |
|
] |
|
|
|
|
|
with st.sidebar: |
|
chemcrow_logo = Image.open('assets/chemcrow-logo-bold-new.png') |
|
st.image(chemcrow_logo) |
|
|
|
|
|
st.text_input( |
|
'Input your OpenAI API key.', |
|
placeholder = 'Input your OpenAI API key.', |
|
type='password', |
|
key='api_key', |
|
on_change=on_api_key_change, |
|
label_visibility="collapsed" |
|
) |
|
|
|
|
|
st.markdown('# What can I ask?') |
|
cols = st.columns(2) |
|
with cols[0]: |
|
st.button( |
|
"How can I synthesize safinamide?", |
|
on_click=lambda: run_prompt(pre_prompts[0]), |
|
) |
|
st.button( |
|
"Explain mechanism of bromoaddition reaction", |
|
on_click=lambda: run_prompt(pre_prompts[1]), |
|
) |
|
with cols[1]: |
|
st.button( |
|
'Predict properties of a reaction product', |
|
on_click=lambda: run_prompt(pre_prompts[2]), |
|
) |
|
st.button( |
|
'Similarities between alkaloids', |
|
on_click=lambda: run_prompt(pre_prompts[3]), |
|
) |
|
|
|
st.markdown('---') |
|
|
|
st.markdown(f"# {len(tool_list)} available tools") |
|
st.dataframe( |
|
tool_list, |
|
use_container_width=True, |
|
hide_index=True, |
|
height=200 |
|
) |
|
|
|
|
|
if user_input := st.chat_input(): |
|
run_prompt(user_input) |
|
|