phess2 commited on
Commit
e40d189
·
1 Parent(s): fc0a201

Add figure 3 reproducability data and code

Browse files
figures/README.md CHANGED
@@ -1,12 +1,13 @@
1
  # Figure Data Extraction and Reproduction
2
 
3
- This folder contains the data used in Figure 2 and Figure 4 from the Training Lipschitz Transformer paper and scripts to reproduce the figures from the saved CSV files.
4
 
5
  ## Files
6
 
7
  - `reproduce_figures.py`: Script to reproduce the figures from the saved CSV files
8
  - `requirements.txt`: Python dependencies required to run the scripts
9
  - `figure_2/`: Directory containing the CSV files for each subplot of Figure 2
 
10
  - `figure_4/`: Directory containing the CSV files for each subplot of Figure 4
11
 
12
  ## Usage
@@ -26,8 +27,9 @@ python reproduce_figures.py
26
  ```
27
 
28
  This will create:
29
- - `figure_2_reproduced.pdf`: Recreation of Figure 2 from the original notebook
30
- - `figure_4_reproduced.pdf`: Recreation of Figure 4 from the original notebook
 
31
 
32
  ## CSV File Structure
33
 
@@ -37,6 +39,10 @@ Each CSV file contains the processed data for its respective subplot:
37
  - **figure_2_subplot_1.csv**: Contains points used to plot the frontier of validation loss vs. Lipschitz constant with columns for technique, learning rate, w_max, final validation loss, Lipschitz constant, optimizer, etc.
38
  - **figure_2_subplot_2_3.csv**: Contains results for top validation accuracy model for each of our tested techniques with columns for technique, learning rate, w_max, final validation accuracy, Lipschitz constant, optimizer, etc.
39
 
 
 
 
 
40
  ### Figure 4 Files
41
  - **figure_4_subplot_1.csv**: MLP model results with frontier points for different optimizers and techniques
42
  - **figure_4_subplot_2.csv**: Transformer model results with frontier points for different optimizers and techniques
@@ -49,4 +55,4 @@ The reproduction script creates pixel-perfect recreations of the original figure
49
  - Matching legend positioning and styling
50
  - Equivalent subplot layouts and spacing
51
 
52
- This ensures full reproducibility of the figures from the saved CSV data.
 
1
  # Figure Data Extraction and Reproduction
2
 
3
+ This folder contains the data used in Figure 2, Figure 3, and Figure 4 from the Training Lipschitz Transformer paper and scripts to reproduce the figures from the saved CSV files.
4
 
5
  ## Files
6
 
7
  - `reproduce_figures.py`: Script to reproduce the figures from the saved CSV files
8
  - `requirements.txt`: Python dependencies required to run the scripts
9
  - `figure_2/`: Directory containing the CSV files for each subplot of Figure 2
10
+ - `figure_3/`: Directory containing the CSV files and a PDF for Figure 3
11
  - `figure_4/`: Directory containing the CSV files for each subplot of Figure 4
12
 
13
  ## Usage
 
27
  ```
28
 
29
  This will create:
30
+ - `figure_2_reproduced.pdf`: Recreation of Figure 2
31
+ - `figure_3_reproduced.pdf`: Recreation of the right panel of Figure 3
32
+ - `figure_4_reproduced.pdf`: Recreation of Figure 4
33
 
34
  ## CSV File Structure
35
 
 
39
  - **figure_2_subplot_1.csv**: Contains points used to plot the frontier of validation loss vs. Lipschitz constant with columns for technique, learning rate, w_max, final validation loss, Lipschitz constant, optimizer, etc.
40
  - **figure_2_subplot_2_3.csv**: Contains results for top validation accuracy model for each of our tested techniques with columns for technique, learning rate, w_max, final validation accuracy, Lipschitz constant, optimizer, etc.
41
 
42
+ ### Figure 3 Files
43
+ - **figure_3_subplot_1.pdf**: Contains the left panel of Figure 3 with adversarial examples pre-made from the models contained in the `models/MLPs/` directory
44
+ - **figure_3_subplot_2.csv**: Contains adversarial robustness data with columns for model_name, epsilon (adversarial perturbation budget), accuracy, avg_correct_prob (mean probability for correct class), and prob_error_bar (error bars for probability measurements)
45
+
46
  ### Figure 4 Files
