Add figure 3 reproducability data and code
Browse files
figures/README.md
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
# Figure Data Extraction and Reproduction
|
| 2 |
|
| 3 |
-
This folder contains the data used in Figure 2 and Figure 4 from the Training Lipschitz Transformer paper and scripts to reproduce the figures from the saved CSV files.
|
| 4 |
|
| 5 |
## Files
|
| 6 |
|
| 7 |
- `reproduce_figures.py`: Script to reproduce the figures from the saved CSV files
|
| 8 |
- `requirements.txt`: Python dependencies required to run the scripts
|
| 9 |
- `figure_2/`: Directory containing the CSV files for each subplot of Figure 2
|
|
|
|
| 10 |
- `figure_4/`: Directory containing the CSV files for each subplot of Figure 4
|
| 11 |
|
| 12 |
## Usage
|
|
@@ -26,8 +27,9 @@ python reproduce_figures.py
|
|
| 26 |
```
|
| 27 |
|
| 28 |
This will create:
|
| 29 |
-
- `figure_2_reproduced.pdf`: Recreation of Figure 2
|
| 30 |
-
- `
|
|
|
|
| 31 |
|
| 32 |
## CSV File Structure
|
| 33 |
|
|
@@ -37,6 +39,10 @@ Each CSV file contains the processed data for its respective subplot:
|
|
| 37 |
- **figure_2_subplot_1.csv**: Contains points used to plot the frontier of validation loss vs. Lipschitz constant with columns for technique, learning rate, w_max, final validation loss, Lipschitz constant, optimizer, etc.
|
| 38 |
- **figure_2_subplot_2_3.csv**: Contains results for top validation accuracy model for each of our tested techniques with columns for technique, learning rate, w_max, final validation accuracy, Lipschitz constant, optimizer, etc.
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
### Figure 4 Files
|
| 41 |
- **figure_4_subplot_1.csv**: MLP model results with frontier points for different optimizers and techniques
|
| 42 |
- **figure_4_subplot_2.csv**: Transformer model results with frontier points for different optimizers and techniques
|
|
@@ -49,4 +55,4 @@ The reproduction script creates pixel-perfect recreations of the original figure
|
|
| 49 |
- Matching legend positioning and styling
|
| 50 |
- Equivalent subplot layouts and spacing
|
| 51 |
|
| 52 |
-
This ensures full reproducibility of
|
|
|
|
| 1 |
# Figure Data Extraction and Reproduction
|
| 2 |
|
| 3 |
+
This folder contains the data used in Figure 2, Figure 3, and Figure 4 from the Training Lipschitz Transformer paper and scripts to reproduce the figures from the saved CSV files.
|
| 4 |
|
| 5 |
## Files
|
| 6 |
|
| 7 |
- `reproduce_figures.py`: Script to reproduce the figures from the saved CSV files
|
| 8 |
- `requirements.txt`: Python dependencies required to run the scripts
|
| 9 |
- `figure_2/`: Directory containing the CSV files for each subplot of Figure 2
|
| 10 |
+
- `figure_3/`: Directory containing the CSV files and a PDF for Figure 3
|
| 11 |
- `figure_4/`: Directory containing the CSV files for each subplot of Figure 4
|
| 12 |
|
| 13 |
## Usage
|
|
|
|
| 27 |
```
|
| 28 |
|
| 29 |
This will create:
|
| 30 |
+
- `figure_2_reproduced.pdf`: Recreation of Figure 2
|
| 31 |
+
- `figure_3_reproduced.pdf`: Recreation of the right panel of Figure 3
|
| 32 |
+
- `figure_4_reproduced.pdf`: Recreation of Figure 4
|
| 33 |
|
| 34 |
## CSV File Structure
|
| 35 |
|
|
|
|
| 39 |
- **figure_2_subplot_1.csv**: Contains points used to plot the frontier of validation loss vs. Lipschitz constant with columns for technique, learning rate, w_max, final validation loss, Lipschitz constant, optimizer, etc.
|
| 40 |
- **figure_2_subplot_2_3.csv**: Contains results for top validation accuracy model for each of our tested techniques with columns for technique, learning rate, w_max, final validation accuracy, Lipschitz constant, optimizer, etc.
|
| 41 |
|
| 42 |
+
### Figure 3 Files
|
| 43 |
+
- **figure_3_subplot_1.pdf**: Contains the left panel of Figure 3 with adversarial examples pre-made from the models contained in the `models/MLPs/` directory
|
| 44 |
+
- **figure_3_subplot_2.csv**: Contains adversarial robustness data with columns for model_name, epsilon (adversarial perturbation budget), accuracy, avg_correct_prob (mean probability for correct class), and prob_error_bar (error bars for probability measurements)
|
| 45 |
+
|
| 46 |
### Figure 4 Files
|
| 47 |
- **figure_4_subplot_1.csv**: MLP model results with frontier points for different optimizers and techniques
|
| 48 |
- **figure_4_subplot_2.csv**: Transformer model results with frontier points for different optimizers and techniques
|
|
|
|
| 55 |
- Matching legend positioning and styling
|
| 56 |
- Equivalent subplot layouts and spacing
|
| 57 |
|
| 58 |
+
This ensures full reproducibility of Figure 2 (Lipschitz constraint comparison), Figure 3 (adversarial robustness analysis), and Figure 4 (MLP and Transformer optimizer comparisons) from the saved CSV data.
|
figures/figure_3/figure_3_subplot_1.pdf
ADDED
|
Binary file (47 kB). View file
|
|
|
figures/figure_3/figure_3_subplot_2.csv
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_name,epsilon,epsilon_index,accuracy,avg_correct_prob,prob_error_bar
|
| 2 |
+
Lipschitz bound 15.2 (Muon + soft cap),0.0,0,0.470703125,0.2465411126613617,0.006872239056974649
|
| 3 |
+
Lipschitz bound 15.2 (Muon + soft cap),0.4,1,0.46484375,0.24188876152038574,0.00677172327414155
|
| 4 |
+
Lipschitz bound 15.2 (Muon + soft cap),0.8,2,0.4609375,0.23624548316001892,0.006661087274551392
|
| 5 |
+
Lipschitz bound 15.2 (Muon + soft cap),1.2,3,0.416015625,0.22944872081279755,0.006539473310112953
|
| 6 |
+
Lipschitz bound 15.2 (Muon + soft cap),1.6,4,0.384765625,0.22136151790618896,0.006398873869329691
|
| 7 |
+
Lipschitz bound 15.2 (Muon + soft cap),2.0,5,0.34375,0.21216988563537598,0.006248640362173319
|
| 8 |
+
Lipschitz bound 15.2 (Muon + soft cap),2.4,6,0.3046875,0.2017553150653839,0.006085531786084175
|
| 9 |
+
Lipschitz bound 15.2 (Muon + soft cap),2.8,7,0.255859375,0.1905008852481842,0.005909574683755636
|
| 10 |
+
Lipschitz bound 15.2 (Muon + soft cap),3.2,8,0.205078125,0.17851822078227997,0.005721117369830608
|
| 11 |
+
Lipschitz bound 15.2 (Muon + soft cap),3.6,9,0.15625,0.166222482919693,0.00552830146625638
|
| 12 |
+
Lipschitz bound 15.2 (Muon + soft cap),4.0,10,0.12109375,0.1539309024810791,0.005329667124897242
|
| 13 |
+
Lipschitz bound 15.2 (Muon + soft cap),4.4,11,0.099609375,0.1419779360294342,0.005133118946105242
|
| 14 |
+
Lipschitz bound 15.2 (Muon + soft cap),4.8,12,0.080078125,0.1307675540447235,0.004940425511449575
|
| 15 |
+
Lipschitz bound 15.2 (Muon + soft cap),5.2,13,0.064453125,0.12042088061571121,0.004750879481434822
|
| 16 |
+
Lipschitz bound 7618.8 (Adam + weight decay),0.0,0,0.455078125,0.3451777398586273,0.011264396831393242
|
| 17 |
+
Lipschitz bound 7618.8 (Adam + weight decay),0.4,1,0.427734375,0.31623339653015137,0.010893518105149269
|
| 18 |
+
Lipschitz bound 7618.8 (Adam + weight decay),0.8,2,0.326171875,0.26872894167900085,0.01072117779403925
|
| 19 |
+
Lipschitz bound 7618.8 (Adam + weight decay),1.2,3,0.234375,0.2093334048986435,0.01011421624571085
|
| 20 |
+
Lipschitz bound 7618.8 (Adam + weight decay),1.6,4,0.158203125,0.1528899371623993,0.008851265534758568
|
| 21 |
+
Lipschitz bound 7618.8 (Adam + weight decay),2.0,5,0.095703125,0.10684726387262344,0.007450764533132315
|
| 22 |
+
Lipschitz bound 7618.8 (Adam + weight decay),2.4,6,0.064453125,0.07626770436763763,0.006211025640368462
|
| 23 |
+
Lipschitz bound 7618.8 (Adam + weight decay),2.8,7,0.03125,0.05315832793712616,0.004932188894599676
|
| 24 |
+
Lipschitz bound 7618.8 (Adam + weight decay),3.2,8,0.015625,0.03924941271543503,0.004173320718109608
|
| 25 |
+
Lipschitz bound 7618.8 (Adam + weight decay),3.6,9,0.013671875,0.03031872771680355,0.0036359354853630066
|
| 26 |
+
Lipschitz bound 7618.8 (Adam + weight decay),4.0,10,0.009765625,0.02403259463608265,0.003225660650059581
|
| 27 |
+
Lipschitz bound 7618.8 (Adam + weight decay),4.4,11,0.009765625,0.02026309445500374,0.0029816720634698868
|
| 28 |
+
Lipschitz bound 7618.8 (Adam + weight decay),4.8,12,0.0078125,0.01729888655245304,0.0027463138103485107
|
| 29 |
+
Lipschitz bound 7618.8 (Adam + weight decay),5.2,13,0.005859375,0.015457117930054665,0.002608406590297818
|
figures/reproduce_figures.py
CHANGED
|
@@ -20,17 +20,21 @@ def load_csv_data():
|
|
| 20 |
"""Load all CSV files containing the figure data"""
|
| 21 |
|
| 22 |
fig_2_data_dir = Path("figure_2")
|
|
|
|
| 23 |
fig_4_data_dir = Path("figure_4")
|
| 24 |
|
| 25 |
# Load Figure 2 data
|
| 26 |
fig2_subplot1 = pd.read_csv(fig_2_data_dir / "figure_2_subplot_1.csv")
|
| 27 |
fig2_subplot2_3 = pd.read_csv(fig_2_data_dir / "figure_2_subplot_2_3.csv") # Used for both subplot 2 and 3
|
| 28 |
|
|
|
|
|
|
|
|
|
|
| 29 |
# Load Figure 4 data
|
| 30 |
fig4_subplot1 = pd.read_csv(fig_4_data_dir / "figure_4_subplot_1.csv")
|
| 31 |
fig4_subplot2 = pd.read_csv(fig_4_data_dir / "figure_4_subplot_2.csv")
|
| 32 |
|
| 33 |
-
return fig2_subplot1, fig2_subplot2_3, fig4_subplot1, fig4_subplot2
|
| 34 |
|
| 35 |
def safe_eval_list(list_str):
|
| 36 |
"""Safely evaluate string representation of list"""
|
|
@@ -236,6 +240,70 @@ def create_figure_2(highlight_points, results_df):
|
|
| 236 |
plt.savefig("figure_2_reproduced.pdf", format='pdf', bbox_inches='tight')
|
| 237 |
plt.show()
|
| 238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
def create_figure_4(MLP_highlight_points, transformer_highlight_points):
|
| 240 |
"""Create Figure 4: MLP vs Transformer comparison"""
|
| 241 |
|
|
@@ -383,19 +451,23 @@ def create_figure_4(MLP_highlight_points, transformer_highlight_points):
|
|
| 383 |
plt.savefig("figure_4_reproduced.pdf", dpi=600, bbox_inches='tight')
|
| 384 |
plt.show()
|
| 385 |
|
|
|
|
| 386 |
def main():
|
| 387 |
-
"""Main function to load data and create
|
| 388 |
|
| 389 |
print("Loading CSV data...")
|
| 390 |
-
fig2_subplot1, fig2_subplot2_3, fig4_subplot1, fig4_subplot2 = load_csv_data()
|
| 391 |
|
| 392 |
print("Creating Figure 2...")
|
| 393 |
create_figure_2(fig2_subplot1, fig2_subplot2_3)
|
| 394 |
|
|
|
|
|
|
|
|
|
|
| 395 |
print("Creating Figure 4...")
|
| 396 |
create_figure_4(fig4_subplot1, fig4_subplot2)
|
| 397 |
|
| 398 |
-
print("Figures saved as 'figure_2_reproduced.pdf' and 'figure_4_reproduced.pdf'")
|
| 399 |
|
| 400 |
if __name__ == "__main__":
|
| 401 |
main()
|
|
|
|
| 20 |
"""Load all CSV files containing the figure data"""
|
| 21 |
|
| 22 |
fig_2_data_dir = Path("figure_2")
|
| 23 |
+
fig_3_data_dir = Path("figure_3")
|
| 24 |
fig_4_data_dir = Path("figure_4")
|
| 25 |
|
| 26 |
# Load Figure 2 data
|
| 27 |
fig2_subplot1 = pd.read_csv(fig_2_data_dir / "figure_2_subplot_1.csv")
|
| 28 |
fig2_subplot2_3 = pd.read_csv(fig_2_data_dir / "figure_2_subplot_2_3.csv") # Used for both subplot 2 and 3
|
| 29 |
|
| 30 |
+
# Load Figure 3 data
|
| 31 |
+
fig3_data = pd.read_csv(fig_3_data_dir / "figure_3_subplot_2.csv")
|
| 32 |
+
|
| 33 |
# Load Figure 4 data
|
| 34 |
fig4_subplot1 = pd.read_csv(fig_4_data_dir / "figure_4_subplot_1.csv")
|
| 35 |
fig4_subplot2 = pd.read_csv(fig_4_data_dir / "figure_4_subplot_2.csv")
|
| 36 |
|
| 37 |
+
return fig2_subplot1, fig2_subplot2_3, fig3_data, fig4_subplot1, fig4_subplot2
|
| 38 |
|
| 39 |
def safe_eval_list(list_str):
|
| 40 |
"""Safely evaluate string representation of list"""
|
|
|
|
| 240 |
plt.savefig("figure_2_reproduced.pdf", format='pdf', bbox_inches='tight')
|
| 241 |
plt.show()
|
| 242 |
|
| 243 |
+
|
| 244 |
+
def create_figure_3(df):
|
| 245 |
+
"""Create Figure 3: Adversarial robustness comparison"""
|
| 246 |
+
|
| 247 |
+
# Extract unique epsilon values and find epsilon range
|
| 248 |
+
epsilons = sorted(df['epsilon'].unique())
|
| 249 |
+
epsilons_upto = len(epsilons) # Use all available epsilon values
|
| 250 |
+
|
| 251 |
+
# Create the model info for plotting (extract from CSV)
|
| 252 |
+
models = []
|
| 253 |
+
for model_name in df['model_name'].unique():
|
| 254 |
+
model_data = df[df['model_name'] == model_name].copy().sort_values(by='epsilon')
|
| 255 |
+
|
| 256 |
+
# Determine color based on model name
|
| 257 |
+
if "Muon" in model_name or "soft cap" in model_name:
|
| 258 |
+
color = "royalblue"
|
| 259 |
+
else:
|
| 260 |
+
color = "#7F7F7F"
|
| 261 |
+
|
| 262 |
+
models.append({
|
| 263 |
+
"name": model_name,
|
| 264 |
+
"color": color,
|
| 265 |
+
"accuracies": model_data['accuracy'].tolist(),
|
| 266 |
+
"avg_correct_probs": model_data['avg_correct_prob'].tolist(),
|
| 267 |
+
"error_bars": model_data['prob_error_bar'].tolist()
|
| 268 |
+
})
|
| 269 |
+
|
| 270 |
+
# Create a figure with two subplots stacked vertically
|
| 271 |
+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 5), sharex=True)
|
| 272 |
+
|
| 273 |
+
# Plot accuracy for each model (top subplot)
|
| 274 |
+
for model in models:
|
| 275 |
+
ax1.plot(epsilons[:epsilons_upto], model["accuracies"][:epsilons_upto], 'o-',
|
| 276 |
+
linewidth=3, markersize=5,
|
| 277 |
+
label=model["name"], color=model["color"])
|
| 278 |
+
ax1.set_xticks(epsilons[::2])
|
| 279 |
+
|
| 280 |
+
# Plot probability with error bars for each model (bottom subplot)
|
| 281 |
+
for model in models:
|
| 282 |
+
ax2.errorbar(epsilons[:epsilons_upto], model["avg_correct_probs"][:epsilons_upto],
|
| 283 |
+
yerr=model["error_bars"][:epsilons_upto], fmt='o-',
|
| 284 |
+
linewidth=3, markersize=5, capsize=5, elinewidth=1.5,
|
| 285 |
+
label=model["name"], color=model["color"])
|
| 286 |
+
ax2.set_xticks(epsilons[::2])
|
| 287 |
+
|
| 288 |
+
# Configure top subplot (accuracy)
|
| 289 |
+
ax1.set_ylabel('Accuracy (top 1)', fontsize=12)
|
| 290 |
+
ax1.set_ylim(0, 0.5)
|
| 291 |
+
ax1.tick_params(axis='y', labelsize=12)
|
| 292 |
+
ax1.legend(fontsize=12, frameon=False, borderpad=0.2, handletextpad=0.5, labelspacing=0.2, loc='upper center', bbox_to_anchor=(0.5, 1.38))
|
| 293 |
+
|
| 294 |
+
# Configure bottom subplot (probability)
|
| 295 |
+
ax2.set_xlabel('Budget of adversarial perturbation (ε)', fontsize=12)
|
| 296 |
+
ax2.set_ylabel('Mean p(correct class)', fontsize=12)
|
| 297 |
+
ax2.tick_params(axis='both', labelsize=12)
|
| 298 |
+
|
| 299 |
+
# Set x-ticks for both subplots
|
| 300 |
+
plt.xticks(epsilons[::2])
|
| 301 |
+
|
| 302 |
+
plt.tight_layout()
|
| 303 |
+
plt.savefig("figure_3_reproduced.pdf", format='pdf', bbox_inches='tight')
|
| 304 |
+
plt.show()
|
| 305 |
+
|
| 306 |
+
|
| 307 |
def create_figure_4(MLP_highlight_points, transformer_highlight_points):
|
| 308 |
"""Create Figure 4: MLP vs Transformer comparison"""
|
| 309 |
|
|
|
|
| 451 |
plt.savefig("figure_4_reproduced.pdf", dpi=600, bbox_inches='tight')
|
| 452 |
plt.show()
|
| 453 |
|
| 454 |
+
|
| 455 |
def main():
|
| 456 |
+
"""Main function to load data and create all figures"""
|
| 457 |
|
| 458 |
print("Loading CSV data...")
|
| 459 |
+
fig2_subplot1, fig2_subplot2_3, fig3_data, fig4_subplot1, fig4_subplot2 = load_csv_data()
|
| 460 |
|
| 461 |
print("Creating Figure 2...")
|
| 462 |
create_figure_2(fig2_subplot1, fig2_subplot2_3)
|
| 463 |
|
| 464 |
+
print("Creating Figure 3...")
|
| 465 |
+
create_figure_3(fig3_data)
|
| 466 |
+
|
| 467 |
print("Creating Figure 4...")
|
| 468 |
create_figure_4(fig4_subplot1, fig4_subplot2)
|
| 469 |
|
| 470 |
+
print("Figures saved as 'figure_2_reproduced.pdf', 'figure_3_reproduced.pdf' and 'figure_4_reproduced.pdf'")
|
| 471 |
|
| 472 |
if __name__ == "__main__":
|
| 473 |
main()
|