Zekun Wu commited on
Commit
39654c5
1 Parent(s): 245d4fa
Files changed (1) hide show
  1. app.py +31 -56
app.py CHANGED
@@ -4,92 +4,67 @@ from io import StringIO
4
  from generation import process_scores
5
  from model import AzureAgent, GPTAgent
6
 
7
- # Initialize session state variables if they don't already exist
8
- def initialize_state():
9
- if 'data_processed' not in st.session_state:
10
- st.session_state.data_processed = False
11
- if 'api_key' not in st.session_state:
12
- st.session_state.api_key = ""
13
- if 'endpoint_url' not in st.session_state:
14
- st.session_state.endpoint_url = ""
15
- if 'deployment_name' not in st.session_state:
16
- st.session_state.deployment_name = ""
17
- if 'temperature' not in st.session_state:
18
- st.session_state.temperature = 0.5
19
- if 'max_tokens' not in st.session_state:
20
- st.session_state.max_tokens = 150
21
- if 'group_name' not in st.session_state:
22
- st.session_state.group_name = ""
23
- if 'privilege_label' not in st.session_state:
24
- st.session_state.privilege_label = ""
25
- if 'protect_label' not in st.session_state:
26
- st.session_state.protect_label = ""
27
- if 'num_run' not in st.session_state:
28
- st.session_state.num_run = 1
29
-
30
- initialize_state()
31
-
32
  # Set up the Streamlit interface
33
  st.title('JobFair: A Benchmark for Fairness in LLM Employment Decision')
34
  st.sidebar.title('Model Settings')
35
 
 
 
 
 
 
 
 
 
 
 
 
36
  # Model selection and configuration
37
  model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent'))
38
- api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key)
39
- endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url)
40
- deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name)
41
  api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else ''
42
- temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01)
43
- max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens)
44
 
45
- # Reset buttons for model information
46
  if st.sidebar.button("Reset Model Info"):
47
- st.session_state.api_key = ""
48
- st.session_state.endpoint_url = ""
49
- st.session_state.deployment_name = ""
50
- st.session_state.temperature = 0.5
51
- st.session_state.max_tokens = 150
52
  st.experimental_rerun()
53
 
54
- submit_model_info = st.sidebar.button("Submit Model Info")
 
55
 
56
- # Data upload and processing with reset option
57
- if submit_model_info:
58
- parameters = {"temperature": temperature, "max_tokens": max_tokens}
59
 
60
- group_name = st.text_input("Group Name", value=st.session_state.group_name)
61
- privilege_label = st.text_input("Privilege Name", value=st.session_state.privilege_label)
62
- protect_label = st.text_input("Protect Name", value=st.session_state.protect_label)
63
- num_run = st.number_input("Number of runs", min_value=1, value=st.session_state.num_run)
64
  uploaded_file = st.file_uploader("Choose a file")
65
 
66
- # Reset button for experiment settings
67
  if st.button("Reset Experiment Settings"):
68
  st.session_state.group_name = ""
69
  st.session_state.privilege_label = ""
70
  st.session_state.protect_label = ""
71
  st.session_state.num_run = 1
72
  st.session_state.data_processed = False
73
- st.experimental_rerun()
74
 
75
- if uploaded_file is not None:
76
  data = StringIO(uploaded_file.getvalue().decode("utf-8"))
77
  df = pd.read_csv(data)
78
 
79
- process_button = st.button('Process Data')
80
-
81
- if process_button and not st.session_state.data_processed:
82
  # Initialize the correct agent based on model type
83
  if model_type == 'AzureAgent':
84
- agent = AzureAgent(api_key, endpoint_url, deployment_name)
85
  else:
86
- agent = GPTAgent(api_key, endpoint_url, deployment_name, api_version)
87
 
88
  # Process data and display results
89
  with st.spinner('Processing data...'):
90
- df = process_scores(df, num_run, parameters, privilege_label, protect_label, agent, group_name)
91
  st.session_state.data_processed = True # Mark as processed
92
-
93
  st.write('Processed Data:', df)
94
- elif process_button and st.session_state.data_processed:
95
- st.warning("Data already processed for this session. Reset or re-upload to process new data.")
 
4
  from generation import process_scores
5
  from model import AzureAgent, GPTAgent
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  # Set up the Streamlit interface
8
  st.title('JobFair: A Benchmark for Fairness in LLM Employment Decision')
9
  st.sidebar.title('Model Settings')
10
 
11
+ # Define a function to manage state initialization
12
+ def initialize_state():
13
+ keys = ["model_submitted", "api_key", "endpoint_url", "deployment_name", "temperature", "max_tokens",
14
+ "data_processed", "group_name", "privilege_label", "protect_label", "num_run"]
15
+ defaults = [False, "", "", "", 0.5, 150, False, "", "", "", 1]
16
+ for key, default in zip(keys, defaults):
17
+ if key not in st.session_state:
18
+ st.session_state[key] = default
19
+
20
+ initialize_state()
21
+
22
  # Model selection and configuration
23
  model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent'))
24
+ st.session_state.api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key)
25
+ st.session_state.endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url)
26
+ st.session_state.deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name)
27
  api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else ''
28
+ st.session_state.temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01)
29
+ st.session_state.max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens)
30
 
 
31
  if st.sidebar.button("Reset Model Info"):
32
+ initialize_state() # Reset all state to defaults
 
 
 
 
33
  st.experimental_rerun()
34
 
35
+ if st.sidebar.button("Submit Model Info"):
36
+ st.session_state.model_submitted = True
37
 
38
+ # Ensure experiment settings are only shown if model info is submitted
39
+ if st.session_state.model_submitted:
40
+ parameters = {"temperature": st.session_state.temperature, "max_tokens": st.session_state.max_tokens}
41
 
42
+ st.session_state.group_name = st.text_input("Group Name", value=st.session_state.group_name)
43
+ st.session_state.privilege_label = st.text_input("Privilege Name", value=st.session_state.privilege_label)
44
+ st.session_state.protect_label = st.text_input("Protect Name", value=st.session_state.protect_label)
45
+ st.session_state.num_run = st.number_input("Number of runs", min_value=1, value=st.session_state.num_run)
46
  uploaded_file = st.file_uploader("Choose a file")
47
 
 
48
  if st.button("Reset Experiment Settings"):
49
  st.session_state.group_name = ""
50
  st.session_state.privilege_label = ""
51
  st.session_state.protect_label = ""
52
  st.session_state.num_run = 1
53
  st.session_state.data_processed = False
 
54
 
55
+ if uploaded_file is not None and not st.session_state.data_processed:
56
  data = StringIO(uploaded_file.getvalue().decode("utf-8"))
57
  df = pd.read_csv(data)
58
 
59
+ if st.button('Process Data'):
 
 
60
  # Initialize the correct agent based on model type
61
  if model_type == 'AzureAgent':
62
+ agent = AzureAgent(st.session_state.api_key, st.session_state.endpoint_url, st.session_state.deployment_name)
63
  else:
64
+ agent = GPTAgent(st.session_state.api_key, st.session_state.endpoint_url, st.session_state.deployment_name, api_version)
65
 
66
  # Process data and display results
67
  with st.spinner('Processing data...'):
68
+ df = process_scores(df, st.session_state.num_run, parameters, st.session_state.privilege_label, st.session_state.protect_label, agent, st.session_state.group_name)
69
  st.session_state.data_processed = True # Mark as processed
 
70
  st.write('Processed Data:', df)