47
  - **figure_4_subplot_1.csv**: MLP model results with frontier points for different optimizers and techniques
48
  - **figure_4_subplot_2.csv**: Transformer model results with frontier points for different optimizers and techniques
 
55
  - Matching legend positioning and styling
56
  - Equivalent subplot layouts and spacing
57
 
58
+ This ensures full reproducibility of Figure 2 (Lipschitz constraint comparison), Figure 3 (adversarial robustness analysis), and Figure 4 (MLP and Transformer optimizer comparisons) from the saved CSV data.
figures/figure_3/figure_3_subplot_1.pdf ADDED
Binary file (47 kB). View file
 
figures/figure_3/figure_3_subplot_2.csv ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name,epsilon,epsilon_index,accuracy,avg_correct_prob,prob_error_bar
2
+ Lipschitz bound 15.2 (Muon + soft cap),0.0,0,0.470703125,0.2465411126613617,0.006872239056974649
3
+ Lipschitz bound 15.2 (Muon + soft cap),0.4,1,0.46484375,0.24188876152038574,0.00677172327414155
4
+ Lipschitz bound 15.2 (Muon + soft cap),0.8,2,0.4609375,0.23624548316001892,0.006661087274551392
5
+ Lipschitz bound 15.2 (Muon + soft cap),1.2,3,0.416015625,0.22944872081279755,0.006539473310112953
6
+ Lipschitz bound 15.2 (Muon + soft cap),1.6,4,0.384765625,0.22136151790618896,0.006398873869329691
7
+ Lipschitz bound 15.2 (Muon + soft cap),2.0,5,0.34375,0.21216988563537598,0.006248640362173319
8
+ Lipschitz bound 15.2 (Muon + soft cap),2.4,6,0.3046875,0.2017553150653839,0.006085531786084175
9
+ Lipschitz bound 15.2 (Muon + soft cap),2.8,7,0.255859375,0.1905008852481842,0.005909574683755636
10
+ Lipschitz bound 15.2 (Muon + soft cap),3.2,8,0.205078125,0.17851822078227997,0.005721117369830608
11
+ Lipschitz bound 15.2 (Muon + soft cap),3.6,9,0.15625,0.166222482919693,0.00552830146625638
12
+ Lipschitz bound 15.2 (Muon + soft cap),4.0,10,0.12109375,0.1539309024810791,0.005329667124897242
13
+ Lipschitz bound 15.2 (Muon + soft cap),4.4,11,0.099609375,0.1419779360294342,0.005133118946105242
14
+ Lipschitz bound 15.2 (Muon + soft cap),4.8,12,0.080078125,0.1307675540447235,0.004940425511449575
15
+ Lipschitz bound 15.2 (Muon + soft cap),5.2,13,0.064453125,0.12042088061571121,0.004750879481434822
16
+ Lipschitz bound 7618.8 (Adam + weight decay),0.0,0,0.455078125,0.3451777398586273,0.011264396831393242
17
+ Lipschitz bound 7618.8 (Adam + weight decay),0.4,1,0.427734375,0.31623339653015137,0.010893518105149269
18
+ Lipschitz bound 7618.8 (Adam + weight decay),0.8,2,0.326171875,0.26872894167900085,0.01072117779403925
19
+ Lipschitz bound 7618.8 (Adam + weight decay),1.2,3,0.234375,0.2093334048986435,0.01011421624571085
20
+ Lipschitz bound 7618.8 (Adam + weight decay),1.6,4,0.158203125,0.1528899371623993,0.008851265534758568
21
+ Lipschitz bound 7618.8 (Adam + weight decay),2.0,5,0.095703125,0.10684726387262344,0.007450764533132315
22
+ Lipschitz bound 7618.8 (Adam + weight decay),2.4,6,0.064453125,0.07626770436763763,0.006211025640368462
23
+ Lipschitz bound 7618.8 (Adam + weight decay),2.8,7,0.03125,0.05315832793712616,0.004932188894599676
24
+ Lipschitz bound 7618.8 (Adam + weight decay),3.2,8,0.015625,0.03924941271543503,0.004173320718109608
25
+ Lipschitz bound 7618.8 (Adam + weight decay),3.6,9,0.013671875,0.03031872771680355,0.0036359354853630066
26
+ Lipschitz bound 7618.8 (Adam + weight decay),4.0,10,0.009765625,0.02403259463608265,0.003225660650059581
27
+ Lipschitz bound 7618.8 (Adam + weight decay),4.4,11,0.009765625,0.02026309445500374,0.0029816720634698868
28
+ Lipschitz bound 7618.8 (Adam + weight decay),4.8,12,0.0078125,0.01729888655245304,0.0027463138103485107
29
+ Lipschitz bound 7618.8 (Adam + weight decay),5.2,13,0.005859375,0.015457117930054665,0.002608406590297818
figures/reproduce_figures.py CHANGED
@@ -20,17 +20,21 @@ def load_csv_data():
20
  """Load all CSV files containing the figure data"""
