File size: 4,233 Bytes
0b5c5aa
 
 
839ca71
 
117a821
245d4fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117a821
245d4fa
117a821
245d4fa
 
35b059b
117a821
245d4fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import streamlit as st
import pandas as pd
from io import StringIO
from generation import process_scores
from model import AzureAgent, GPTAgent

# Initialize session state variables if they don't already exist
def initialize_state():
    if 'data_processed' not in st.session_state:
        st.session_state.data_processed = False
    if 'api_key' not in st.session_state:
        st.session_state.api_key = ""
    if 'endpoint_url' not in st.session_state:
        st.session_state.endpoint_url = ""
    if 'deployment_name' not in st.session_state:
        st.session_state.deployment_name = ""
    if 'temperature' not in st.session_state:
        st.session_state.temperature = 0.5
    if 'max_tokens' not in st.session_state:
        st.session_state.max_tokens = 150
    if 'group_name' not in st.session_state:
        st.session_state.group_name = ""
    if 'privilege_label' not in st.session_state:
        st.session_state.privilege_label = ""
    if 'protect_label' not in st.session_state:
        st.session_state.protect_label = ""
    if 'num_run' not in st.session_state:
        st.session_state.num_run = 1

initialize_state()

# Set up the Streamlit interface
st.title('JobFair: A Benchmark for Fairness in LLM Employment Decision')
st.sidebar.title('Model Settings')

# Model selection and configuration
model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent'))
api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key)
endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url)
deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name)
api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else ''
temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01)
max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens)

# Reset buttons for model information
if st.sidebar.button("Reset Model Info"):
    st.session_state.api_key = ""
    st.session_state.endpoint_url = ""
    st.session_state.deployment_name = ""
    st.session_state.temperature = 0.5
    st.session_state.max_tokens = 150
    st.experimental_rerun()

submit_model_info = st.sidebar.button("Submit Model Info")

# Data upload and processing with reset option
if submit_model_info:
    parameters = {"temperature": temperature, "max_tokens": max_tokens}

    group_name = st.text_input("Group Name", value=st.session_state.group_name)
    privilege_label = st.text_input("Privilege Name", value=st.session_state.privilege_label)
    protect_label = st.text_input("Protect Name", value=st.session_state.protect_label)
    num_run = st.number_input("Number of runs", min_value=1, value=st.session_state.num_run)
    uploaded_file = st.file_uploader("Choose a file")

    # Reset button for experiment settings
    if st.button("Reset Experiment Settings"):
        st.session_state.group_name = ""
        st.session_state.privilege_label = ""
        st.session_state.protect_label = ""
        st.session_state.num_run = 1
        st.session_state.data_processed = False
        st.experimental_rerun()

    if uploaded_file is not None:
        data = StringIO(uploaded_file.getvalue().decode("utf-8"))
        df = pd.read_csv(data)

        process_button = st.button('Process Data')

        if process_button and not st.session_state.data_processed:
            # Initialize the correct agent based on model type
            if model_type == 'AzureAgent':
                agent = AzureAgent(api_key, endpoint_url, deployment_name)
            else:
                agent = GPTAgent(api_key, endpoint_url, deployment_name, api_version)

            # Process data and display results
            with st.spinner('Processing data...'):
                df = process_scores(df, num_run, parameters, privilege_label, protect_label, agent, group_name)
                st.session_state.data_processed = True  # Mark as processed

            st.write('Processed Data:', df)
        elif process_button and st.session_state.data_processed:
            st.warning("Data already processed for this session. Reset or re-upload to process new data.")