import streamlit as st import pandas as pd from io import StringIO from generation import process_scores from model import AzureAgent, GPTAgent # Initialize session state variables if they don't already exist def initialize_state(): if 'data_processed' not in st.session_state: st.session_state.data_processed = False if 'api_key' not in st.session_state: st.session_state.api_key = "" if 'endpoint_url' not in st.session_state: st.session_state.endpoint_url = "" if 'deployment_name' not in st.session_state: st.session_state.deployment_name = "" if 'temperature' not in st.session_state: st.session_state.temperature = 0.5 if 'max_tokens' not in st.session_state: st.session_state.max_tokens = 150 if 'group_name' not in st.session_state: st.session_state.group_name = "" if 'privilege_label' not in st.session_state: st.session_state.privilege_label = "" if 'protect_label' not in st.session_state: st.session_state.protect_label = "" if 'num_run' not in st.session_state: st.session_state.num_run = 1 initialize_state() # Set up the Streamlit interface st.title('JobFair: A Benchmark for Fairness in LLM Employment Decision') st.sidebar.title('Model Settings') # Model selection and configuration model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent')) api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key) endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url) deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name) api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else '' temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01) max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens) # Reset buttons for model information if st.sidebar.button("Reset Model Info"): st.session_state.api_key = "" st.session_state.endpoint_url = "" st.session_state.deployment_name = "" st.session_state.temperature = 0.5 st.session_state.max_tokens = 150 st.experimental_rerun() submit_model_info = st.sidebar.button("Submit Model Info") # Data upload and processing with reset option if submit_model_info: parameters = {"temperature": temperature, "max_tokens": max_tokens} group_name = st.text_input("Group Name", value=st.session_state.group_name) privilege_label = st.text_input("Privilege Name", value=st.session_state.privilege_label) protect_label = st.text_input("Protect Name", value=st.session_state.protect_label) num_run = st.number_input("Number of runs", min_value=1, value=st.session_state.num_run) uploaded_file = st.file_uploader("Choose a file") # Reset button for experiment settings if st.button("Reset Experiment Settings"): st.session_state.group_name = "" st.session_state.privilege_label = "" st.session_state.protect_label = "" st.session_state.num_run = 1 st.session_state.data_processed = False st.experimental_rerun() if uploaded_file is not None: data = StringIO(uploaded_file.getvalue().decode("utf-8")) df = pd.read_csv(data) process_button = st.button('Process Data') if process_button and not st.session_state.data_processed: # Initialize the correct agent based on model type if model_type == 'AzureAgent': agent = AzureAgent(api_key, endpoint_url, deployment_name) else: agent = GPTAgent(api_key, endpoint_url, deployment_name, api_version) # Process data and display results with st.spinner('Processing data...'): df = process_scores(df, num_run, parameters, privilege_label, protect_label, agent, group_name) st.session_state.data_processed = True # Mark as processed st.write('Processed Data:', df) elif process_button and st.session_state.data_processed: st.warning("Data already processed for this session. Reset or re-upload to process new data.")