21
 
22
  fig_2_data_dir = Path("figure_2")
 
23
  fig_4_data_dir = Path("figure_4")
24
 
25
  # Load Figure 2 data
26
  fig2_subplot1 = pd.read_csv(fig_2_data_dir / "figure_2_subplot_1.csv")
27
  fig2_subplot2_3 = pd.read_csv(fig_2_data_dir / "figure_2_subplot_2_3.csv") # Used for both subplot 2 and 3
28
 
 
 
 
29
  # Load Figure 4 data
30
  fig4_subplot1 = pd.read_csv(fig_4_data_dir / "figure_4_subplot_1.csv")
31
  fig4_subplot2 = pd.read_csv(fig_4_data_dir / "figure_4_subplot_2.csv")
32
 
33
- return fig2_subplot1, fig2_subplot2_3, fig4_subplot1, fig4_subplot2
34
 
35
  def safe_eval_list(list_str):
36
  """Safely evaluate string representation of list"""
@@ -236,6 +240,70 @@ def create_figure_2(highlight_points, results_df):
236
  plt.savefig("figure_2_reproduced.pdf", format='pdf', bbox_inches='tight')
237
  plt.show()
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  def create_figure_4(MLP_highlight_points, transformer_highlight_points):
240
  """Create Figure 4: MLP vs Transformer comparison"""
241
 
@@ -383,19 +451,23 @@ def create_figure_4(MLP_highlight_points, transformer_highlight_points):
383
  plt.savefig("figure_4_reproduced.pdf", dpi=600, bbox_inches='tight')
384
  plt.show()
385
 
 
386
  def main():
387
- """Main function to load data and create both figures"""
388
 
389
  print("Loading CSV data...")
390
- fig2_subplot1, fig2_subplot2_3, fig4_subplot1, fig4_subplot2 = load_csv_data()
391
 
392
  print("Creating Figure 2...")
393
  create_figure_2(fig2_subplot1, fig2_subplot2_3)
394
 
 
 
 
395
  print("Creating Figure 4...")
396
  create_figure_4(fig4_subplot1, fig4_subplot2)
397
 
398
- print("Figures saved as 'figure_2_reproduced.pdf' and 'figure_4_reproduced.pdf'")
399
 
400
  if __name__ == "__main__":
401
  main()
 
20
  """Load all CSV files containing the figure data"""
21
 
22
  fig_2_data_dir = Path("figure_2")
23
+ fig_3_data_dir = Path("figure_3")
24
  fig_4_data_dir = Path("figure_4")
25
 
26
  # Load Figure 2 data
27
  fig2_subplot1 = pd.read_csv(fig_2_data_dir / "figure_2_subplot_1.csv")
28
  fig2_subplot2_3 = pd.read_csv(fig_2_data_dir / "figure_2_subplot_2_3.csv") # Used for both subplot 2 and 3
29
 
30
+ # Load Figure 3 data
31
+ fig3_data = pd.read_csv(fig_3_data_dir / "figure_3_subplot_2.csv")
32
+
33
  # Load Figure 4 data
34
  fig4_subplot1 = pd.read_csv(fig_4_data_dir / "figure_4_subplot_1.csv")
35
  fig4_subplot2 = pd.read_csv(fig_4_data_dir / "figure_4_subplot_2.csv")
36
 
37
+ return fig2_subplot1, fig2_subplot2_3, fig3_data, fig4_subplot1, fig4_subplot2
38
 
39
  def safe_eval_list(list_str):
40
  """Safely evaluate string representation of list"""
 
240
  plt.savefig("figure_2_reproduced.pdf", format='pdf', bbox_inches='tight')
