File size: 3,479 Bytes
0b5c5aa
 
 
839ca71
 
117a821
245d4fa
 
35b059b
117a821
39654c5
 
 
d7128fd
 
39654c5
 
 
 
 
 
245d4fa
 
39654c5
 
 
245d4fa
39654c5
 
245d4fa
 
39654c5
245d4fa
 
39654c5
 
245d4fa
cbbb1a3
 
 
297ee52
cbbb1a3
d7128fd
cbbb1a3
d7128fd
245d4fa
 
cbbb1a3
 
 
 
 
 
 
 
245d4fa
cbbb1a3
 
 
 
 
 
d7128fd
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import streamlit as st
import pandas as pd
from io import StringIO
from generation import process_scores
from model import AzureAgent, GPTAgent

# Set up the Streamlit interface
st.title('JobFair: A Benchmark for Fairness in LLM Employment Decision')
st.sidebar.title('Model Settings')

# Define a function to manage state initialization
def initialize_state():
    keys = ["model_submitted", "api_key", "endpoint_url", "deployment_name", "temperature", "max_tokens",
            "data_processed", "group_name", "privilege_label", "protect_label", "num_run", "uploaded_file"]
    defaults = [False, "", "", "", 0.5, 150, False, "", "", "", 1, None]
    for key, default in zip(keys, defaults):
        if key not in st.session_state:
            st.session_state[key] = default

initialize_state()

# Model selection and configuration
model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent'))
st.session_state.api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key)
st.session_state.endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url)
st.session_state.deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name)
api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else ''
st.session_state.temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01)
st.session_state.max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens)

if st.sidebar.button("Reset Model Info"):
    initialize_state()  # Reset all state to defaults
    st.experimental_rerun()

if st.sidebar.button("Submit Model Info"):
    st.session_state.model_submitted = True

# File selection
file_options = st.radio("Choose file source:", ["Upload", "Example"])
if file_options == "Example":
    df = pd.read_csv("prompt_test.csv")
else:
    st.session_state.uploaded_file = st.file_uploader("Choose a file")
    if st.session_state.uploaded_file is not None:
        data = StringIO(st.session_state.uploaded_file.getvalue().decode("utf-8"))
        df = pd.read_csv(data)

# Ensure experiment settings are only shown if model info is submitted
if st.session_state.model_submitted and df is not None:
    if st.button('Process Data'):
        # Initialize the correct agent based on model type
        if model_type == 'AzureAgent':
            agent = AzureAgent(st.session_state.api_key, st.session_state.endpoint_url, st.session_state.deployment_name)
        else:
            agent = GPTAgent(st.session_state.api_key, st.session_state.endpoint_url, st.session_state.deployment_name, api_version)

        # Process data and display results
        with st.spinner('Processing data...'):
            parameters = {"temperature": st.session_state.temperature, "max_tokens": st.session_state.max_tokens}
            df = process_scores(df, st.session_state.num_run, parameters, st.session_state.privilege_label, st.session_state.protect_label, agent, st.session_state.group_name)
            st.session_state.data_processed = True  # Mark as processed
        st.write('Processed Data:', df)

    if st.button("Reset Experiment Settings"):
        st.session_state.group_name = ""
        st.session_state.privilege_label = ""
        st.session_state.protect_label = ""
        st.session_state.num_run = 1
        st.session_state.data_processed = False
        st.session_state.uploaded_file = None