Rename figure 2 and figure 4 to be accurate
Browse files- figures/README.md +5 -5
- figures/figure_2/figure_2_subplot_1.csv +96 -42
- figures/{figure_4/figure_4_subplot_2.csv → figure_2/figure_2_subplot_2.csv} +0 -0
- figures/figure_4/figure_4_subplot_1.csv +42 -96
- figures/{figure_2/figure_2_subplot_2_3.csv → figure_4/figure_4_subplot_2_3.csv} +0 -0
- figures/reproduce_figures.py +221 -220
figures/README.md
CHANGED
|
@@ -36,16 +36,16 @@ This will create:
|
|
| 36 |
Each CSV file contains the processed data for its respective subplot:
|
| 37 |
|
| 38 |
### Figure 2 Files
|
| 39 |
-
- **figure_2_subplot_1.csv**:
|
| 40 |
-
- **
|
| 41 |
|
| 42 |
### Figure 3 Files
|
| 43 |
- **figure_3_subplot_1.pdf**: Contains the left panel of Figure 3 with adversarial examples pre-made from the models contained in the `models/MLPs/` directory
|
| 44 |
- **figure_3_subplot_2.csv**: Contains adversarial robustness data with columns for model_name, epsilon (adversarial perturbation budget), accuracy, avg_correct_prob (mean probability for correct class), and prob_error_bar (error bars for probability measurements)
|
| 45 |
|
| 46 |
### Figure 4 Files
|
| 47 |
-
- **figure_4_subplot_1.csv**:
|
| 48 |
-
- **
|
| 49 |
|
| 50 |
## Reproducibility
|
| 51 |
|
|
@@ -55,4 +55,4 @@ The reproduction script creates pixel-perfect recreations of the original figure
|
|
| 55 |
- Matching legend positioning and styling
|
| 56 |
- Equivalent subplot layouts and spacing
|
| 57 |
|
| 58 |
-
This ensures full reproducibility of Figure 2 (
|
|
|
|
| 36 |
Each CSV file contains the processed data for its respective subplot:
|
| 37 |
|
| 38 |
### Figure 2 Files
|
| 39 |
+
- **figure_2_subplot_1.csv**: MLP model results with frontier points for different optimizers and techniques
|
| 40 |
+
- **figure_2_subplot_2.csv**: Transformer model results with frontier points for different optimizers and techniques
|
| 41 |
|
| 42 |
### Figure 3 Files
|
| 43 |
- **figure_3_subplot_1.pdf**: Contains the left panel of Figure 3 with adversarial examples pre-made from the models contained in the `models/MLPs/` directory
|
| 44 |
- **figure_3_subplot_2.csv**: Contains adversarial robustness data with columns for model_name, epsilon (adversarial perturbation budget), accuracy, avg_correct_prob (mean probability for correct class), and prob_error_bar (error bars for probability measurements)
|
| 45 |
|
| 46 |
### Figure 4 Files
|
| 47 |
+
- **figure_4_subplot_1.csv**: Contains points used to plot the frontier of validation loss vs. Lipschitz constant with columns for technique, learning rate, w_max, final validation loss, Lipschitz constant, optimizer, etc.
|
| 48 |
+
- **figure_4_subplot_2_3.csv**: Contains results for top validation accuracy model for each of our tested techniques with columns for technique, learning rate, w_max, final validation accuracy, Lipschitz constant, optimizer, etc.
|
| 49 |
|
| 50 |
## Reproducibility
|
| 51 |
|
|
|
|
| 55 |
- Matching legend positioning and styling
|
| 56 |
- Equivalent subplot layouts and spacing
|
| 57 |
|
| 58 |
+
This ensures full reproducibility of Figure 2 (MLP and Transformer optimizer comparisons), Figure 3 (adversarial robustness analysis), and Figure 4 (Lipschitz weight constraints comparison) from the saved CSV data.
|
figures/figure_2/figure_2_subplot_1.csv
CHANGED
|
@@ -1,42 +1,96 @@
|
|
| 1 |
-
technique,lr,w_max,final_train_loss,final_val_loss,final_train_acc,final_val_acc,spectral_wd,wd,lipschitz,optim,log_lipschitz,log_bin
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
none,0.
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
none,
|
| 20 |
-
none,0.
|
| 21 |
-
|
| 22 |
-
spec_hammer,
|
| 23 |
-
spec_hammer,0.
|
| 24 |
-
|
| 25 |
-
spec_wd,0.
|
| 26 |
-
spec_hammer,0.
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
spec_wd,0.
|
| 33 |
-
spec_wd,0.
|
| 34 |
-
spec_wd,0.
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
spec_wd,
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
spec_wd,0.
|
| 41 |
-
spec_wd,0.
|
| 42 |
-
spec_wd,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
technique,lr,w_max,final_train_loss,final_val_loss,final_train_acc,final_val_acc,spectral_wd,wd,lipschitz,model,optim,log_lipschitz,log_bin
|
| 2 |
+
spec_hammer,0.005011872336272719,5.0,0.9515809590617815,1.2619053721427917,0.6875,0.554882824420929,0.0,0.0,129.45342456643067,mlp,adam,2.1121135437676575,"(2.086, 2.186]"
|
| 3 |
+
spec_hammer,0.005011872336272719,6.0,0.8196051319440206,1.263916289806366,0.73828125,0.5677734613418579,0.0,0.0,222.53532028321317,mlp,adam,2.3473989509875897,"(2.286, 2.387]"
|
| 4 |
+
spec_hammer,0.005011872336272719,7.0,0.737185575067997,1.2809421181678773,0.765625,0.564453125,0.0,0.0,351.3358190879143,mlp,adam,2.5457224288675264,"(2.487, 2.588]"
|
| 5 |
+
none,0.0012589254117941675,1.0,1.0784005771080654,1.2941604197025298,0.619140625,0.54248046875,0.0,1.0,506.6237385752386,mlp,adam,2.7046855354164254,"(2.688, 2.788]"
|
| 6 |
+
spec_hammer,0.0025118864315095794,4.0,1.0748761147260666,1.3025354504585267,0.650390625,0.55078125,0.0,0.0,66.43864942526967,mlp,adam,1.822420795466376,"(1.784, 1.885]"
|
| 7 |
+
none,0.0025118864315095794,1.0,1.0336566790938377,1.3073923587799072,0.66015625,0.5365234613418579,0.0,0.5994842503189409,724.9802792613901,mlp,adam,2.8603261931579698,"(2.788, 2.889]"
|
| 8 |
+
spec_normalize,0.0012589254117941675,8.0,1.1368991235891979,1.3106480538845062,0.60546875,0.5386719107627869,0.0,0.0,427.1077822976168,mlp,adam,2.6305374847448286,"(2.588, 2.688]"
|
| 9 |
+
none,0.0012589254117941675,1.0,0.97108194231987,1.3244979202747345,0.669921875,0.5448242425918579,0.0,0.0774263682681127,850.8291475823788,mlp,adam,2.929842359483344,"(2.889, 2.989]"
|
| 10 |
+
none,0.005011872336272719,1.0,1.0748743638396263,1.3389107167720795,0.61328125,0.53369140625,0.0,0.3593813663804626,976.9287740280438,mlp,adam,2.9898629013091567,"(2.989, 3.089]"
|
| 11 |
+
spec_wd,0.01,1.0,0.8677087550361952,1.3512717723846435,0.748046875,0.536816418170929,0.027825594022071243,0.0,250.04869287505062,mlp,adam,2.398024588623193,"(2.387, 2.487]"
|
| 12 |
+
spec_wd,0.0025118864315095794,1.0,1.1031840691963832,1.3541878283023834,0.6796875,0.534472644329071,0.01,0.0,101.3752057008878,mlp,adam,2.0059317484490014,"(1.985, 2.086]"
|
| 13 |
+
none,0.0025118864315095794,1.0,0.941372700035572,1.3675941169261931,0.671875,0.53271484375,0.0,0.1291549665014884,1525.730299032476,mlp,adam,3.1834777708444437,"(3.089, 3.19]"
|
| 14 |
+
none,0.0025118864315095794,1.0,0.921451210975647,1.3749362766742705,0.68359375,0.53125,0.0,0.046415888336127774,2049.1846266518255,mlp,adam,3.311581089061943,"(3.29, 3.391]"
|
| 15 |
+
none,0.005011872336272719,1.0,1.0828657895326614,1.3764999449253081,0.615234375,0.5179687738418579,0.0,0.21544346900318834,1818.0184033366806,mlp,adam,3.25959827516049,"(3.19, 3.29]"
|
| 16 |
+
none,0.005011872336272719,1.0,1.0658802141745884,1.3928723454475402,0.591796875,0.518750011920929,0.0,0.1291549665014884,4647.603675304063,mlp,adam,3.6672290864719645,"(3.591, 3.692]"
|
| 17 |
+
spec_wd,0.005011872336272719,1.0,1.1334923108418782,1.3999023735523224,0.69921875,0.5406250357627869,0.027825594022071243,0.0,78.03448865238366,mlp,adam,1.8922865888570233,"(1.885, 1.985]"
|
| 18 |
+
none,0.005011872336272719,1.0,1.025273084640503,1.435728758573532,0.638671875,0.51904296875,0.0,0.0774263682681127,6130.108515767662,mlp,adam,3.7874681625089472,"(3.692, 3.792]"
|
| 19 |
+
none,0.01,1.0,1.243887220819791,1.4366242825984954,0.564453125,0.49951171875,0.0,0.1291549665014884,8180.71540845687,mlp,adam,3.912791284644351,"(3.892, 3.993]"
|
| 20 |
+
none,0.00015848931924611142,1.0,1.3828533291816711,1.4380092024803162,0.51953125,0.4942382872104645,0.0,0.046415888336127774,188.82877449249958,mlp,adam,2.276068174554859,"(2.186, 2.286]"
|
| 21 |
+
spec_wd,0.0025118864315095794,1.0,1.2781847169001896,1.4496609568595886,0.626953125,0.5259765982627869,0.016681005372000592,0.0,50.70119493006016,mlp,adam,1.705018194943119,"(1.684, 1.784]"
|
| 22 |
+
spec_hammer,0.0025118864315095794,3.0,1.3556161274512608,1.453220111131668,0.556640625,0.5142578482627869,0.0,0.0,28.18393379817141,mlp,adam,1.4500016100510382,"(1.383, 1.483]"
|
| 23 |
+
spec_hammer,0.07943282347242814,3.0,1.6074494669834773,1.6066371738910674,0.458984375,0.4454101622104645,0.0,0.0,32.22924571340854,mlp,adam,1.5082501414931675,"(1.483, 1.584]"
|
| 24 |
+
spec_wd,0.005011872336272719,1.0,1.5296800186236699,1.6472486913204194,0.5546875,0.4999023377895355,0.046415888336127774,0.0,23.30438244803372,mlp,adam,1.3674375988813583,"(1.283, 1.383]"
|
| 25 |
+
spec_wd,0.0025118864315095794,1.0,1.5805873225132625,1.6653954088687897,0.546875,0.49394533038139343,0.027825594022071243,0.0,18.804192228471603,mlp,adam,1.2742546821746796,"(1.182, 1.283]"
|
| 26 |
+
spec_hammer,0.1584893192461114,3.0,1.6966613233089447,1.6707084357738495,0.4296875,0.41621094942092896,0.0,0.0,39.67848391525369,mlp,adam,1.5985550697367212,"(1.584, 1.684]"
|
| 27 |
+
spec_wd,0.000630957344480193,1.0,1.7244041413068771,1.758846378326416,0.46484375,0.455078125,0.01,0.0,11.477750861916165,mlp,adam,1.0598567936406547,"(0.981, 1.082]"
|
| 28 |
+
orthogonal,0.01,2.0,1.7524752169847488,1.7782266974449157,0.453125,0.4374023377895355,0.0,0.0,7.964808580580092,mlp,adam,0.9011753427992026,"(0.881, 0.981]"
|
| 29 |
+
spec_hammer,0.07943282347242814,2.0,1.8279538204272587,1.8279983818531036,0.37890625,0.39179688692092896,0.0,0.0,12.094601317362297,mlp,adam,1.0825915569870004,"(1.082, 1.182]"
|
| 30 |
+
spec_wd,0.00031622776601683794,1.0,1.9295574476321538,1.9457779407501221,0.4140625,0.41083985567092896,0.01,0.0,5.443420300581499,mlp,adam,0.7358718686136261,"(0.68, 0.781]"
|
| 31 |
+
spec_wd,0.000630957344480193,1.0,1.9567436476548512,1.975956916809082,0.427734375,0.41552734375,0.016681005372000592,0.0,4.784294033468242,mlp,adam,0.6798178627003927,"(0.58, 0.68]"
|
| 32 |
+
spec_wd,0.01,1.0,2.0174430906772614,2.074524438381195,0.384765625,0.4083007872104645,0.1291549665014884,0.0,3.7747499356204504,mlp,adam,0.5768881863818103,"(0.479, 0.58]"
|
| 33 |
+
spec_wd,0.00015848931924611142,1.0,2.089405337969462,2.097890019416809,0.37890625,0.36992189288139343,0.01,0.0,2.565830480395789,mlp,adam,0.40922795996539274,"(0.379, 0.479]"
|
| 34 |
+
spec_wd,0.00031622776601683794,1.0,2.1186704138914743,2.1283711552619935,0.37109375,0.3671875,0.016681005372000592,0.0,2.1046149632763105,mlp,adam,0.32317265379649435,"(0.279, 0.379]"
|
| 35 |
+
spec_wd,0.000630957344480193,1.0,2.141642322142919,2.153390038013458,0.35546875,0.36113283038139343,0.027825594022071243,0.0,1.752324585442925,mlp,adam,0.24361455423521955,"(0.178, 0.279]"
|
| 36 |
+
spec_wd,0.0012589254117941675,1.0,2.1743771135807037,2.1884663224220278,0.359375,0.35371094942092896,0.046415888336127774,0.0,1.3431498101158814,mlp,adam,0.12812445502027967,"(0.0779, 0.178]"
|
| 37 |
+
spec_wd,7.943282347242822e-05,1.0,2.2084874312082925,2.212341868877411,0.337890625,0.3106445372104645,0.01,0.0,0.9947789729433428,mlp,adam,-0.002273403168073618,"(-0.0225, 0.0779]"
|
| 38 |
+
spec_wd,0.00015848931924611142,1.0,2.227921575307846,2.2318353772163393,0.3203125,0.30195313692092896,0.016681005372000592,0.0,0.7608930861204106,mlp,adam,-0.11867636211234922,"(-0.123, -0.0225]"
|
| 39 |
+
spec_wd,0.00031622776601683794,1.0,2.2452704707781472,2.2492041110992433,0.322265625,0.2826171815395355,0.027825594022071243,0.0,0.5636803988918185,mlp,adam,-0.24896706683162642,"(-0.324, -0.223]"
|
| 40 |
+
spec_wd,0.000630957344480193,1.0,2.2568132082621255,2.2616291284561156,0.283203125,0.2513671815395355,0.046415888336127774,0.0,0.4342262761043406,mlp,adam,-0.3622838998547839,"(-0.424, -0.324]"
|
| 41 |
+
spec_wd,0.01,1.0,2.2436083257198334,2.2628151893615724,0.279296875,0.24101562798023224,0.21544346900318834,0.0,0.6067265186307454,mlp,adam,-0.21700702262096228,"(-0.223, -0.123]"
|
| 42 |
+
spec_wd,3.9810717055349695e-05,1.0,2.265566815932592,2.266888773441315,0.251953125,0.23554687201976776,0.01,0.0,0.36165306774861455,mlp,adam,-0.4417078466165141,"(-0.524, -0.424]"
|
| 43 |
+
spec_wd,7.943282347242822e-05,1.0,2.2741187711556754,2.2754802346229552,0.228515625,0.21855469048023224,0.016681005372000592,0.0,0.27120713254860657,mlp,adam,-0.5666988930285408,"(-0.625, -0.524]"
|
| 44 |
+
spec_wd,0.00015848931924611142,1.0,2.2818308671315513,2.283188211917877,0.189453125,0.18427734076976776,0.027825594022071243,0.0,0.19219393191999723,mlp,adam,-0.7162603282973528,"(-0.725, -0.625]"
|
| 45 |
+
spec_wd,0.00031622776601683794,1.0,2.28839044769605,2.28968540430069,0.15625,0.15996094048023224,0.046415888336127774,0.0,0.12763887726802262,mlp,adam,-0.8940170245854444,"(-0.926, -0.826]"
|
| 46 |
+
spec_wd,0.005011872336272719,1.0,2.286033570766449,2.291467344760895,0.173828125,0.16953125596046448,0.21544346900318834,0.0,0.15584988833724137,mlp,adam,-0.8072935045487537,"(-0.826, -0.725]"
|
| 47 |
+
spec_wd,0.31622776601683794,1.0,2.287269373734792,2.301726531982422,0.09765625,0.10732422024011612,0.1291549665014884,0.0,2498.0827439935415,mlp,adam,3.397606819412279,"(3.391, 3.491]"
|
| 48 |
+
spec_normalize,0.01995262314968879,2.0,2.3025851249694824,2.3025851249694824,0.08984375,0.09990234673023224,0.0,0.0,7.344295024871721,mlp,adam,0.8659501144214397,"(0.781, 0.881]"
|
| 49 |
+
spec_wd,0.1584893192461114,1.0,2.1716247498989105,2.3687571406364443,0.173828125,0.16093750298023224,0.046415888336127774,0.0,3483.274928819459,mlp,adam,3.5419877539176285,"(3.491, 3.591]"
|
| 50 |
+
spec_normalize,1.0,6.0,0.6135424623886744,1.145600324869156,0.857421875,0.603710949420929,0.0,0.0,216.73144763249007,mlp,muon,2.3359219318198665,"(2.286, 2.387]"
|
| 51 |
+
none,1.584893192461114,0.0,0.49193960676590603,1.1503977954387665,0.900390625,0.5956054925918579,0.0,0.1,326.6314794336425,mlp,muon,2.5140580379802833,"(2.487, 2.588]"
|
| 52 |
+
none,3.981071705534973,0.0,0.7261431738734245,1.1708129107952119,0.802734375,0.5909180045127869,0.0,0.1,296.2291798364921,mlp,muon,2.4716278361493904,"(2.387, 2.487]"
|
| 53 |
+
soft_cap,0.6309573444801934,6.0,0.546201099952062,1.178132140636444,0.9140625,0.58642578125,0.0,0.0,113.50021642736891,mlp,muon,2.054996689662379,"(1.985, 2.086]"
|
| 54 |
+
spec_normalize,3.981071705534973,8.0,0.5286508773763975,1.1918570876121521,0.857421875,0.589550793170929,0.0,0.0,512.7184240622878,mlp,muon,2.70987892369127,"(2.688, 2.788]"
|
| 55 |
+
spec_normalize,0.3981071705534973,5.0,0.8090496485431989,1.1927893519401551,0.798828125,0.5970703363418579,0.0,0.0,125.49701822317675,mlp,muon,2.0986334072146304,"(2.086, 2.186]"
|
| 56 |
+
soft_cap,1.584893192461114,7.0,0.46543631081779796,1.1975364863872529,0.92578125,0.5869140625,0.0,0.0,170.1667443487106,mlp,muon,2.230874689961104,"(2.186, 2.286]"
|
| 57 |
+
hard_cap,0.6309573444801934,5.0,0.5606682449579239,1.206486052274704,0.919921875,0.587207019329071,0.0,0.0,92.23833076924721,mlp,muon,1.9649114349602712,"(1.885, 1.985]"
|
| 58 |
+
soft_cap,0.3981071705534973,5.0,0.7737894381086031,1.2196079075336457,0.833984375,0.584667980670929,0.0,0.0,67.99641583474592,mlp,muon,1.8324860211736487,"(1.784, 1.885]"
|
| 59 |
+
spec_normalize,6.309573444801936,9.0,0.6397276371717453,1.2377217531204223,0.779296875,0.565136730670929,0.0,0.0,730.5870873089353,mlp,muon,2.863671992047765,"(2.788, 2.889]"
|
| 60 |
+
none,6.309573444801936,0.0,0.5984488800168037,1.2627243876457215,0.8125,0.561230480670929,0.0,0.046415888336127774,791.3652552469557,mlp,muon,2.8983769787253495,"(2.889, 2.989]"
|
| 61 |
+
spec_normalize,10.0,10.0,0.8100952530900637,1.2630029141902923,0.712890625,0.5615234375,0.0,0.0,1002.3103625525505,mlp,muon,3.001002220406808,"(2.989, 3.089]"
|
| 62 |
+
hard_cap,0.6309573444801934,4.0,0.9015005355079969,1.271504408121109,0.77734375,0.57177734375,0.0,0.0,47.20795187848079,mlp,muon,1.6740151589322285,"(1.584, 1.684]"
|
| 63 |
+
soft_cap,10.0,10.0,0.8678075224161148,1.2838713943958282,0.6875,0.5531250238418579,0.0,0.0,392.5746399125229,mlp,muon,2.5939222410057474,"(2.588, 2.688]"
|
| 64 |
+
soft_cap,0.025118864315095794,4.0,0.9869052916765213,1.3121922612190247,0.693359375,0.5497070550918579,0.0,0.0,54.07914961694873,mlp,muon,1.7330298537995121,"(1.684, 1.784]"
|
| 65 |
+
soft_cap,0.3981071705534973,4.0,1.0925154983997345,1.3367777287960052,0.697265625,0.560351550579071,0.0,0.0,35.62367706625288,mlp,muon,1.5517387451991804,"(1.483, 1.584]"
|
| 66 |
+
none,10.0,0.0,1.0049275904893875,1.342218542098999,0.658203125,0.5269531607627869,0.0,0.021544346900318832,1701.8186317345078,mlp,muon,3.230913274059976,"(3.19, 3.29]"
|
| 67 |
+
none,2.5118864315095824,0.0,0.8702319413423538,1.3664021372795105,0.708984375,0.538769543170929,0.0,0.01,4266.531814529947,mlp,muon,3.6300749884131167,"(3.591, 3.692]"
|
| 68 |
+
none,6.309573444801936,0.0,1.3039097587267559,1.3991230189800263,0.505859375,0.503222644329071,0.0,0.01,7049.651521704099,mlp,muon,3.8481676494819683,"(3.792, 3.892]"
|
| 69 |
+
spec_hammer,0.15848931924611143,3.0,1.1873172769943874,1.416911232471466,0.685546875,0.54150390625,0.0,0.0,30.371566399121594,mlp,muon,1.4824671910276543,"(1.383, 1.483]"
|
| 70 |
+
none,3.981071705534973,0.0,0.22559045689801374,1.666180557012558,0.947265625,0.54443359375,0.0,0.021544346900318832,2482.9099898930976,mlp,muon,3.394960975856903,"(3.391, 3.491]"
|
| 71 |
+
spec_wd,0.015848931924611134,1.0,1.6119112869103749,1.7064568161964417,0.5859375,0.5047851800918579,0.01,0.0,13.85554856777875,mlp,muon,1.1416237250112864,"(1.082, 1.182]"
|
| 72 |
+
spec_hammer,0.3981071705534973,2.0,1.6606111526489258,1.737447053194046,0.54296875,0.4800781309604645,0.0,0.0,11.200719275761823,mlp,muon,1.049245912622338,"(0.981, 1.082]"
|
| 73 |
+
none,0.1,0.0,1.7180798798799515,1.7425432860851289,0.4609375,0.447265625,0.0,0.46415888336127775,9.430748150434574,mlp,muon,0.9745461471081676,"(0.881, 0.981]"
|
| 74 |
+
spec_wd,0.15848931924611143,1.0,1.7357934166987736,1.8928824245929718,0.69921875,0.5130859613418579,0.0774263682681127,0.0,16.699798610106914,mlp,muon,1.2227112338394006,"(1.182, 1.283]"
|
| 75 |
+
none,0.3981071705534973,0.0,0.008818869362585247,1.9311249792575835,1.0,0.565722644329071,0.0,0.046415888336127774,1266.201594041739,mlp,muon,3.102502855926169,"(3.089, 3.19]"
|
| 76 |
+
spec_wd,0.01,1.0,1.929346779982249,1.9589250326156615,0.46484375,0.4449218809604645,0.01,0.0,4.75861730167534,mlp,muon,0.677480779298928,"(0.58, 0.68]"
|
| 77 |
+
spec_wd,0.025118864315095794,1.0,2.0198851923147836,2.054336929321289,0.478515625,0.4500976502895355,0.027825594022071243,0.0,3.738501995033509,mlp,muon,0.57269761674191,"(0.479, 0.58]"
|
| 78 |
+
spec_wd,0.1,1.0,1.990190014243126,2.0631633162498475,0.53515625,0.470703125,0.0774263682681127,0.0,5.913547971048514,mlp,muon,0.7718481241836456,"(0.68, 0.781]"
|
| 79 |
+
spec_wd,0.25118864315095807,1.0,1.9745069791873295,2.0720305681228637,0.568359375,0.48359376192092896,0.1291549665014884,0.0,7.554608223885079,mlp,muon,0.8782119470844948,"(0.781, 0.881]"
|
| 80 |
+
spec_wd,0.039810717055349734,1.0,2.0878397126992545,2.1199615478515623,0.458984375,0.43583986163139343,0.046415888336127774,0.0,2.834448725282898,mlp,muon,0.4524686050814032,"(0.379, 0.479]"
|
| 81 |
+
spec_hammer,1.0,1.0,2.147181282440821,2.161078703403473,0.365234375,0.3597656190395355,0.0,0.0,2.23825218773244,mlp,muon,0.3499090176571464,"(0.279, 0.379]"
|
| 82 |
+
spec_hammer,0.6309573444801936,1.0,2.1550684571266174,2.168361949920654,0.3515625,0.33867189288139343,0.0,0.0,1.8674988418364986,mlp,muon,0.2712603411518681,"(0.178, 0.279]"
|
| 83 |
+
spec_hammer,0.25118864315095824,1.0,2.1698385576407113,2.1785532593727113,0.345703125,0.33867189288139343,0.0,0.0,1.4152101372317512,mlp,muon,0.15082093078920447,"(0.0779, 0.178]"
|
| 84 |
+
spec_hammer,0.1,1.0,2.1790361801783242,2.185564732551575,0.345703125,0.3270507752895355,0.0,0.0,1.1631840225552137,mlp,muon,0.0656484281059515,"(-0.0225, 0.0779]"
|
| 85 |
+
spec_wd,0.025118864315095794,1.0,2.225555419921875,2.2346877694129943,0.361328125,0.3545898497104645,0.046415888336127774,0.0,0.7616373136447698,mlp,muon,-0.11825178742681001,"(-0.123, -0.0225]"
|
| 86 |
+
spec_wd,0.039810717055349734,1.0,2.2539119025071463,2.262084949016571,0.34765625,0.3252929747104645,0.0774263682681127,0.0,0.47417001799999065,mlp,muon,-0.3240659101402314,"(-0.424, -0.324]"
|
| 87 |
+
spec_wd,0.6309573444801934,1.0,2.2518043716748557,2.2770498752593995,0.330078125,0.30644533038139343,0.3593813663804626,0.0,0.6785443717909011,mlp,muon,-0.168421747423765,"(-0.223, -0.123]"
|
| 88 |
+
spec_wd,0.01,1.0,2.2768599589665732,2.2787588238716125,0.259765625,0.23408202826976776,0.027825594022071243,0.0,0.22047154838900057,mlp,muon,-0.6566474478057232,"(-0.725, -0.625]"
|
| 89 |
+
spec_wd,0.15848931924611143,1.0,2.2774324218432107,2.2858771085739136,0.3125,0.28583985567092896,0.21544346900318834,0.0,0.2575040777159906,mlp,muon,-0.5892158892809198,"(-0.625, -0.524]"
|
| 90 |
+
spec_wd,0.015848931924611134,1.0,2.285846769809723,2.287567436695099,0.234375,0.2099609375,0.046415888336127774,0.0,0.1426509697946006,mlp,muon,-0.8457252715006871,"(-0.926, -0.826]"
|
| 91 |
+
spec_wd,2.5118864315095824,1.0,2.2647158205509186,2.2902453184127807,0.220703125,0.27363282442092896,0.5994842503189409,0.0,0.5779600219544359,mlp,muon,-0.23810220110378486,"(-0.324, -0.223]"
|
| 92 |
+
spec_wd,1.584893192461114,1.0,2.3011384308338165,2.302428185939789,0.08984375,0.10019531100988388,0.016681005372000592,0.0,3342.4251711889774,mlp,muon,3.524061693240384,"(3.491, 3.591]"
|
| 93 |
+
spec_wd,1.584893192461114,1.0,2.30130398273468,2.3028753638267516,0.08984375,0.10009765625,0.01,0.0,9837.359494148184,mlp,muon,3.9928785424342728,"(3.892, 3.993]"
|
| 94 |
+
spec_wd,6.309573444801936,1.0,2.3011400997638702,2.3037927389144897,0.08984375,0.09990234673023224,0.0774263682681127,0.0,2321.7090285990557,mlp,muon,3.365807790255239,"(3.29, 3.391]"
|
| 95 |
+
none,0.008111308307896872,2.0,1.3065174619356792,1.421818447113037,0.564453125,0.4927734434604645,0.0,0.1,7618.843781501872,mlp,adam,3.881889068791904,"(3.792, 3.892]"
|
| 96 |
+
soft_cap,0.23101297000831597,3.0,1.4630188594261806,1.5657442033290863,0.5625,0.504589855670929,0.0,0.0,15.239435025533158,mlp,muon,1.182968866620278,"(1.182, 1.283]"
|
figures/{figure_4/figure_4_subplot_2.csv → figure_2/figure_2_subplot_2.csv}
RENAMED
|
File without changes
|
figures/figure_4/figure_4_subplot_1.csv
CHANGED
|
@@ -1,96 +1,42 @@
|
|
| 1 |
-
technique,lr,w_max,final_train_loss,final_val_loss,final_train_acc,final_val_acc,spectral_wd,wd,lipschitz,
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
none,0.
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
none,
|
| 20 |
-
none,0.
|
| 21 |
-
|
| 22 |
-
spec_hammer,
|
| 23 |
-
spec_hammer,0.
|
| 24 |
-
|
| 25 |
-
spec_wd,0.
|
| 26 |
-
spec_hammer,0.
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
spec_wd,0.
|
| 33 |
-
spec_wd,0.
|
| 34 |
-
spec_wd,0.
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
spec_wd,
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
spec_wd,0.
|
| 41 |
-
spec_wd,0.
|
| 42 |
-
spec_wd,
|
| 43 |
-
spec_wd,7.943282347242822e-05,1.0,2.2741187711556754,2.2754802346229552,0.228515625,0.21855469048023224,0.016681005372000592,0.0,0.27120713254860657,mlp,adam,-0.5666988930285408,"(-0.625, -0.524]"
|
| 44 |
-
spec_wd,0.00015848931924611142,1.0,2.2818308671315513,2.283188211917877,0.189453125,0.18427734076976776,0.027825594022071243,0.0,0.19219393191999723,mlp,adam,-0.7162603282973528,"(-0.725, -0.625]"
|
| 45 |
-
spec_wd,0.00031622776601683794,1.0,2.28839044769605,2.28968540430069,0.15625,0.15996094048023224,0.046415888336127774,0.0,0.12763887726802262,mlp,adam,-0.8940170245854444,"(-0.926, -0.826]"
|
| 46 |
-
spec_wd,0.005011872336272719,1.0,2.286033570766449,2.291467344760895,0.173828125,0.16953125596046448,0.21544346900318834,0.0,0.15584988833724137,mlp,adam,-0.8072935045487537,"(-0.826, -0.725]"
|
| 47 |
-
spec_wd,0.31622776601683794,1.0,2.287269373734792,2.301726531982422,0.09765625,0.10732422024011612,0.1291549665014884,0.0,2498.0827439935415,mlp,adam,3.397606819412279,"(3.391, 3.491]"
|
| 48 |
-
spec_normalize,0.01995262314968879,2.0,2.3025851249694824,2.3025851249694824,0.08984375,0.09990234673023224,0.0,0.0,7.344295024871721,mlp,adam,0.8659501144214397,"(0.781, 0.881]"
|
| 49 |
-
spec_wd,0.1584893192461114,1.0,2.1716247498989105,2.3687571406364443,0.173828125,0.16093750298023224,0.046415888336127774,0.0,3483.274928819459,mlp,adam,3.5419877539176285,"(3.491, 3.591]"
|
| 50 |
-
spec_normalize,1.0,6.0,0.6135424623886744,1.145600324869156,0.857421875,0.603710949420929,0.0,0.0,216.73144763249007,mlp,muon,2.3359219318198665,"(2.286, 2.387]"
|
| 51 |
-
none,1.584893192461114,0.0,0.49193960676590603,1.1503977954387665,0.900390625,0.5956054925918579,0.0,0.1,326.6314794336425,mlp,muon,2.5140580379802833,"(2.487, 2.588]"
|
| 52 |
-
none,3.981071705534973,0.0,0.7261431738734245,1.1708129107952119,0.802734375,0.5909180045127869,0.0,0.1,296.2291798364921,mlp,muon,2.4716278361493904,"(2.387, 2.487]"
|
| 53 |
-
soft_cap,0.6309573444801934,6.0,0.546201099952062,1.178132140636444,0.9140625,0.58642578125,0.0,0.0,113.50021642736891,mlp,muon,2.054996689662379,"(1.985, 2.086]"
|
| 54 |
-
spec_normalize,3.981071705534973,8.0,0.5286508773763975,1.1918570876121521,0.857421875,0.589550793170929,0.0,0.0,512.7184240622878,mlp,muon,2.70987892369127,"(2.688, 2.788]"
|
| 55 |
-
spec_normalize,0.3981071705534973,5.0,0.8090496485431989,1.1927893519401551,0.798828125,0.5970703363418579,0.0,0.0,125.49701822317675,mlp,muon,2.0986334072146304,"(2.086, 2.186]"
|
| 56 |
-
soft_cap,1.584893192461114,7.0,0.46543631081779796,1.1975364863872529,0.92578125,0.5869140625,0.0,0.0,170.1667443487106,mlp,muon,2.230874689961104,"(2.186, 2.286]"
|
| 57 |
-
hard_cap,0.6309573444801934,5.0,0.5606682449579239,1.206486052274704,0.919921875,0.587207019329071,0.0,0.0,92.23833076924721,mlp,muon,1.9649114349602712,"(1.885, 1.985]"
|
| 58 |
-
soft_cap,0.3981071705534973,5.0,0.7737894381086031,1.2196079075336457,0.833984375,0.584667980670929,0.0,0.0,67.99641583474592,mlp,muon,1.8324860211736487,"(1.784, 1.885]"
|
| 59 |
-
spec_normalize,6.309573444801936,9.0,0.6397276371717453,1.2377217531204223,0.779296875,0.565136730670929,0.0,0.0,730.5870873089353,mlp,muon,2.863671992047765,"(2.788, 2.889]"
|
| 60 |
-
none,6.309573444801936,0.0,0.5984488800168037,1.2627243876457215,0.8125,0.561230480670929,0.0,0.046415888336127774,791.3652552469557,mlp,muon,2.8983769787253495,"(2.889, 2.989]"
|
| 61 |
-
spec_normalize,10.0,10.0,0.8100952530900637,1.2630029141902923,0.712890625,0.5615234375,0.0,0.0,1002.3103625525505,mlp,muon,3.001002220406808,"(2.989, 3.089]"
|
| 62 |
-
hard_cap,0.6309573444801934,4.0,0.9015005355079969,1.271504408121109,0.77734375,0.57177734375,0.0,0.0,47.20795187848079,mlp,muon,1.6740151589322285,"(1.584, 1.684]"
|
| 63 |
-
soft_cap,10.0,10.0,0.8678075224161148,1.2838713943958282,0.6875,0.5531250238418579,0.0,0.0,392.5746399125229,mlp,muon,2.5939222410057474,"(2.588, 2.688]"
|
| 64 |
-
soft_cap,0.025118864315095794,4.0,0.9869052916765213,1.3121922612190247,0.693359375,0.5497070550918579,0.0,0.0,54.07914961694873,mlp,muon,1.7330298537995121,"(1.684, 1.784]"
|
| 65 |
-
soft_cap,0.3981071705534973,4.0,1.0925154983997345,1.3367777287960052,0.697265625,0.560351550579071,0.0,0.0,35.62367706625288,mlp,muon,1.5517387451991804,"(1.483, 1.584]"
|
| 66 |
-
none,10.0,0.0,1.0049275904893875,1.342218542098999,0.658203125,0.5269531607627869,0.0,0.021544346900318832,1701.8186317345078,mlp,muon,3.230913274059976,"(3.19, 3.29]"
|
| 67 |
-
none,2.5118864315095824,0.0,0.8702319413423538,1.3664021372795105,0.708984375,0.538769543170929,0.0,0.01,4266.531814529947,mlp,muon,3.6300749884131167,"(3.591, 3.692]"
|
| 68 |
-
none,6.309573444801936,0.0,1.3039097587267559,1.3991230189800263,0.505859375,0.503222644329071,0.0,0.01,7049.651521704099,mlp,muon,3.8481676494819683,"(3.792, 3.892]"
|
| 69 |
-
spec_hammer,0.15848931924611143,3.0,1.1873172769943874,1.416911232471466,0.685546875,0.54150390625,0.0,0.0,30.371566399121594,mlp,muon,1.4824671910276543,"(1.383, 1.483]"
|
| 70 |
-
none,3.981071705534973,0.0,0.22559045689801374,1.666180557012558,0.947265625,0.54443359375,0.0,0.021544346900318832,2482.9099898930976,mlp,muon,3.394960975856903,"(3.391, 3.491]"
|
| 71 |
-
spec_wd,0.015848931924611134,1.0,1.6119112869103749,1.7064568161964417,0.5859375,0.5047851800918579,0.01,0.0,13.85554856777875,mlp,muon,1.1416237250112864,"(1.082, 1.182]"
|
| 72 |
-
spec_hammer,0.3981071705534973,2.0,1.6606111526489258,1.737447053194046,0.54296875,0.4800781309604645,0.0,0.0,11.200719275761823,mlp,muon,1.049245912622338,"(0.981, 1.082]"
|
| 73 |
-
none,0.1,0.0,1.7180798798799515,1.7425432860851289,0.4609375,0.447265625,0.0,0.46415888336127775,9.430748150434574,mlp,muon,0.9745461471081676,"(0.881, 0.981]"
|
| 74 |
-
spec_wd,0.15848931924611143,1.0,1.7357934166987736,1.8928824245929718,0.69921875,0.5130859613418579,0.0774263682681127,0.0,16.699798610106914,mlp,muon,1.2227112338394006,"(1.182, 1.283]"
|
| 75 |
-
none,0.3981071705534973,0.0,0.008818869362585247,1.9311249792575835,1.0,0.565722644329071,0.0,0.046415888336127774,1266.201594041739,mlp,muon,3.102502855926169,"(3.089, 3.19]"
|
| 76 |
-
spec_wd,0.01,1.0,1.929346779982249,1.9589250326156615,0.46484375,0.4449218809604645,0.01,0.0,4.75861730167534,mlp,muon,0.677480779298928,"(0.58, 0.68]"
|
| 77 |
-
spec_wd,0.025118864315095794,1.0,2.0198851923147836,2.054336929321289,0.478515625,0.4500976502895355,0.027825594022071243,0.0,3.738501995033509,mlp,muon,0.57269761674191,"(0.479, 0.58]"
|
| 78 |
-
spec_wd,0.1,1.0,1.990190014243126,2.0631633162498475,0.53515625,0.470703125,0.0774263682681127,0.0,5.913547971048514,mlp,muon,0.7718481241836456,"(0.68, 0.781]"
|
| 79 |
-
spec_wd,0.25118864315095807,1.0,1.9745069791873295,2.0720305681228637,0.568359375,0.48359376192092896,0.1291549665014884,0.0,7.554608223885079,mlp,muon,0.8782119470844948,"(0.781, 0.881]"
|
| 80 |
-
spec_wd,0.039810717055349734,1.0,2.0878397126992545,2.1199615478515623,0.458984375,0.43583986163139343,0.046415888336127774,0.0,2.834448725282898,mlp,muon,0.4524686050814032,"(0.379, 0.479]"
|
| 81 |
-
spec_hammer,1.0,1.0,2.147181282440821,2.161078703403473,0.365234375,0.3597656190395355,0.0,0.0,2.23825218773244,mlp,muon,0.3499090176571464,"(0.279, 0.379]"
|
| 82 |
-
spec_hammer,0.6309573444801936,1.0,2.1550684571266174,2.168361949920654,0.3515625,0.33867189288139343,0.0,0.0,1.8674988418364986,mlp,muon,0.2712603411518681,"(0.178, 0.279]"
|
| 83 |
-
spec_hammer,0.25118864315095824,1.0,2.1698385576407113,2.1785532593727113,0.345703125,0.33867189288139343,0.0,0.0,1.4152101372317512,mlp,muon,0.15082093078920447,"(0.0779, 0.178]"
|
| 84 |
-
spec_hammer,0.1,1.0,2.1790361801783242,2.185564732551575,0.345703125,0.3270507752895355,0.0,0.0,1.1631840225552137,mlp,muon,0.0656484281059515,"(-0.0225, 0.0779]"
|
| 85 |
-
spec_wd,0.025118864315095794,1.0,2.225555419921875,2.2346877694129943,0.361328125,0.3545898497104645,0.046415888336127774,0.0,0.7616373136447698,mlp,muon,-0.11825178742681001,"(-0.123, -0.0225]"
|
| 86 |
-
spec_wd,0.039810717055349734,1.0,2.2539119025071463,2.262084949016571,0.34765625,0.3252929747104645,0.0774263682681127,0.0,0.47417001799999065,mlp,muon,-0.3240659101402314,"(-0.424, -0.324]"
|
| 87 |
-
spec_wd,0.6309573444801934,1.0,2.2518043716748557,2.2770498752593995,0.330078125,0.30644533038139343,0.3593813663804626,0.0,0.6785443717909011,mlp,muon,-0.168421747423765,"(-0.223, -0.123]"
|
| 88 |
-
spec_wd,0.01,1.0,2.2768599589665732,2.2787588238716125,0.259765625,0.23408202826976776,0.027825594022071243,0.0,0.22047154838900057,mlp,muon,-0.6566474478057232,"(-0.725, -0.625]"
|
| 89 |
-
spec_wd,0.15848931924611143,1.0,2.2774324218432107,2.2858771085739136,0.3125,0.28583985567092896,0.21544346900318834,0.0,0.2575040777159906,mlp,muon,-0.5892158892809198,"(-0.625, -0.524]"
|
| 90 |
-
spec_wd,0.015848931924611134,1.0,2.285846769809723,2.287567436695099,0.234375,0.2099609375,0.046415888336127774,0.0,0.1426509697946006,mlp,muon,-0.8457252715006871,"(-0.926, -0.826]"
|
| 91 |
-
spec_wd,2.5118864315095824,1.0,2.2647158205509186,2.2902453184127807,0.220703125,0.27363282442092896,0.5994842503189409,0.0,0.5779600219544359,mlp,muon,-0.23810220110378486,"(-0.324, -0.223]"
|
| 92 |
-
spec_wd,1.584893192461114,1.0,2.3011384308338165,2.302428185939789,0.08984375,0.10019531100988388,0.016681005372000592,0.0,3342.4251711889774,mlp,muon,3.524061693240384,"(3.491, 3.591]"
|
| 93 |
-
spec_wd,1.584893192461114,1.0,2.30130398273468,2.3028753638267516,0.08984375,0.10009765625,0.01,0.0,9837.359494148184,mlp,muon,3.9928785424342728,"(3.892, 3.993]"
|
| 94 |
-
spec_wd,6.309573444801936,1.0,2.3011400997638702,2.3037927389144897,0.08984375,0.09990234673023224,0.0774263682681127,0.0,2321.7090285990557,mlp,muon,3.365807790255239,"(3.29, 3.391]"
|
| 95 |
-
none,0.008111308307896872,2.0,1.3065174619356792,1.421818447113037,0.564453125,0.4927734434604645,0.0,0.1,7618.843781501872,mlp,adam,3.881889068791904,"(3.792, 3.892]"
|
| 96 |
-
soft_cap,0.23101297000831597,3.0,1.4630188594261806,1.5657442033290863,0.5625,0.504589855670929,0.0,0.0,15.239435025533158,mlp,muon,1.182968866620278,"(1.182, 1.283]"
|
|
|
|
| 1 |
+
technique,lr,w_max,final_train_loss,final_val_loss,final_train_acc,final_val_acc,spectral_wd,wd,lipschitz,optim,log_lipschitz,log_bin
|
| 2 |
+
spec_normalize,1.0,6,0.6135424623886744,1.145600324869156,0.857421875,0.603710949420929,0.0,0.0,216.73144763249007,muon,2.3359219318198665,"(2.285, 2.377]"
|
| 3 |
+
none,1.584893192461114,0,0.49193960676590603,1.1503977954387665,0.900390625,0.5956054925918579,0.0,0.1,326.6314794336425,muon,2.5140580379802833,"(2.469, 2.561]"
|
| 4 |
+
soft_cap,0.6309573444801934,6,0.546201099952062,1.178132140636444,0.9140625,0.58642578125,0.0,0.0,113.50021642736891,muon,2.054996689662379,"(2.009, 2.101]"
|
| 5 |
+
none,0.3981071705534973,0,0.28087465713421506,1.1853827595710755,0.974609375,0.591601550579071,0.0,0.1,379.7440522356983,muon,2.579490980424371,"(2.561, 2.653]"
|
| 6 |
+
none,6.309573444801936,0,0.8561444530884424,1.1901896238327025,0.736328125,0.577832043170929,0.0,0.1,287.3388110236937,muon,2.4583942903673064,"(2.377, 2.469]"
|
| 7 |
+
spec_normalize,3.981071705534973,8,0.5286508773763975,1.1918570876121521,0.857421875,0.589550793170929,0.0,0.0,512.7184240622878,muon,2.70987892369127,"(2.653, 2.745]"
|
| 8 |
+
soft_cap,1.584893192461114,7,0.46543631081779796,1.1975364863872529,0.92578125,0.5869140625,0.0,0.0,170.1667443487106,muon,2.230874689961104,"(2.193, 2.285]"
|
| 9 |
+
hard_cap,0.6309573444801934,5,0.5606682449579239,1.206486052274704,0.919921875,0.587207019329071,0.0,0.0,92.23833076924721,muon,1.9649114349602712,"(1.918, 2.009]"
|
| 10 |
+
soft_cap,0.3981071705534973,5,0.7737894381086031,1.2196079075336457,0.833984375,0.584667980670929,0.0,0.0,67.99641583474592,muon,1.8324860211736487,"(1.826, 1.918]"
|
| 11 |
+
soft_cap,0.1,6,0.4424842360119025,1.2208800554275512,0.951171875,0.575390636920929,0.0,0.0,131.2053366389354,muon,2.1179514998697213,"(2.101, 2.193]"
|
| 12 |
+
soft_cap,1.0,5,0.8516250724593798,1.2360183537006377,0.791015625,0.579882800579071,0.0,0.0,66.5756347283387,muon,1.8233153156697612,"(1.734, 1.826]"
|
| 13 |
+
spec_normalize,6.309573444801936,9,0.6397276371717453,1.2377217531204223,0.779296875,0.565136730670929,0.0,0.0,730.5870873089353,muon,2.863671992047765,"(2.837, 2.929]"
|
| 14 |
+
spec_normalize,10.0,10,0.8100952530900637,1.2630029141902923,0.712890625,0.5615234375,0.0,0.0,1002.3103625525505,muon,3.001002220406808,"(2.929, 3.021]"
|
| 15 |
+
hard_cap,0.6309573444801934,4,0.9015005355079969,1.271504408121109,0.77734375,0.57177734375,0.0,0.0,47.20795187848079,muon,1.6740151589322285,"(1.642, 1.734]"
|
| 16 |
+
soft_cap,0.06309573444801933,4,1.044749620060126,1.3355577886104584,0.7109375,0.5517578125,0.0,0.0,41.23087050223728,muon,1.6152225041034978,"(1.55, 1.642]"
|
| 17 |
+
none,10.0,0,1.0049275904893875,1.342218542098999,0.658203125,0.5269531607627869,0.0,0.021544346900318832,1701.8186317345078,muon,3.230913274059976,"(3.205, 3.297]"
|
| 18 |
+
soft_cap,0.6309573444801934,4,1.116555059949557,1.3439416527748107,0.685546875,0.5542969107627869,0.0,0.0,35.155756608646705,muon,1.5459964490904399,"(1.458, 1.55]"
|
| 19 |
+
none,2.5118864315095824,0,0.8702319413423538,1.3664021372795105,0.708984375,0.538769543170929,0.0,0.01,4266.531814529947,muon,3.6300749884131167,"(3.572, 3.664]"
|
| 20 |
+
none,0.025118864315095794,0,0.24326033890247345,1.3985329926013947,0.962890625,0.5589843988418579,0.0,0.046415888336127774,576.1887291490207,muon,2.7605647587755975,"(2.745, 2.837]"
|
| 21 |
+
none,6.309573444801936,0,1.3039097587267559,1.3991230189800263,0.505859375,0.503222644329071,0.0,0.01,7049.651521704099,muon,3.8481676494819683,"(3.756, 3.848]"
|
| 22 |
+
spec_hammer,6.309573444801936,10,0.9368437851468722,1.4277557492256165,0.673828125,0.5132812857627869,0.0,0.0,1122.555836802812,muon,3.0502079523588717,"(3.021, 3.113]"
|
| 23 |
+
spec_hammer,0.039810717055349734,3,1.2380208770434062,1.4304382085800171,0.642578125,0.524707019329071,0.0,0.0,28.26259951905857,muon,1.451212104606026,"(1.366, 1.458]"
|
| 24 |
+
none,3.981071705534973,0,0.22559045689801374,1.666180557012558,0.947265625,0.54443359375,0.0,0.021544346900318832,2482.9099898930976,muon,3.394960975856903,"(3.388, 3.48]"
|
| 25 |
+
spec_wd,0.015848931924611134,1,1.6119112869103749,1.7064568161964417,0.5859375,0.5047851800918579,0.01,0.0,13.85554856777875,muon,1.1416237250112864,"(1.09, 1.182]"
|
| 26 |
+
spec_hammer,0.3981071705534973,2,1.6606111526489258,1.737447053194046,0.54296875,0.4800781309604645,0.0,0.0,11.200719275761823,muon,1.049245912622338,"(0.998, 1.09]"
|
| 27 |
+
spec_hammer,0.25118864315095824,2,1.6719481895367305,1.742106318473816,0.525390625,0.4791015684604645,0.0,0.0,9.893169575184247,muon,0.9953354532233326,"(0.906, 0.998]"
|
| 28 |
+
spec_wd,0.15848931924611143,1,1.7357934166987736,1.8928824245929718,0.69921875,0.5130859613418579,0.0774263682681127,0.0,16.699798610106914,muon,1.2227112338394006,"(1.182, 1.274]"
|
| 29 |
+
spec_wd,0.01,1,1.929346779982249,1.9589250326156615,0.46484375,0.4449218809604645,0.01,0.0,4.75861730167534,muon,0.677480779298928,"(0.63, 0.722]"
|
| 30 |
+
spec_hammer,0.001584893192461114,2,1.9736527701218922,1.9803420543670653,0.376953125,0.3541015684604645,0.0,0.0,8.03332382971308,muon,0.9048952740744247,"(0.814, 0.906]"
|
| 31 |
+
none,0.25118864315095807,0,0.004763689454800139,2.0109056890010835,1.0,0.5713867545127869,0.0,0.046415888336127774,1336.1765925456536,muon,3.125863859411919,"(3.113, 3.205]"
|
| 32 |
+
spec_wd,0.025118864315095794,1,2.0198851923147836,2.054336929321289,0.478515625,0.4500976502895355,0.027825594022071243,0.0,3.738501995033509,muon,0.57269761674191,"(0.539, 0.63]"
|
| 33 |
+
spec_wd,0.1,1,1.990190014243126,2.0631633162498475,0.53515625,0.470703125,0.0774263682681127,0.0,5.913547971048514,muon,0.7718481241836456,"(0.722, 0.814]"
|
| 34 |
+
spec_wd,0.039810717055349734,1,2.0878397126992545,2.1199615478515623,0.458984375,0.43583986163139343,0.046415888336127774,0.0,2.834448725282898,muon,0.4524686050814032,"(0.447, 0.539]"
|
| 35 |
+
spec_hammer,1.0,1,2.147181282440821,2.161078703403473,0.365234375,0.3597656190395355,0.0,0.0,2.23825218773244,muon,0.3499090176571464,"(0.263, 0.355]"
|
| 36 |
+
spec_hammer,0.3981071705534973,1,2.1627643605073295,2.174478459358215,0.353515625,0.34345704317092896,0.0,0.0,1.6179561106364104,muon,0.20896673657974793,"(0.171, 0.263]"
|
| 37 |
+
spec_wd,0.15848931924611143,1,2.130100061496099,2.1772934913635256,0.462890625,0.4325195252895355,0.1291549665014884,0.0,2.7401330237465658,muon,0.437771646790013,"(0.355, 0.447]"
|
| 38 |
+
spec_hammer,0.25118864315095824,1,2.1698385576407113,2.1785532593727113,0.345703125,0.33867189288139343,0.0,0.0,1.4152101372317512,muon,0.15082093078920447,"(0.0788, 0.171]"
|
| 39 |
+
spec_hammer,0.1,1,2.1790361801783242,2.185564732551575,0.345703125,0.3270507752895355,0.0,0.0,1.1631840225552137,muon,0.0656484281059515,"(-0.0131, 0.0788]"
|
| 40 |
+
spec_wd,0.025118864315095794,1,2.225555419921875,2.2346877694129943,0.361328125,0.3545898497104645,0.046415888336127774,0.0,0.7616373136447698,muon,-0.11825178742681001,"(-0.197, -0.105]"
|
| 41 |
+
spec_wd,0.1,1,2.2288816571235657,2.2476592302322387,0.3984375,0.3721679747104645,0.1291549665014884,0.0,0.8792017653044303,muon,-0.055911448585927,"(-0.105, -0.0131]"
|
| 42 |
+
spec_wd,0.039810717055349734,1,2.2539119025071463,2.262084949016571,0.34765625,0.3252929747104645,0.0774263682681127,0.0,0.47417001799999065,muon,-0.3240659101402314,"(-0.381, -0.289]"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
figures/{figure_2/figure_2_subplot_2_3.csv → figure_4/figure_4_subplot_2_3.csv}
RENAMED
|
File without changes
|
figures/reproduce_figures.py
CHANGED
|
@@ -25,16 +25,16 @@ def load_csv_data():
|
|
| 25 |
|
| 26 |
# Load Figure 2 data
|
| 27 |
fig2_subplot1 = pd.read_csv(fig_2_data_dir / "figure_2_subplot_1.csv")
|
| 28 |
-
|
| 29 |
|
| 30 |
# Load Figure 3 data
|
| 31 |
fig3_data = pd.read_csv(fig_3_data_dir / "figure_3_subplot_2.csv")
|
| 32 |
|
| 33 |
# Load Figure 4 data
|
| 34 |
fig4_subplot1 = pd.read_csv(fig_4_data_dir / "figure_4_subplot_1.csv")
|
| 35 |
-
|
| 36 |
|
| 37 |
-
return fig2_subplot1,
|
| 38 |
|
| 39 |
def safe_eval_list(list_str):
|
| 40 |
"""Safely evaluate string representation of list"""
|
|
@@ -43,8 +43,220 @@ def safe_eval_list(list_str):
|
|
| 43 |
except:
|
| 44 |
return []
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
tech_colors = {
|
| 50 |
"Spectral soft cap": "#8c564b", # brown
|
|
@@ -237,218 +449,7 @@ def create_figure_2(highlight_points, results_df):
|
|
| 237 |
|
| 238 |
plt.subplots_adjust(left=0.07, right=0.97, top=0.95, bottom=0.15)
|
| 239 |
|
| 240 |
-
plt.savefig("
|
| 241 |
-
plt.show()
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
def create_figure_3(df):
|
| 245 |
-
"""Create Figure 3: Adversarial robustness comparison"""
|
| 246 |
-
|
| 247 |
-
# Extract unique epsilon values and find epsilon range
|
| 248 |
-
epsilons = sorted(df['epsilon'].unique())
|
| 249 |
-
epsilons_upto = len(epsilons) # Use all available epsilon values
|
| 250 |
-
|
| 251 |
-
# Create the model info for plotting (extract from CSV)
|
| 252 |
-
models = []
|
| 253 |
-
for model_name in df['model_name'].unique():
|
| 254 |
-
model_data = df[df['model_name'] == model_name].copy().sort_values(by='epsilon')
|
| 255 |
-
|
| 256 |
-
# Determine color based on model name
|
| 257 |
-
if "Muon" in model_name or "soft cap" in model_name:
|
| 258 |
-
color = "royalblue"
|
| 259 |
-
else:
|
| 260 |
-
color = "#7F7F7F"
|
| 261 |
-
|
| 262 |
-
models.append({
|
| 263 |
-
"name": model_name,
|
| 264 |
-
"color": color,
|
| 265 |
-
"accuracies": model_data['accuracy'].tolist(),
|
| 266 |
-
"avg_correct_probs": model_data['avg_correct_prob'].tolist(),
|
| 267 |
-
"error_bars": model_data['prob_error_bar'].tolist()
|
| 268 |
-
})
|
| 269 |
-
|
| 270 |
-
# Create a figure with two subplots stacked vertically
|
| 271 |
-
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 5), sharex=True)
|
| 272 |
-
|
| 273 |
-
# Plot accuracy for each model (top subplot)
|
| 274 |
-
for model in models:
|
| 275 |
-
ax1.plot(epsilons[:epsilons_upto], model["accuracies"][:epsilons_upto], 'o-',
|
| 276 |
-
linewidth=3, markersize=5,
|
| 277 |
-
label=model["name"], color=model["color"])
|
| 278 |
-
ax1.set_xticks(epsilons[::2])
|
| 279 |
-
|
| 280 |
-
# Plot probability with error bars for each model (bottom subplot)
|
| 281 |
-
for model in models:
|
| 282 |
-
ax2.errorbar(epsilons[:epsilons_upto], model["avg_correct_probs"][:epsilons_upto],
|
| 283 |
-
yerr=model["error_bars"][:epsilons_upto], fmt='o-',
|
| 284 |
-
linewidth=3, markersize=5, capsize=5, elinewidth=1.5,
|
| 285 |
-
label=model["name"], color=model["color"])
|
| 286 |
-
ax2.set_xticks(epsilons[::2])
|
| 287 |
-
|
| 288 |
-
# Configure top subplot (accuracy)
|
| 289 |
-
ax1.set_ylabel('Accuracy (top 1)', fontsize=12)
|
| 290 |
-
ax1.set_ylim(0, 0.5)
|
| 291 |
-
ax1.tick_params(axis='y', labelsize=12)
|
| 292 |
-
ax1.legend(fontsize=12, frameon=False, borderpad=0.2, handletextpad=0.5, labelspacing=0.2, loc='upper center', bbox_to_anchor=(0.5, 1.38))
|
| 293 |
-
|
| 294 |
-
# Configure bottom subplot (probability)
|
| 295 |
-
ax2.set_xlabel('Budget of adversarial perturbation (ε)', fontsize=12)
|
| 296 |
-
ax2.set_ylabel('Mean p(correct class)', fontsize=12)
|
| 297 |
-
ax2.tick_params(axis='both', labelsize=12)
|
| 298 |
-
|
| 299 |
-
# Set x-ticks for both subplots
|
| 300 |
-
plt.xticks(epsilons[::2])
|
| 301 |
-
|
| 302 |
-
plt.tight_layout()
|
| 303 |
-
plt.savefig("figure_3_reproduced.pdf", format='pdf', bbox_inches='tight')
|
| 304 |
-
plt.show()
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
def create_figure_4(MLP_highlight_points, transformer_highlight_points):
|
| 308 |
-
"""Create Figure 4: MLP vs Transformer comparison"""
|
| 309 |
-
|
| 310 |
-
method_to_str = {
|
| 311 |
-
("adam", "none"): "Adam: weight decay",
|
| 312 |
-
("adam", "spec_normalize"): "Adam: spectral normalize",
|
| 313 |
-
("adam", "spec_hammer"): "Adam: spectral hammer",
|
| 314 |
-
("muon", "none"): "Muon: weight decay",
|
| 315 |
-
("muon", "spec_normalize"): "Muon: spectral normalize",
|
| 316 |
-
("muon", "soft_cap"): "Muon: soft cap",
|
| 317 |
-
}
|
| 318 |
-
|
| 319 |
-
# Define font sizes
|
| 320 |
-
label_fontsize = 16
|
| 321 |
-
title_fontsize = 16
|
| 322 |
-
tick_fontsize = 15
|
| 323 |
-
legend_fontsize = 15
|
| 324 |
-
|
| 325 |
-
# Create figure and gridspec with 3 panels (left-middle-right)
|
| 326 |
-
fig = plt.figure(figsize=(16, 3.5))
|
| 327 |
-
gs = gridspec.GridSpec(1, 3, width_ratios=[1, 1, 0.8], wspace=0.3)
|
| 328 |
-
|
| 329 |
-
# List of combinations to plot
|
| 330 |
-
combinations = [
|
| 331 |
-
("muon", "none"),
|
| 332 |
-
("muon", "spec_normalize"),
|
| 333 |
-
("muon", "soft_cap"),
|
| 334 |
-
("adam", "none"),
|
| 335 |
-
("adam", "spec_normalize"),
|
| 336 |
-
("adam", "spec_hammer"),
|
| 337 |
-
]
|
| 338 |
-
|
| 339 |
-
# Dictionary to store legend handles and labels
|
| 340 |
-
legend_elements = []
|
| 341 |
-
|
| 342 |
-
# Function to create plot with given data
|
| 343 |
-
def create_plot(ax, data_points, title, show_ylabel=True):
|
| 344 |
-
for opt, project in combinations:
|
| 345 |
-
# Filter DataFrame for this optimizer and method combination
|
| 346 |
-
subset_df = data_points[(data_points['optim'] == opt) & (data_points['technique'] == project)]
|
| 347 |
-
|
| 348 |
-
if len(subset_df) == 0:
|
| 349 |
-
continue
|
| 350 |
-
|
| 351 |
-
# Extract x and y values from the DataFrame
|
| 352 |
-
x = subset_df['lipschitz'].values
|
| 353 |
-
y = subset_df['final_val_loss'].values
|
| 354 |
-
|
| 355 |
-
# Define colors based on optimizer
|
| 356 |
-
light_blue, light_green = "#7f7f7f", "royalblue"
|
| 357 |
-
color = light_green if opt == "muon" else light_blue
|
| 358 |
-
|
| 359 |
-
# Set markers based on project type
|
| 360 |
-
if project == "soft_cap":
|
| 361 |
-
marker = 'o'
|
| 362 |
-
markersize = 80
|
| 363 |
-
elif project == "none":
|
| 364 |
-
marker = 's' # square for weight decay
|
| 365 |
-
markersize = 80
|
| 366 |
-
elif project == "spec_normalize":
|
| 367 |
-
marker = 'P' # plus for spectral normalize
|
| 368 |
-
markersize = 90
|
| 369 |
-
elif project == "spec_hammer":
|
| 370 |
-
marker = 'v' # triangle down for spectral hammer
|
| 371 |
-
markersize = 90
|
| 372 |
-
|
| 373 |
-
# Plot the points
|
| 374 |
-
sc = ax.scatter(x, y, color=color, alpha=1, marker=marker, s=markersize, edgecolors='white', linewidth=0.5)
|
| 375 |
-
|
| 376 |
-
# Store the legend element if it's not already in the list
|
| 377 |
-
label = method_to_str[(opt, project)]
|
| 378 |
-
if not any(label == elem.get_label() for elem in legend_elements):
|
| 379 |
-
# Set appropriate marker sizes for the legend
|
| 380 |
-
legend_elements.append(
|
| 381 |
-
plt.Line2D([0], [0], marker=marker, color=color, linestyle='None',
|
| 382 |
-
label=label, markersize=10)
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
-
ax.set_xlabel("Lipschitz constant", fontsize=label_fontsize)
|
| 386 |
-
if show_ylabel:
|
| 387 |
-
ax.set_ylabel("Validation loss", fontsize=label_fontsize)
|
| 388 |
-
ax.set_xscale("log")
|
| 389 |
-
ax.set_yscale("log")
|
| 390 |
-
ax.set_title(title, fontsize=title_fontsize, pad=10)
|
| 391 |
-
|
| 392 |
-
# Set tick font sizes
|
| 393 |
-
ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
|
| 394 |
-
|
| 395 |
-
# Custom log formatter for y-axis
|
| 396 |
-
class CustomLogFormatter(LogFormatter):
|
| 397 |
-
def __call__(self, x, pos=None):
|
| 398 |
-
# Format the tick label as a plain number
|
| 399 |
-
if x < 1:
|
| 400 |
-
return f"{x:.2f}"
|
| 401 |
-
elif x < 10:
|
| 402 |
-
return f"{x:.1f}"
|
| 403 |
-
else:
|
| 404 |
-
return f"{int(x)}"
|
| 405 |
-
|
| 406 |
-
def get_offset(self):
|
| 407 |
-
return ''
|
| 408 |
-
|
| 409 |
-
def set_locs(self, locs=None):
|
| 410 |
-
self.locs = locs
|
| 411 |
-
return
|
| 412 |
-
|
| 413 |
-
# Apply the custom formatter to the y-axis
|
| 414 |
-
ax.yaxis.set_major_formatter(CustomLogFormatter())
|
| 415 |
-
# Remove the minor tick labels which might still show scaling
|
| 416 |
-
ax.yaxis.set_minor_formatter(NullFormatter())
|
| 417 |
-
# Force the ticks to good positions
|
| 418 |
-
y_min, y_max = 1.05, 2.4 # Adjust based on your data range
|
| 419 |
-
ax.set_yticks(np.linspace(y_min, y_max, 5))
|
| 420 |
-
ax.xaxis.set_tick_params(which='minor', bottom=False)
|
| 421 |
-
ax.minorticks_off()
|
| 422 |
-
ax.grid(False)
|
| 423 |
-
|
| 424 |
-
# Create the side panels
|
| 425 |
-
ax1 = fig.add_subplot(gs[0])
|
| 426 |
-
create_plot(ax1, MLP_highlight_points, "MLP")
|
| 427 |
-
|
| 428 |
-
ax3 = fig.add_subplot(gs[1])
|
| 429 |
-
create_plot(ax3, transformer_highlight_points, "Transformer", show_ylabel=False)
|
| 430 |
-
|
| 431 |
-
# Create the rightmost panel for legend
|
| 432 |
-
ax2 = fig.add_subplot(gs[2])
|
| 433 |
-
ax2.axis('off') # Turn off axis for legend panel
|
| 434 |
-
|
| 435 |
-
# Add the combined legend to the middle panel
|
| 436 |
-
legend = ax2.legend(
|
| 437 |
-
handles=legend_elements,
|
| 438 |
-
loc='center',
|
| 439 |
-
fontsize=legend_fontsize,
|
| 440 |
-
title='Method:',
|
| 441 |
-
title_fontsize=title_fontsize,
|
| 442 |
-
frameon=False,
|
| 443 |
-
alignment='left',
|
| 444 |
-
bbox_to_anchor=(0.36, 0.5),
|
| 445 |
-
)
|
| 446 |
-
|
| 447 |
-
# Left align the legend title
|
| 448 |
-
legend.get_title().set_ha('left')
|
| 449 |
-
|
| 450 |
-
plt.tight_layout()
|
| 451 |
-
plt.savefig("figure_4_reproduced.pdf", dpi=600, bbox_inches='tight')
|
| 452 |
plt.show()
|
| 453 |
|
| 454 |
|
|
@@ -456,16 +457,16 @@ def main():
|
|
| 456 |
"""Main function to load data and create all figures"""
|
| 457 |
|
| 458 |
print("Loading CSV data...")
|
| 459 |
-
fig2_subplot1,
|
| 460 |
|
| 461 |
print("Creating Figure 2...")
|
| 462 |
-
create_figure_2(fig2_subplot1,
|
| 463 |
|
| 464 |
print("Creating Figure 3...")
|
| 465 |
create_figure_3(fig3_data)
|
| 466 |
|
| 467 |
print("Creating Figure 4...")
|
| 468 |
-
create_figure_4(fig4_subplot1,
|
| 469 |
|
| 470 |
print("Figures saved as 'figure_2_reproduced.pdf', 'figure_3_reproduced.pdf' and 'figure_4_reproduced.pdf'")
|
| 471 |
|
|
|
|
| 25 |
|
| 26 |
# Load Figure 2 data
|
| 27 |
fig2_subplot1 = pd.read_csv(fig_2_data_dir / "figure_2_subplot_1.csv")
|
| 28 |
+
fig2_subplot2 = pd.read_csv(fig_2_data_dir / "figure_2_subplot_2.csv")
|
| 29 |
|
| 30 |
# Load Figure 3 data
|
| 31 |
fig3_data = pd.read_csv(fig_3_data_dir / "figure_3_subplot_2.csv")
|
| 32 |
|
| 33 |
# Load Figure 4 data
|
| 34 |
fig4_subplot1 = pd.read_csv(fig_4_data_dir / "figure_4_subplot_1.csv")
|
| 35 |
+
fig4_subplot2_3 = pd.read_csv(fig_4_data_dir / "figure_4_subplot_2_3.csv") # Used for both subplot 2 and 3
|
| 36 |
|
| 37 |
+
return fig2_subplot1, fig2_subplot2, fig3_data, fig4_subplot1, fig4_subplot2_3
|
| 38 |
|
| 39 |
def safe_eval_list(list_str):
|
| 40 |
"""Safely evaluate string representation of list"""
|
|
|
|
| 43 |
except:
|
| 44 |
return []
|
| 45 |
|
| 46 |
+
|
| 47 |
+
def create_figure_2(MLP_highlight_points, transformer_highlight_points):
|
| 48 |
+
"""Create Figure 2: MLP vs Transformer comparison"""
|
| 49 |
+
|
| 50 |
+
method_to_str = {
|
| 51 |
+
("adam", "none"): "Adam: weight decay",
|
| 52 |
+
("adam", "spec_normalize"): "Adam: spectral normalize",
|
| 53 |
+
("adam", "spec_hammer"): "Adam: spectral hammer",
|
| 54 |
+
("muon", "none"): "Muon: weight decay",
|
| 55 |
+
("muon", "spec_normalize"): "Muon: spectral normalize",
|
| 56 |
+
("muon", "soft_cap"): "Muon: soft cap",
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
# Define font sizes
|
| 60 |
+
label_fontsize = 16
|
| 61 |
+
title_fontsize = 16
|
| 62 |
+
tick_fontsize = 15
|
| 63 |
+
legend_fontsize = 15
|
| 64 |
+
|
| 65 |
+
# Create figure and gridspec with 3 panels (left-middle-right)
|
| 66 |
+
fig = plt.figure(figsize=(16, 3.5))
|
| 67 |
+
gs = gridspec.GridSpec(1, 3, width_ratios=[1, 1, 0.8], wspace=0.3)
|
| 68 |
+
|
| 69 |
+
# List of combinations to plot
|
| 70 |
+
combinations = [
|
| 71 |
+
("muon", "none"),
|
| 72 |
+
("muon", "spec_normalize"),
|
| 73 |
+
("muon", "soft_cap"),
|
| 74 |
+
("adam", "none"),
|
| 75 |
+
("adam", "spec_normalize"),
|
| 76 |
+
("adam", "spec_hammer"),
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
# Dictionary to store legend handles and labels
|
| 80 |
+
legend_elements = []
|
| 81 |
+
|
| 82 |
+
# Function to create plot with given data
|
| 83 |
+
def create_plot(ax, data_points, title, show_ylabel=True):
|
| 84 |
+
for opt, project in combinations:
|
| 85 |
+
# Filter DataFrame for this optimizer and method combination
|
| 86 |
+
subset_df = data_points[(data_points['optim'] == opt) & (data_points['technique'] == project)]
|
| 87 |
+
|
| 88 |
+
if len(subset_df) == 0:
|
| 89 |
+
continue
|
| 90 |
+
|
| 91 |
+
# Extract x and y values from the DataFrame
|
| 92 |
+
x = subset_df['lipschitz'].values
|
| 93 |
+
y = subset_df['final_val_loss'].values
|
| 94 |
+
|
| 95 |
+
# Define colors based on optimizer
|
| 96 |
+
light_blue, light_green = "#7f7f7f", "royalblue"
|
| 97 |
+
color = light_green if opt == "muon" else light_blue
|
| 98 |
+
|
| 99 |
+
# Set markers based on project type
|
| 100 |
+
if project == "soft_cap":
|
| 101 |
+
marker = 'o'
|
| 102 |
+
markersize = 80
|
| 103 |
+
elif project == "none":
|
| 104 |
+
marker = 's' # square for weight decay
|
| 105 |
+
markersize = 80
|
| 106 |
+
elif project == "spec_normalize":
|
| 107 |
+
marker = 'P' # plus for spectral normalize
|
| 108 |
+
markersize = 90
|
| 109 |
+
elif project == "spec_hammer":
|
| 110 |
+
marker = 'v' # triangle down for spectral hammer
|
| 111 |
+
markersize = 90
|
| 112 |
+
|
| 113 |
+
# Plot the points
|
| 114 |
+
sc = ax.scatter(x, y, color=color, alpha=1, marker=marker, s=markersize, edgecolors='white', linewidth=0.5)
|
| 115 |
+
|
| 116 |
+
# Store the legend element if it's not already in the list
|
| 117 |
+
label = method_to_str[(opt, project)]
|
| 118 |
+
if not any(label == elem.get_label() for elem in legend_elements):
|
| 119 |
+
# Set appropriate marker sizes for the legend
|
| 120 |
+
legend_elements.append(
|
| 121 |
+
plt.Line2D([0], [0], marker=marker, color=color, linestyle='None',
|
| 122 |
+
label=label, markersize=10)
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
ax.set_xlabel("Lipschitz constant", fontsize=label_fontsize)
|
| 126 |
+
if show_ylabel:
|
| 127 |
+
ax.set_ylabel("Validation loss", fontsize=label_fontsize)
|
| 128 |
+
ax.set_xscale("log")
|
| 129 |
+
ax.set_yscale("log")
|
| 130 |
+
ax.set_title(title, fontsize=title_fontsize, pad=10)
|
| 131 |
+
|
| 132 |
+
# Set tick font sizes
|
| 133 |
+
ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
|
| 134 |
+
|
| 135 |
+
# Custom log formatter for y-axis
|
| 136 |
+
class CustomLogFormatter(LogFormatter):
|
| 137 |
+
def __call__(self, x, pos=None):
|
| 138 |
+
# Format the tick label as a plain number
|
| 139 |
+
if x < 1:
|
| 140 |
+
return f"{x:.2f}"
|
| 141 |
+
elif x < 10:
|
| 142 |
+
return f"{x:.1f}"
|
| 143 |
+
else:
|
| 144 |
+
return f"{int(x)}"
|
| 145 |
+
|
| 146 |
+
def get_offset(self):
|
| 147 |
+
return ''
|
| 148 |
+
|
| 149 |
+
def set_locs(self, locs=None):
|
| 150 |
+
self.locs = locs
|
| 151 |
+
return
|
| 152 |
+
|
| 153 |
+
# Apply the custom formatter to the y-axis
|
| 154 |
+
ax.yaxis.set_major_formatter(CustomLogFormatter())
|
| 155 |
+
# Remove the minor tick labels which might still show scaling
|
| 156 |
+
ax.yaxis.set_minor_formatter(NullFormatter())
|
| 157 |
+
# Force the ticks to good positions
|
| 158 |
+
y_min, y_max = 1.05, 2.4 # Adjust based on your data range
|
| 159 |
+
ax.set_yticks(np.linspace(y_min, y_max, 5))
|
| 160 |
+
ax.xaxis.set_tick_params(which='minor', bottom=False)
|
| 161 |
+
ax.minorticks_off()
|
| 162 |
+
ax.grid(False)
|
| 163 |
+
|
| 164 |
+
# Create the side panels
|
| 165 |
+
ax1 = fig.add_subplot(gs[0])
|
| 166 |
+
create_plot(ax1, MLP_highlight_points, "MLP")
|
| 167 |
+
|
| 168 |
+
ax3 = fig.add_subplot(gs[1])
|
| 169 |
+
create_plot(ax3, transformer_highlight_points, "Transformer", show_ylabel=False)
|
| 170 |
+
|
| 171 |
+
# Create the rightmost panel for legend
|
| 172 |
+
ax2 = fig.add_subplot(gs[2])
|
| 173 |
+
ax2.axis('off') # Turn off axis for legend panel
|
| 174 |
+
|
| 175 |
+
# Add the combined legend to the middle panel
|
| 176 |
+
legend = ax2.legend(
|
| 177 |
+
handles=legend_elements,
|
| 178 |
+
loc='center',
|
| 179 |
+
fontsize=legend_fontsize,
|
| 180 |
+
title='Method:',
|
| 181 |
+
title_fontsize=title_fontsize,
|
| 182 |
+
frameon=False,
|
| 183 |
+
alignment='left',
|
| 184 |
+
bbox_to_anchor=(0.36, 0.5),
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Left align the legend title
|
| 188 |
+
legend.get_title().set_ha('left')
|
| 189 |
+
|
| 190 |
+
plt.tight_layout()
|
| 191 |
+
plt.savefig("figure_2_reproduced.pdf", dpi=600, bbox_inches='tight')
|
| 192 |
+
plt.show()
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def create_figure_3(df):
|
| 196 |
+
"""Create Figure 3: Adversarial robustness comparison"""
|
| 197 |
+
|
| 198 |
+
# Extract unique epsilon values and find epsilon range
|
| 199 |
+
epsilons = sorted(df['epsilon'].unique())
|
| 200 |
+
epsilons_upto = len(epsilons) # Use all available epsilon values
|
| 201 |
+
|
| 202 |
+
# Create the model info for plotting (extract from CSV)
|
| 203 |
+
models = []
|
| 204 |
+
for model_name in df['model_name'].unique():
|
| 205 |
+
model_data = df[df['model_name'] == model_name].copy().sort_values(by='epsilon')
|
| 206 |
+
|
| 207 |
+
# Determine color based on model name
|
| 208 |
+
if "Muon" in model_name or "soft cap" in model_name:
|
| 209 |
+
color = "royalblue"
|
| 210 |
+
else:
|
| 211 |
+
color = "#7F7F7F"
|
| 212 |
+
|
| 213 |
+
models.append({
|
| 214 |
+
"name": model_name,
|
| 215 |
+
"color": color,
|
| 216 |
+
"accuracies": model_data['accuracy'].tolist(),
|
| 217 |
+
"avg_correct_probs": model_data['avg_correct_prob'].tolist(),
|
| 218 |
+
"error_bars": model_data['prob_error_bar'].tolist()
|
| 219 |
+
})
|
| 220 |
+
|
| 221 |
+
# Create a figure with two subplots stacked vertically
|
| 222 |
+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 5), sharex=True)
|
| 223 |
+
|
| 224 |
+
# Plot accuracy for each model (top subplot)
|
| 225 |
+
for model in models:
|
| 226 |
+
ax1.plot(epsilons[:epsilons_upto], model["accuracies"][:epsilons_upto], 'o-',
|
| 227 |
+
linewidth=3, markersize=5,
|
| 228 |
+
label=model["name"], color=model["color"])
|
| 229 |
+
ax1.set_xticks(epsilons[::2])
|
| 230 |
+
|
| 231 |
+
# Plot probability with error bars for each model (bottom subplot)
|
| 232 |
+
for model in models:
|
| 233 |
+
ax2.errorbar(epsilons[:epsilons_upto], model["avg_correct_probs"][:epsilons_upto],
|
| 234 |
+
yerr=model["error_bars"][:epsilons_upto], fmt='o-',
|
| 235 |
+
linewidth=3, markersize=5, capsize=5, elinewidth=1.5,
|
| 236 |
+
label=model["name"], color=model["color"])
|
| 237 |
+
ax2.set_xticks(epsilons[::2])
|
| 238 |
+
|
| 239 |
+
# Configure top subplot (accuracy)
|
| 240 |
+
ax1.set_ylabel('Accuracy (top 1)', fontsize=12)
|
| 241 |
+
ax1.set_ylim(0, 0.5)
|
| 242 |
+
ax1.tick_params(axis='y', labelsize=12)
|
| 243 |
+
ax1.legend(fontsize=12, frameon=False, borderpad=0.2, handletextpad=0.5, labelspacing=0.2, loc='upper center', bbox_to_anchor=(0.5, 1.38))
|
| 244 |
+
|
| 245 |
+
# Configure bottom subplot (probability)
|
| 246 |
+
ax2.set_xlabel('Budget of adversarial perturbation (ε)', fontsize=12)
|
| 247 |
+
ax2.set_ylabel('Mean p(correct class)', fontsize=12)
|
| 248 |
+
ax2.tick_params(axis='both', labelsize=12)
|
| 249 |
+
|
| 250 |
+
# Set x-ticks for both subplots
|
| 251 |
+
plt.xticks(epsilons[::2])
|
| 252 |
+
|
| 253 |
+
plt.tight_layout()
|
| 254 |
+
plt.savefig("figure_3_reproduced.pdf", format='pdf', bbox_inches='tight')
|
| 255 |
+
plt.show()
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def create_figure_4(highlight_points, results_df):
|
| 259 |
+
"""Create Figure 4: Three-panel comparison plot"""
|
| 260 |
|
| 261 |
tech_colors = {
|
| 262 |
"Spectral soft cap": "#8c564b", # brown
|
|
|
|
| 449 |
|
| 450 |
plt.subplots_adjust(left=0.07, right=0.97, top=0.95, bottom=0.15)
|
| 451 |
|
| 452 |
+
plt.savefig("figure_4_reproduced.pdf", format='pdf', bbox_inches='tight')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
plt.show()
|
| 454 |
|
| 455 |
|
|
|
|
| 457 |
"""Main function to load data and create all figures"""
|
| 458 |
|
| 459 |
print("Loading CSV data...")
|
| 460 |
+
fig2_subplot1, fig2_subplot2, fig3_data, fig4_subplot1, fig4_subplot2_3 = load_csv_data()
|
| 461 |
|
| 462 |
print("Creating Figure 2...")
|
| 463 |
+
create_figure_2(fig2_subplot1, fig2_subplot2)
|
| 464 |
|
| 465 |
print("Creating Figure 3...")
|
| 466 |
create_figure_3(fig3_data)
|
| 467 |
|
| 468 |
print("Creating Figure 4...")
|
| 469 |
+
create_figure_4(fig4_subplot1, fig4_subplot2_3)
|
| 470 |
|
| 471 |
print("Figures saved as 'figure_2_reproduced.pdf', 'figure_3_reproduced.pdf' and 'figure_4_reproduced.pdf'")
|
| 472 |
|