241
  plt.show()
242
 
243
+
244
+ def create_figure_3(df):
245
+ """Create Figure 3: Adversarial robustness comparison"""
246
+
247
+ # Extract unique epsilon values and find epsilon range
248
+ epsilons = sorted(df['epsilon'].unique())
249
+ epsilons_upto = len(epsilons) # Use all available epsilon values
250
+
251
+ # Create the model info for plotting (extract from CSV)
252
+ models = []
253
+ for model_name in df['model_name'].unique():
254
+ model_data = df[df['model_name'] == model_name].copy().sort_values(by='epsilon')
255
+
256
+ # Determine color based on model name
257
+ if "Muon" in model_name or "soft cap" in model_name:
258
+ color = "royalblue"
259
+ else:
260
+ color = "#7F7F7F"
261
+
262
+ models.append({
263
+ "name": model_name,
264
+ "color": color,
265
+ "accuracies": model_data['accuracy'].tolist(),
266
+ "avg_correct_probs": model_data['avg_correct_prob'].tolist(),
267
+ "error_bars": model_data['prob_error_bar'].tolist()
268
+ })
269
+
270
+ # Create a figure with two subplots stacked vertically
271
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 5), sharex=True)
272
+
273
+ # Plot accuracy for each model (top subplot)
274
+ for model in models:
275
+ ax1.plot(epsilons[:epsilons_upto], model["accuracies"][:epsilons_upto], 'o-',
276
+ linewidth=3, markersize=5,
277
+ label=model["name"], color=model["color"])
278
+ ax1.set_xticks(epsilons[::2])
279
+
280
+ # Plot probability with error bars for each model (bottom subplot)
281
+ for model in models:
282
+ ax2.errorbar(epsilons[:epsilons_upto], model["avg_correct_probs"][:epsilons_upto],
283
+ yerr=model["error_bars"][:epsilons_upto], fmt='o-',
284
+ linewidth=3, markersize=5, capsize=5, elinewidth=1.5,
285
+ label=model["name"], color=model["color"])
286
+ ax2.set_xticks(epsilons[::2])
287
+
288
+ # Configure top subplot (accuracy)
289
+ ax1.set_ylabel('Accuracy (top 1)', fontsize=12)
290
+ ax1.set_ylim(0, 0.5)
291
+ ax1.tick_params(axis='y', labelsize=12)
292
+ ax1.legend(fontsize=12, frameon=False, borderpad=0.2, handletextpad=0.5, labelspacing=0.2, loc='upper center', bbox_to_anchor=(0.5, 1.38))
293
+
294
+ # Configure bottom subplot (probability)
295
+ ax2.set_xlabel('Budget of adversarial perturbation (ε)', fontsize=12)
296
+ ax2.set_ylabel('Mean p(correct class)', fontsize=12)
297
+ ax2.tick_params(axis='both', labelsize=12)
298
+
299
+ # Set x-ticks for both subplots
300
+ plt.xticks(epsilons[::2])
301
+
302
+ plt.tight_layout()
303
+ plt.savefig("figure_3_reproduced.pdf", format='pdf', bbox_inches='tight')
304
+ plt.show()
305
+
306
+
307
  def create_figure_4(MLP_highlight_points, transformer_highlight_points):
308
  """Create Figure 4: MLP vs Transformer comparison"""
309
 
 
451
  plt.savefig("figure_4_reproduced.pdf", dpi=600, bbox_inches='tight')
452
  plt.show()
453
 
454
+
455
  def main():
456
+ """Main function to load data and create all figures"""
457
 
458
  print("Loading CSV data...")
459
+ fig2_subplot1, fig2_subplot2_3, fig3_data, fig4_subplot1, fig4_subplot2 = load_csv_data()
460
 
461
  print("Creating Figure 2...")
462
  create_figure_2(fig2_subplot1, fig2_subplot2_3)
463
 
464
+ print("Creating Figure 3...")
465
+ create_figure_3(fig3_data)
466
+
467
  print("Creating Figure 4...")
468
  create_figure_4(fig4_subplot1, fig4_subplot2)
469
 
470
+ print("Figures saved as 'figure_2_reproduced.pdf', 'figure_3_reproduced.pdf' and 'figure_4_reproduced.pdf'")
471
 
472
  if __name__ == "__main__":
473
  main()