File size: 5,372 Bytes
0b5c5aa
 
 
839ca71
 
5defafa
117a821
245d4fa
 
35b059b
117a821
39654c5
 
 
a7883dd
88009b8
39654c5
 
 
 
 
 
245d4fa
 
39654c5
 
 
245d4fa
39654c5
 
245d4fa
 
39654c5
245d4fa
 
39654c5
 
245d4fa
cc7e22e
 
245d4fa
cbbb1a3
cc7e22e
 
 
 
 
 
 
 
 
 
 
4bf4df2
 
 
a7883dd
 
 
 
 
 
4bf4df2
cc7e22e
 
 
 
 
245d4fa
cc7e22e
 
 
bedb44d
cc7e22e
5defafa
 
 
2d29d72
 
5defafa
 
 
 
 
2d29d72
 
 
 
 
 
 
 
5defafa
3e77c82
5defafa
3e77c82
5defafa
 
 
 
 
cc7e22e
29635b7
 
 
 
cc7e22e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import streamlit as st
import pandas as pd
from io import StringIO
from generation import process_scores
from model import AzureAgent, GPTAgent
from analysis import statistical_tests, result_evaluation

# Set up the Streamlit interface
st.title('JobFair: A Benchmark for Fairness in LLM Employment Decision')
st.sidebar.title('Model Settings')

# Define a function to manage state initialization
def initialize_state():
    keys = ["model_submitted", "api_key", "endpoint_url", "deployment_name", "temperature", "max_tokens",
            "data_processed", "group_name","occupation", "privilege_label", "protect_label", "num_run", "uploaded_file"]
    defaults = [False, "", "https://safeguard-monitor.openai.azure.com/", "gpt35-1106", 0.5, 150, False,"Gender", "Programmer", "Male", "Female", 1, None]
    for key, default in zip(keys, defaults):
        if key not in st.session_state:
            st.session_state[key] = default

initialize_state()

# Model selection and configuration
model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent'))
st.session_state.api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key)
st.session_state.endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url)
st.session_state.deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name)
api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else ''
st.session_state.temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01)
st.session_state.max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens)

if st.sidebar.button("Reset Model Info"):
    initialize_state()  # Reset all state to defaults
    st.experimental_rerun()

if st.sidebar.button("Submit Model Info"):
    st.session_state.model_submitted = True




# Ensure experiment settings are only shown if model info is submitted
if st.session_state.model_submitted:
    df = None
    file_options = st.radio("Choose file source:", ["Upload", "Example"])
    if file_options == "Example":
        df = pd.read_csv("prompt_test.csv")
    else:
        st.session_state.uploaded_file = st.file_uploader("Choose a file")
        if st.session_state.uploaded_file is not None:
            data = StringIO(st.session_state.uploaded_file.getvalue().decode("utf-8"))
            df = pd.read_csv(data)
    if df is not None:

        st.write('Data:', df)

        st.session_state.occupation = st.text_input("Occupation", value=st.session_state.occupation)
        st.session_state.group_name = st.text_input("Group Name", value=st.session_state.group_name)
        st.session_state.privilege_label = st.text_input("Privilege Label", value=st.session_state.privilege_label)
        st.session_state.protect_label = st.text_input("Protect Label", value=st.session_state.protect_label)
        st.session_state.num_run = st.number_input("Number of Runs", 1, 10, st.session_state.num_run)

        if st.button('Process Data') and not st.session_state.data_processed:
            # Initialize the correct agent based on model type
            if model_type == 'AzureAgent':
                agent = AzureAgent(st.session_state.api_key, st.session_state.endpoint_url, st.session_state.deployment_name)
            else:
                agent = GPTAgent(st.session_state.api_key, st.session_state.endpoint_url, st.session_state.deployment_name, api_version)

            # Process data and display results
            with st.spinner('Processing data...'):
                parameters = {"temperature": st.session_state.temperature, "max_tokens": st.session_state.max_tokens}
                df = process_scores(df, st.session_state.num_run, parameters, st.session_state.privilege_label, st.session_state.protect_label, agent, st.session_state.group_name, st.session_state.occupation)
                st.session_state.data_processed = True  # Mark as processed

            # Add ranks for each score within each row
            df['Privilege_Rank'] = \
                df[['Privilege_Avg_Score', 'Protect_Avg_Score', 'Neutral_Avg_Score']].rank(axis=1)[
                    'Privilege_Avg_Score']
            df['Protect_Rank'] = df[['Privilege_Avg_Score', 'Protect_Avg_Score', 'Neutral_Avg_Score']].rank(axis=1)[
                'Protect_Avg_Score']
            df['Neutral_Rank'] = df[['Privilege_Avg_Score', 'Protect_Avg_Score', 'Neutral_Avg_Score']].rank(axis=1)[
                'Neutral_Avg_Score']

            st.write('Processed Data:', df)

            # use the data to generate a plot
            st.write("Plotting the data")




            test_results = statistical_tests(df)
            print(test_results)
            evaluation_results = result_evaluation(test_results)
            print(evaluation_results)

            for key, value in evaluation_results.items():
                st.write(f"{key}: {value}")


        if st.button("Reset Experiment Settings"):
            st.session_state.occupation = "Programmer"
            st.session_state.group_name = "Gender"
            st.session_state.privilege_label = "Male"
            st.session_state.protect_label = "Female"
            st.session_state.num_run = 1
            st.session_state.data_processed = False
            st.session_state.uploaded_file = None