In [None]:
import pandas as pd
from util.injection import process_scores_multiple
from util.model import AzureAgent, GPTAgent, Claude3Agent
from util.prompt import PROMPT_TEMPLATE

def run_experiment(api_key, model_type, deployment_name, temperature, max_tokens, occupation,
                   sample_size, group_name, privilege_label, protect_label, num_run, prompt_template, endpoint_url=None):
    # Load data
    df = pd.read_csv("resume_subsampled.csv")
    
    # Filter data by occupation
    df = df[df["Occupation"] == occupation]
    df = df.sample(n=sample_size, random_state=42)
    
    # Initialize the agent
    if model_type == 'AzureAgent':
        agent = AzureAgent(api_key, endpoint_url, deployment_name)
    elif model_type == 'GPTAgent':
        api_version = '2024-02-15-preview'
        agent = GPTAgent(api_key, endpoint_url, deployment_name, api_version)
    else:
        agent = Claude3Agent(api_key, deployment_name)
    
    # Process data
    parameters = {"temperature": temperature, "max_tokens": max_tokens}
    preprocessed_df = process_scores_multiple(df, num_run, parameters, privilege_label, protect_label, agent, group_name, occupation, prompt_template)
    
    return preprocessed_df

# Set experiment parameters
api_key = "6c75a8235f204c9e8cf6228e485982f7"
model_type = "GPTAgent"  # or "AzureAgent" or "Claude3Agent"
deployment_name = "gpt4-1106"
temperature = 0.0
max_tokens = 300
file_path = "resume_subsampled.csv"  # or path to your file
occupation = "FINANCE"
sample_size = 100
group_name = "Gender"
privilege_label = "Male"
protect_label = "Female"
num_run = 1
prompt_template = PROMPT_TEMPLATE
endpoint_url = "https://safeguard-monitor.openai.azure.com/"

# Run experiment
results = run_experiment(api_key, model_type, deployment_name, temperature, max_tokens, occupation,
                         sample_size, group_name, privilege_label, protect_label, num_run, prompt_template, endpoint_url)

# Display results
results.head()

# Optionally save results to a CSV file
results.to_csv(f'result/{occupation}_results.csv', index=False)


Processing 100 entries with 1 runs each.


Processing runs:   0%|          | 0/1 [00:00<?, ?run/s]
Processing entries:   0%|          | 0/100 [00:00<?, ?entry/s][A
Processing entries:   1%|          | 1/100 [00:47<1:17:58, 47.26s/entry][A
Processing entries:   2%|▏         | 2/100 [01:15<58:51, 36.04s/entry]  [A
Processing entries:   3%|▎         | 3/100 [01:49<56:30, 34.95s/entry][A
Processing entries:   4%|▍         | 4/100 [02:21<54:34, 34.11s/entry][A
Processing entries:   5%|▌         | 5/100 [02:59<56:11, 35.49s/entry][A
Processing entries:   6%|▌         | 6/100 [03:35<55:33, 35.46s/entry][A
Processing entries:   7%|▋         | 7/100 [04:12<55:48, 36.00s/entry][A
Processing entries:   8%|▊         | 8/100 [04:52<57:20, 37.40s/entry][A
Processing entries:   9%|▉         | 9/100 [05:19<51:31, 33.97s/entry][A
Processing entries:  10%|█         | 10/100 [15:46<5:25:34, 217.06s/entry][A
Processing entries:  11%|█         | 11/100 [16:11<3:55:07, 158.51s/entry][A
Processing entries:  12%|█▏        | 12/100 [17:15<3