convosim-ui / hidden_pages /manual_comparisor.py
ivnban27-ctl's picture
feat/MVP_GCT_SP (#2)
9ff00d4 verified
raw
history blame
4.02 kB
import os
import random
import datetime as dt
import streamlit as st
from streamlit.logger import get_logger
from langchain.schema.messages import HumanMessage
from utils.mongo_utils import get_db_client, new_battle_result, get_non_assesed_comparison, new_completion_error
from app_config import ISSUES, SOURCES
logger = get_logger(__name__)
openai_api_key = os.environ['OPENAI_API_KEY']
if 'db_client' not in st.session_state:
st.session_state["db_client"] = get_db_client()
def disable_buttons():
return len(comparison) == 0
def replaceA():
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='model_two'
)
def replaceB():
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='model_one'
)
def regenerateBoth():
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='both_bad'
)
def bothGood():
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='tie'
)
def error2db(model):
logger.info(f"error logged for {model}")
new_completion_error(st.session_state['db_client'],
st.session_state['comparison_id'],
username, model
)
def error2dbA():
error2db(sourceA)
def error2dbB():
error2db(sourceB)
with st.sidebar:
username = st.text_input("Username", value='ivnban-ctl', max_chars=30)
comparison = get_non_assesed_comparison(st.session_state["db_client"], username)
with st.sidebar:
sbcol1, sbcol2 = st.columns(2)
beta = sbcol1.button("A is better", on_click=replaceB, disabled=disable_buttons())
betb = sbcol2.button("B is better", on_click=replaceA, disabled=disable_buttons())
same = sbcol1.button("Tie", on_click=bothGood, disabled=disable_buttons())
bbad = sbcol2.button("Both are bad", on_click=regenerateBoth, disabled=disable_buttons())
errorA = sbcol1.button("Error in A", on_click=error2dbA, disabled=disable_buttons())
errorB = sbcol2.button("Error in B", on_click=error2dbB, disabled=disable_buttons())
if len(comparison) > 0:
st.session_state['comparison_id'] = comparison[0]["_id"]
st.session_state['convo_id'] = comparison[0]["convo_id"]
st.session_state["disabled_buttons"] = False
st.sidebar.text_input("Issue", value=comparison[0]['convo_info'][0]['issue'], disabled=True)
st.title(f"πŸ’¬ History")
for msg in comparison[0]['chat_history'].split("\n"):
parts = msg.split(":")
if len(parts) > 1:
role = "user" if parts[0] == 'helper' else "assistant"
st.chat_message(role).write(parts[1])
col1, col2 = st.columns(2)
col1.title(f"πŸ’¬ Simulator A")
col2.title(f"πŸ’¬ Simulator B")
selectedA = random.choice(['model_one', 'model_two'])
selectedB = "model_two" if selectedA == "model_one" else "model_one"
sourceA = comparison[0]['convo_info'][0][selectedA]
sourceB = comparison[0]['convo_info'][0][selectedB]
logger.info(f"selected A is {sourceA} and B is {sourceB}")
col1.chat_message("user").write(comparison[0]["prompt"])
col2.chat_message("user").write(comparison[0]["prompt"])
col1.chat_message("assistant").write(comparison[0][f"compeltion_{selectedA}"])
col2.chat_message("assistant").write(comparison[0][f"compeltion_{selectedB}"])
else:
st.write("No Comparisons left to Check")