import streamlit as st import pandas as pd from io import StringIO from generation import process_scores from model import AzureAgent, GPTAgent # Set up the Streamlit interface st.title('JobFair: A Benchmark for Fairness in LLM Employment Decision') st.sidebar.title('Model Settings') # Define a function to manage state initialization def initialize_state(): keys = ["model_submitted", "api_key", "endpoint_url", "deployment_name", "temperature", "max_tokens", "data_processed", "group_name","occupation", "privilege_label", "protect_label", "num_run", "uploaded_file"] defaults = [False, "", "", "", 0.5, 150, False,"", "", "", "", 1, None] for key, default in zip(keys, defaults): if key not in st.session_state: st.session_state[key] = default initialize_state() # Model selection and configuration model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent')) st.session_state.api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key) st.session_state.endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url) st.session_state.deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name) api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else '' st.session_state.temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01) st.session_state.max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens) if st.sidebar.button("Reset Model Info"): initialize_state() # Reset all state to defaults st.experimental_rerun() if st.sidebar.button("Submit Model Info"): st.session_state.model_submitted = True # Ensure experiment settings are only shown if model info is submitted if st.session_state.model_submitted: df = None file_options = st.radio("Choose file source:", ["Upload", "Example"]) if file_options == "Example": df = pd.read_csv("prompt_test.csv") else: st.session_state.uploaded_file = st.file_uploader("Choose a file") if st.session_state.uploaded_file is not None: data = StringIO(st.session_state.uploaded_file.getvalue().decode("utf-8")) df = pd.read_csv(data) if df is not None: st.write('Data:', df) st.session_state.occupation = st.text_input("Occupation", value=st.session_state.occupation) st.session_state.group_name = st.text_input("Group Name", value=st.session_state.group_name) st.session_state.privilege_label = st.text_input("Privilege Label", value=st.session_state.privilege_label) st.session_state.protect_label = st.text_input("Protect Label", value=st.session_state.protect_label) st.session_state.num_run = st.number_input("Number of Runs", 1, 10, st.session_state.num_run) if st.button('Process Data') and not st.session_state.data_processed: # Initialize the correct agent based on model type if model_type == 'AzureAgent': agent = AzureAgent(st.session_state.api_key, st.session_state.endpoint_url, st.session_state.deployment_name) else: agent = GPTAgent(st.session_state.api_key, st.session_state.endpoint_url, st.session_state.deployment_name, api_version) # Process data and display results with st.spinner('Processing data...'): parameters = {"temperature": st.session_state.temperature, "max_tokens": st.session_state.max_tokens} df = process_scores(df, st.session_state.num_run, parameters, st.session_state.privilege_label, st.session_state.protect_label, agent, st.session_state.group_name) st.session_state.data_processed = True # Mark as processed st.write('Processed Data:', df) if st.button("Reset Experiment Settings"): st.session_state.group_name = "" st.session_state.privilege_label = "" st.session_state.protect_label = "" st.session_state.num_run = 1 st.session_state.data_processed = False st.session_state.uploaded_file = None