Blair Yang commited on
Commit
db9289d
1 Parent(s): a53cf00
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -24,25 +24,26 @@ def generate_plot(meta_index, topic_index):
24
 
25
  data = pd.read_csv(f"data/{meta_topic}/response_rec.csv", sep=",")
26
 
27
- topic_data = data[data['sub_topic'] == topic]
28
 
29
  # Compute human and llm accuracy
30
  topic_data['human_acc'] = topic_data['no_correct_human'] / topic_data['no_responses_human'].replace(0, np.nan)
31
  topic_data['llm_acc'] = topic_data['no_correct_llm'] / topic_data['no_responses_llm'].replace(0, np.nan)
32
 
33
- # Calculate mean and standard deviation for the sample data
34
- mean_data = topic_data.groupby('model_name').mean().reset_index()
35
- std_deviation = topic_data.groupby('model_name').std().reset_index()
 
36
 
37
  # Prepare the plot data
38
  plot_data = []
39
 
40
- # Define a consistent color scheme
41
  colors = ['#FFA07A', '#20B2AA', '#778899'] # Light Salmon, Light Sea Green, Light Slate Gray
42
- opacities = [0.7, 0.7, 0.7] # Opacity for average bars
43
-
44
  # Add bars with error bars for the averages
45
- for acc_type, color, opacity in zip(['oracle_acc', 'human_acc', 'llm_acc'], colors, opacities):
46
  plot_data.append(go.Bar(
47
  x=mean_data['model_name'],
48
  y=mean_data[acc_type],
@@ -52,7 +53,7 @@ def generate_plot(meta_index, topic_index):
52
  visible=True
53
  ),
54
  name=acc_type.split('_')[0].capitalize(),
55
- marker=dict(color=color, opacity=opacity)
56
  ))
57
 
58
  # Layout
 
24
 
25
  data = pd.read_csv(f"data/{meta_topic}/response_rec.csv", sep=",")
26
 
27
+ topic_data = data.loc[data['sub_topic'] == topic].copy()
28
 
29
  # Compute human and llm accuracy
30
  topic_data['human_acc'] = topic_data['no_correct_human'] / topic_data['no_responses_human'].replace(0, np.nan)
31
  topic_data['llm_acc'] = topic_data['no_correct_llm'] / topic_data['no_responses_llm'].replace(0, np.nan)
32
 
33
+ # Selecting only numeric columns for aggregation
34
+ numeric_cols = ['no_responses_human', 'no_correct_human', 'no_responses_llm', 'no_correct_llm', 'oracle_acc', 'human_acc', 'llm_acc']
35
+ mean_data = topic_data.groupby('model_name')[numeric_cols].mean().reset_index()
36
+ std_deviation = topic_data.groupby('model_name')[numeric_cols].std().reset_index()
37
 
38
  # Prepare the plot data
39
  plot_data = []
40
 
41
+ # Define a consistent color scheme with different opacities
42
  colors = ['#FFA07A', '#20B2AA', '#778899'] # Light Salmon, Light Sea Green, Light Slate Gray
43
+ acc_types = ['oracle_acc', 'human_acc', 'llm_acc']
44
+
45
  # Add bars with error bars for the averages
46
+ for acc_type, color in zip(acc_types, colors):
47
  plot_data.append(go.Bar(
48
  x=mean_data['model_name'],
49
  y=mean_data[acc_type],
 
53
  visible=True
54
  ),
55
  name=acc_type.split('_')[0].capitalize(),
56
+ marker=dict(color=color)
57
  ))
58
 
59
  # Layout