job-fair / pages /1_Generation_Demo.py
Zekun Wu
update
225f4f1
raw
history blame
No virus
6.22 kB
import streamlit as st
import pandas as pd
from io import StringIO
from util.generation import process_scores
from util.model import AzureAgent, GPTAgent
from util.analysis import statistical_tests, result_evaluation
# Set up the Streamlit interface
st.title('JobFair: A Benchmark for Fairness in LLM Employment Decision')
st.sidebar.title('Model Settings')
# Define a function to manage state initialization
def initialize_state():
keys = ["model_submitted", "api_key", "endpoint_url", "deployment_name", "temperature", "max_tokens",
"data_processed", "group_name","occupation", "privilege_label", "protect_label", "num_run", "uploaded_file"]
defaults = [False, "", "https://safeguard-monitor.openai.azure.com/", "gpt35-1106", 0.5, 150, False,"Gender", "Programmer", "Male", "Female", 1, None]
for key, default in zip(keys, defaults):
if key not in st.session_state:
st.session_state[key] = default
initialize_state()
# Model selection and configuration
model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent'))
st.session_state.api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key)
st.session_state.endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url)
st.session_state.deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name)
api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else ''
st.session_state.temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01)
st.session_state.max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens)
if st.sidebar.button("Reset Model Info"):
initialize_state() # Reset all state to defaults
st.experimental_rerun()
if st.sidebar.button("Submit Model Info"):
st.session_state.model_submitted = True
def add_row(df):
# Add a new row with default or empty values at the end of the DataFrame
new_row = pd.DataFrame([{col: "" for col in df.columns}])
return pd.concat([df, new_row], ignore_index=True)
def remove_row(df, index):
# Remove a row based on the index provided
return df.drop(index, errors='ignore').reset_index(drop=True)
# Ensure experiment settings are only shown if model info is submitted
if st.session_state.model_submitted:
df = None
file_options = st.radio("Choose file source:", ["Upload", "Example"])
if file_options == "Example":
df = pd.read_csv("prompt_test.csv")
else:
st.session_state.uploaded_file = st.file_uploader("Choose a file")
if st.session_state.uploaded_file is not None:
data = StringIO(st.session_state.uploaded_file.getvalue().decode("utf-8"))
df = pd.read_csv(data)
if df is not None:
st.write('Data:', df)
# Button to add a new row
if st.button('Add Row'):
df = add_row(df)
st.session_state.uploaded_file = StringIO(
df.to_csv(index=False)) # Update the session file after modification
# Input for row index to remove
row_to_remove = st.number_input('Enter row index to remove', min_value=0, max_value=len(df) - 1, step=1,
format='%d')
if st.button('Remove Row'):
df = remove_row(df, row_to_remove)
st.session_state.uploaded_file = StringIO(
df.to_csv(index=False)) # Update the session file after modification
st.session_state.occupation = st.text_input("Occupation", value=st.session_state.occupation)
st.session_state.group_name = st.text_input("Group Name", value=st.session_state.group_name)
st.session_state.privilege_label = st.text_input("Privilege Label", value=st.session_state.privilege_label)
st.session_state.protect_label = st.text_input("Protect Label", value=st.session_state.protect_label)
st.session_state.num_run = st.number_input("Number of Runs", 1, 10, st.session_state.num_run)
if st.button('Process Data') and not st.session_state.data_processed:
# Initialize the correct agent based on model type
if model_type == 'AzureAgent':
agent = AzureAgent(st.session_state.api_key, st.session_state.endpoint_url, st.session_state.deployment_name)
else:
agent = GPTAgent(st.session_state.api_key, st.session_state.endpoint_url, st.session_state.deployment_name, api_version)
# Process data and display results
with st.spinner('Processing data...'):
parameters = {"temperature": st.session_state.temperature, "max_tokens": st.session_state.max_tokens}
df = process_scores(df, st.session_state.num_run, parameters, st.session_state.privilege_label, st.session_state.protect_label, agent, st.session_state.group_name, st.session_state.occupation)
st.session_state.data_processed = True # Mark as processed
# Add ranks for each score within each row
ranks = df[['Privilege_Avg_Score', 'Protect_Avg_Score', 'Neutral_Avg_Score']].rank(axis=1,ascending=False)
df['Privilege_Rank'] = ranks['Privilege_Avg_Score']
df['Protect_Rank'] = ranks['Protect_Avg_Score']
df['Neutral_Rank'] = ranks['Neutral_Avg_Score']
st.write('Processed Data:', df)
# use the data to generate a plot
st.write("Plotting the data")
test_results = statistical_tests(df)
print(test_results)
evaluation_results = result_evaluation(test_results)
print(evaluation_results)
for key, value in evaluation_results.items():
st.write(f"{key}: {value}")
if st.button("Reset Experiment Settings"):
st.session_state.occupation = "Programmer"
st.session_state.group_name = "Gender"
st.session_state.privilege_label = "Male"
st.session_state.protect_label = "Female"
st.session_state.num_run = 1
st.session_state.data_processed = False
st.session_state.uploaded_file = None