patrickvonplaten commited on
Commit
cf42a95
β€’
1 Parent(s): a85c2d0
Files changed (2) hide show
  1. after_fix_log.txt +3 -12
  2. after_fix_pretrained_log.txt +8 -148
after_fix_log.txt CHANGED
@@ -9,10 +9,7 @@ Flax loss: 6.887884140014648, PyTorch loss: 6.887884616851807
9
  βœ… All grads pass
10
  --------------------------Checking rel gradients match--------------------------
11
  ❌ Layer ('roberta', 'encoder', 'layer', '0', 'attention', 'self', 'key', 'bias') has PT grad norm 7.584575871001642e-13 and flax grad norm 6.388195094436666e-13.
12
- ❌ Layer ('roberta', 'encoder', 'layer', '1', 'attention', 'self', 'key', 'bias') has PT grad norm 7.811836030477415e-13 and flax grad norm 6.42668156105447e-13.
13
- ❌ Layer ('roberta', 'encoder', 'layer', '2', 'attention', 'self', 'key', 'bias') has PT grad norm 8.422985074175993e-13 and flax grad norm 6.414080963405844e-13.
14
- ❌ Layer ('roberta', 'encoder', 'layer', '3', 'attention', 'self', 'key', 'bias') has PT grad norm 8.625919531608794e-13 and flax grad norm 7.699825477734679e-13.
15
- ❌ Layer ('roberta', 'encoder', 'layer', '4', 'attention', 'self', 'key', 'bias') has PT grad norm 1.0383360837806777e-12 and flax grad norm 6.049140680551568e-13.
16
  =========================================
17
  Check hf-internal-testing/tiny-random-bert ...
18
  --------------------------Checking logits match--------------------------
@@ -25,10 +22,7 @@ Flax loss: 7.036032199859619, PyTorch loss: 7.036032676696777
25
  βœ… All grads pass
26
  --------------------------Checking rel gradients match--------------------------
27
  ❌ Layer ('bert', 'encoder', 'layer', '0', 'attention', 'self', 'key', 'bias') has PT grad norm 5.234438642080785e-13 and flax grad norm 4.935363641205004e-13.
28
- ❌ Layer ('bert', 'encoder', 'layer', '1', 'attention', 'self', 'key', 'bias') has PT grad norm 9.028551018787356e-13 and flax grad norm 6.16219206737989e-13.
29
- ❌ Layer ('bert', 'encoder', 'layer', '2', 'attention', 'self', 'key', 'bias') has PT grad norm 8.728350616056535e-13 and flax grad norm 6.037235598596591e-13.
30
- ❌ Layer ('bert', 'encoder', 'layer', '3', 'attention', 'self', 'key', 'bias') has PT grad norm 8.327751465850297e-13 and flax grad norm 7.390156737431541e-13.
31
- ❌ Layer ('bert', 'encoder', 'layer', '4', 'attention', 'self', 'key', 'bias') has PT grad norm 7.404479048130075e-13 and flax grad norm 7.178592030705755e-13.
32
  =========================================
33
  Check hf-internal-testing/tiny-random-t5 ...
34
  --------------------------Checking logits match--------------------------
@@ -55,8 +49,5 @@ Flax loss: 6.919522285461426, PyTorch loss: 6.919522285461426
55
  ❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.0.
56
  ❌ Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 1.1293364247239035e-13 and flax grad norm 7.444291358479557e-14.
57
  ❌ Layer ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 1.9028742882613858e-13 and flax grad norm 1.0847509820726894e-13.
58
- ❌ Layer ('model', 'decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 1.0747876384459981e-13 and flax grad norm 1.1924105637346055e-13.
59
- ❌ Layer ('model', 'decoder', 'layers', '1', 'self_attn', 'k_proj', 'bias') has PT grad norm 2.0553104032074165e-13 and flax grad norm 2.416926793251395e-13.
60
- ❌ Layer ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 1.6505221806215232e-16 and flax grad norm 8.704786207109524e-17.
61
- ❌ Layer ('model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias') has PT grad norm 2.145616114665214e-16 and flax grad norm 1.6750639014770325e-16.
62
  =========================================
9
  βœ… All grads pass
10
  --------------------------Checking rel gradients match--------------------------
11
  ❌ Layer ('roberta', 'encoder', 'layer', '0', 'attention', 'self', 'key', 'bias') has PT grad norm 7.584575871001642e-13 and flax grad norm 6.388195094436666e-13.
12
+ ...
 
 
 
13
  =========================================
14
  Check hf-internal-testing/tiny-random-bert ...
15
  --------------------------Checking logits match--------------------------
22
  βœ… All grads pass
23
  --------------------------Checking rel gradients match--------------------------
24
  ❌ Layer ('bert', 'encoder', 'layer', '0', 'attention', 'self', 'key', 'bias') has PT grad norm 5.234438642080785e-13 and flax grad norm 4.935363641205004e-13.
25
+ ...
 
 
 
26
  =========================================
27
  Check hf-internal-testing/tiny-random-t5 ...
28
  --------------------------Checking logits match--------------------------
49
  ❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.0.
50
  ❌ Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 1.1293364247239035e-13 and flax grad norm 7.444291358479557e-14.
51
  ❌ Layer ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 1.9028742882613858e-13 and flax grad norm 1.0847509820726894e-13.
52
+ ...
 
 
 
53
  =========================================
after_fix_pretrained_log.txt CHANGED
@@ -9,17 +9,7 @@ Flax loss: 14.801228523254395, PyTorch loss: 14.801219940185547
9
  βœ… All grads pass
10
  --------------------------Checking rel gradients match--------------------------
11
  ❌ Layer ('roberta', 'encoder', 'layer', '0', 'attention', 'self', 'key', 'bias') has PT grad norm 6.889232651019483e-08 and flax grad norm 5.7956174970286156e-08.
12
- ❌ Layer ('roberta', 'encoder', 'layer', '1', 'attention', 'self', 'key', 'bias') has PT grad norm 7.026115156349988e-08 and flax grad norm 6.282134989987753e-08.
13
- ❌ Layer ('roberta', 'encoder', 'layer', '10', 'attention', 'self', 'key', 'bias') has PT grad norm 2.273949206710313e-08 and flax grad norm 1.8748883334751554e-08.
14
- ❌ Layer ('roberta', 'encoder', 'layer', '11', 'attention', 'self', 'key', 'bias') has PT grad norm 2.9379741306456708e-08 and flax grad norm 2.6026933497291793e-08.
15
- ❌ Layer ('roberta', 'encoder', 'layer', '2', 'attention', 'self', 'key', 'bias') has PT grad norm 6.197853963385569e-08 and flax grad norm 5.317058082709991e-08.
16
- ❌ Layer ('roberta', 'encoder', 'layer', '3', 'attention', 'self', 'key', 'bias') has PT grad norm 7.359258802352997e-08 and flax grad norm 8.573702814373974e-08.
17
- ❌ Layer ('roberta', 'encoder', 'layer', '4', 'attention', 'self', 'key', 'bias') has PT grad norm 5.1634213349416314e-08 and flax grad norm 5.744939457485998e-08.
18
- ❌ Layer ('roberta', 'encoder', 'layer', '5', 'attention', 'self', 'key', 'bias') has PT grad norm 4.652720519970899e-08 and flax grad norm 6.121346984855336e-08.
19
- ❌ Layer ('roberta', 'encoder', 'layer', '6', 'attention', 'self', 'key', 'bias') has PT grad norm 3.8810604507943935e-08 and flax grad norm 4.2490388096894094e-08.
20
- ❌ Layer ('roberta', 'encoder', 'layer', '7', 'attention', 'self', 'key', 'bias') has PT grad norm 3.7450202938771326e-08 and flax grad norm 3.219445687818734e-08.
21
- ❌ Layer ('roberta', 'encoder', 'layer', '8', 'attention', 'self', 'key', 'bias') has PT grad norm 3.3088259243641005e-08 and flax grad norm 2.6118801343955056e-08.
22
- ❌ Layer ('roberta', 'encoder', 'layer', '9', 'attention', 'self', 'key', 'bias') has PT grad norm 2.6417508180998084e-08 and flax grad norm 2.415968225477627e-08.
23
  =========================================
24
  Check bert-base-cased ...
25
  --------------------------Checking logits match--------------------------
@@ -32,16 +22,7 @@ Flax loss: 13.967159271240234, PyTorch loss: 13.967162132263184
32
  βœ… All grads pass
33
  --------------------------Checking rel gradients match--------------------------
34
  ❌ Layer ('bert', 'encoder', 'layer', '0', 'attention', 'self', 'key', 'bias') has PT grad norm 8.025740783068613e-08 and flax grad norm 8.381563532111613e-08.
35
- ❌ Layer ('bert', 'encoder', 'layer', '1', 'attention', 'self', 'key', 'bias') has PT grad norm 7.262840284738559e-08 and flax grad norm 5.0372555904232286e-08.
36
- ❌ Layer ('bert', 'encoder', 'layer', '10', 'attention', 'self', 'key', 'bias') has PT grad norm 2.6523425233904163e-08 and flax grad norm 2.7082945663892133e-08.
37
- ❌ Layer ('bert', 'encoder', 'layer', '11', 'attention', 'self', 'key', 'bias') has PT grad norm 2.9038789151059063e-08 and flax grad norm 3.3138192634396546e-08.
38
- ❌ Layer ('bert', 'encoder', 'layer', '2', 'attention', 'self', 'key', 'bias') has PT grad norm 5.880680831182872e-08 and flax grad norm 5.04786008548308e-08.
39
- ❌ Layer ('bert', 'encoder', 'layer', '3', 'attention', 'self', 'key', 'bias') has PT grad norm 4.705585965325554e-08 and flax grad norm 4.983893475696277e-08.
40
- ❌ Layer ('bert', 'encoder', 'layer', '4', 'attention', 'self', 'key', 'bias') has PT grad norm 6.595875134962625e-08 and flax grad norm 5.823812543326312e-08.
41
- ❌ Layer ('bert', 'encoder', 'layer', '5', 'attention', 'self', 'key', 'bias') has PT grad norm 4.716540402682767e-08 and flax grad norm 6.053270595884896e-08.
42
- ❌ Layer ('bert', 'encoder', 'layer', '6', 'attention', 'self', 'key', 'bias') has PT grad norm 5.4432636176215965e-08 and flax grad norm 4.0700697923057305e-08.
43
- ❌ Layer ('bert', 'encoder', 'layer', '8', 'attention', 'self', 'key', 'bias') has PT grad norm 4.059621971919114e-08 and flax grad norm 4.575255374561493e-08.
44
- ❌ Layer ('bert', 'encoder', 'layer', '9', 'attention', 'self', 'key', 'bias') has PT grad norm 2.9032529269557017e-08 and flax grad norm 2.659336217902819e-08.
45
  =========================================
46
  Check t5-small ...
47
  --------------------------Checking logits match--------------------------
@@ -65,104 +46,14 @@ Flax loss: 13.993148803710938, PyTorch loss: 13.993138313293457
65
  --------------------------Checking gradients match--------------------------
66
  ❌ Layer ('model', 'decoder', 'layers', '0', 'fc1', 'kernel') has PT grad norm 11.655710220336914 and flax grad norm 11.6015625.
67
  ❌ Layer ('model', 'decoder', 'layers', '0', 'fc2', 'kernel') has PT grad norm 7.740886211395264 and flax grad norm 7.71484375.
68
- ❌ Layer ('model', 'decoder', 'layers', '1', 'fc1', 'kernel') has PT grad norm 6.798306465148926 and flax grad norm 6.765625.
69
- ❌ Layer ('model', 'decoder', 'layers', '1', 'fc2', 'kernel') has PT grad norm 7.071859836578369 and flax grad norm 7.05078125.
70
- ❌ Layer ('model', 'decoder', 'layers', '10', 'fc1', 'kernel') has PT grad norm 16.904926300048828 and flax grad norm 16.859375.
71
- ❌ Layer ('model', 'decoder', 'layers', '10', 'fc2', 'kernel') has PT grad norm 6.783661842346191 and flax grad norm 6.765625.
72
  ❌ Layer ('model', 'decoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel') has PT grad norm 6.97633171081543 and flax grad norm 6.96484375.
73
- ❌ Layer ('model', 'decoder', 'layers', '11', 'fc1', 'kernel') has PT grad norm 13.733073234558105 and flax grad norm 13.6875.
74
- ❌ Layer ('model', 'decoder', 'layers', '11', 'fc2', 'kernel') has PT grad norm 6.311193466186523 and flax grad norm 6.29296875.
75
- ❌ Layer ('model', 'decoder', 'layers', '2', 'fc1', 'kernel') has PT grad norm 6.043461322784424 and flax grad norm 6.01171875.
76
- ❌ Layer ('model', 'decoder', 'layers', '2', 'fc2', 'kernel') has PT grad norm 8.091109275817871 and flax grad norm 8.0703125.
77
- ❌ Layer ('model', 'decoder', 'layers', '3', 'fc1', 'kernel') has PT grad norm 6.561250686645508 and flax grad norm 6.52734375.
78
- ❌ Layer ('model', 'decoder', 'layers', '3', 'fc2', 'kernel') has PT grad norm 8.535536766052246 and flax grad norm 8.5.
79
- ❌ Layer ('model', 'decoder', 'layers', '4', 'fc1', 'kernel') has PT grad norm 5.882389545440674 and flax grad norm 5.859375.
80
- ❌ Layer ('model', 'decoder', 'layers', '4', 'fc2', 'kernel') has PT grad norm 8.772762298583984 and flax grad norm 8.75.
81
- ❌ Layer ('model', 'decoder', 'layers', '5', 'fc1', 'kernel') has PT grad norm 4.559173107147217 and flax grad norm 4.5390625.
82
- ❌ Layer ('model', 'decoder', 'layers', '5', 'fc2', 'kernel') has PT grad norm 7.053295612335205 and flax grad norm 7.03515625.
83
- ❌ Layer ('model', 'decoder', 'layers', '6', 'fc1', 'kernel') has PT grad norm 4.724750995635986 and flax grad norm 4.703125.
84
- ❌ Layer ('model', 'decoder', 'layers', '6', 'fc2', 'kernel') has PT grad norm 6.6051740646362305 and flax grad norm 6.5859375.
85
- ❌ Layer ('model', 'decoder', 'layers', '7', 'fc1', 'kernel') has PT grad norm 3.9028773307800293 and flax grad norm 3.884765625.
86
- ❌ Layer ('model', 'decoder', 'layers', '7', 'fc2', 'kernel') has PT grad norm 6.16121244430542 and flax grad norm 6.14453125.
87
- ❌ Layer ('model', 'decoder', 'layers', '8', 'fc1', 'kernel') has PT grad norm 3.8737242221832275 and flax grad norm 3.85546875.
88
- ❌ Layer ('model', 'decoder', 'layers', '8', 'fc2', 'kernel') has PT grad norm 6.476221084594727 and flax grad norm 6.45703125.
89
- ❌ Layer ('model', 'decoder', 'layers', '9', 'fc1', 'kernel') has PT grad norm 6.240624904632568 and flax grad norm 6.203125.
90
- ❌ Layer ('model', 'decoder', 'layers', '9', 'fc2', 'kernel') has PT grad norm 6.872060775756836 and flax grad norm 6.8515625.
91
- ❌ Layer ('model', 'encoder', 'layers', '0', 'fc1', 'kernel') has PT grad norm 18.140913009643555 and flax grad norm 18.0625.
92
- ❌ Layer ('model', 'encoder', 'layers', '0', 'fc2', 'kernel') has PT grad norm 13.278300285339355 and flax grad norm 13.2265625.
93
- ❌ Layer ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel') has PT grad norm 13.971784591674805 and flax grad norm 13.9609375.
94
- ❌ Layer ('model', 'encoder', 'layers', '1', 'fc1', 'kernel') has PT grad norm 9.453910827636719 and flax grad norm 9.4140625.
95
- ❌ Layer ('model', 'encoder', 'layers', '1', 'fc2', 'kernel') has PT grad norm 5.271415710449219 and flax grad norm 5.25.
96
- ❌ Layer ('model', 'encoder', 'layers', '10', 'fc1', 'kernel') has PT grad norm 13.269391059875488 and flax grad norm 13.2109375.
97
- ❌ Layer ('model', 'encoder', 'layers', '10', 'fc2', 'kernel') has PT grad norm 15.780173301696777 and flax grad norm 15.734375.
98
- ❌ Layer ('model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel') has PT grad norm 21.07574462890625 and flax grad norm 21.0625.
99
- ❌ Layer ('model', 'encoder', 'layers', '11', 'fc1', 'kernel') has PT grad norm 16.847095489501953 and flax grad norm 16.765625.
100
- ❌ Layer ('model', 'encoder', 'layers', '11', 'fc2', 'kernel') has PT grad norm 17.480010986328125 and flax grad norm 17.421875.
101
- ❌ Layer ('model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'kernel') has PT grad norm 26.91538429260254 and flax grad norm 26.890625.
102
- ❌ Layer ('model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'kernel') has PT grad norm 26.706096649169922 and flax grad norm 26.6875.
103
- ❌ Layer ('model', 'encoder', 'layers', '2', 'fc1', 'kernel') has PT grad norm 6.756587505340576 and flax grad norm 6.7265625.
104
- ❌ Layer ('model', 'encoder', 'layers', '2', 'fc2', 'kernel') has PT grad norm 5.000077724456787 and flax grad norm 4.98046875.
105
- ❌ Layer ('model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'kernel') has PT grad norm 18.72007942199707 and flax grad norm 18.703125.
106
- ❌ Layer ('model', 'encoder', 'layers', '3', 'fc1', 'kernel') has PT grad norm 9.560458183288574 and flax grad norm 9.515625.
107
- ❌ Layer ('model', 'encoder', 'layers', '3', 'fc2', 'kernel') has PT grad norm 8.892805099487305 and flax grad norm 8.859375.
108
- ❌ Layer ('model', 'encoder', 'layers', '4', 'fc1', 'kernel') has PT grad norm 8.845908164978027 and flax grad norm 8.8046875.
109
- ❌ Layer ('model', 'encoder', 'layers', '4', 'fc2', 'kernel') has PT grad norm 8.670329093933105 and flax grad norm 8.640625.
110
- ❌ Layer ('model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'kernel') has PT grad norm 18.135663986206055 and flax grad norm 18.125.
111
- ❌ Layer ('model', 'encoder', 'layers', '5', 'fc1', 'kernel') has PT grad norm 10.071086883544922 and flax grad norm 10.03125.
112
- ❌ Layer ('model', 'encoder', 'layers', '5', 'fc2', 'kernel') has PT grad norm 9.528592109680176 and flax grad norm 9.5.
113
- ❌ Layer ('model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'kernel') has PT grad norm 16.202266693115234 and flax grad norm 16.1875.
114
- ❌ Layer ('model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'kernel') has PT grad norm 17.79180908203125 and flax grad norm 17.78125.
115
- ❌ Layer ('model', 'encoder', 'layers', '6', 'fc1', 'kernel') has PT grad norm 12.527167320251465 and flax grad norm 12.46875.
116
- ❌ Layer ('model', 'encoder', 'layers', '6', 'fc2', 'kernel') has PT grad norm 11.495430946350098 and flax grad norm 11.453125.
117
- ❌ Layer ('model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'kernel') has PT grad norm 18.312782287597656 and flax grad norm 18.296875.
118
- ❌ Layer ('model', 'encoder', 'layers', '7', 'fc1', 'kernel') has PT grad norm 11.963201522827148 and flax grad norm 11.9140625.
119
- ❌ Layer ('model', 'encoder', 'layers', '7', 'fc2', 'kernel') has PT grad norm 13.052857398986816 and flax grad norm 13.0078125.
120
- ❌ Layer ('model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'kernel') has PT grad norm 17.2364501953125 and flax grad norm 17.21875.
121
- ❌ Layer ('model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'kernel') has PT grad norm 18.88938331604004 and flax grad norm 18.875.
122
- ❌ Layer ('model', 'encoder', 'layers', '8', 'fc1', 'kernel') has PT grad norm 11.773221969604492 and flax grad norm 11.71875.
123
- ❌ Layer ('model', 'encoder', 'layers', '8', 'fc2', 'kernel') has PT grad norm 14.441213607788086 and flax grad norm 14.3984375.
124
- ❌ Layer ('model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'kernel') has PT grad norm 19.045316696166992 and flax grad norm 19.03125.
125
- ❌ Layer ('model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'kernel') has PT grad norm 18.37466049194336 and flax grad norm 18.359375.
126
- ❌ Layer ('model', 'encoder', 'layers', '9', 'fc1', 'kernel') has PT grad norm 12.223063468933105 and flax grad norm 12.1640625.
127
- ❌ Layer ('model', 'encoder', 'layers', '9', 'fc2', 'kernel') has PT grad norm 15.896522521972656 and flax grad norm 15.8515625.
128
  --------------------------Checking rel gradients match--------------------------
129
  ❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.0.
130
  ❌ Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 8.274865592738934e-08 and flax grad norm 0.0.
131
- ❌ Layer ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 3.530680814378684e-08 and flax grad norm 0.0.
132
- ❌ Layer ('model', 'decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 9.156278935051887e-08 and flax grad norm 0.0.
133
- ❌ Layer ('model', 'decoder', 'layers', '1', 'self_attn', 'k_proj', 'bias') has PT grad norm 2.7762926180230352e-08 and flax grad norm 0.0.
134
- ❌ Layer ('model', 'decoder', 'layers', '10', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 2.314587632668008e-08 and flax grad norm 0.0.
135
- ❌ Layer ('model', 'decoder', 'layers', '10', 'self_attn', 'k_proj', 'bias') has PT grad norm 1.0350275303494527e-08 and flax grad norm 0.0.
136
- ❌ Layer ('model', 'decoder', 'layers', '11', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 2.4003222520718737e-08 and flax grad norm 0.0.
137
- ❌ Layer ('model', 'decoder', 'layers', '11', 'self_attn', 'k_proj', 'bias') has PT grad norm 1.0850777165671843e-08 and flax grad norm 0.0.
138
- ❌ Layer ('model', 'decoder', 'layers', '2', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 8.753153935003866e-08 and flax grad norm 0.0.
139
- ❌ Layer ('model', 'decoder', 'layers', '2', 'self_attn', 'k_proj', 'bias') has PT grad norm 3.340263532436438e-08 and flax grad norm 0.0.
140
- ❌ Layer ('model', 'decoder', 'layers', '3', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 7.864576190286243e-08 and flax grad norm 0.0.
141
- ❌ Layer ('model', 'decoder', 'layers', '3', 'self_attn', 'k_proj', 'bias') has PT grad norm 3.161582284860742e-08 and flax grad norm 0.0.
142
- ❌ Layer ('model', 'decoder', 'layers', '4', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 3.1698593971896116e-08 and flax grad norm 0.0.
143
- ❌ Layer ('model', 'decoder', 'layers', '4', 'self_attn', 'k_proj', 'bias') has PT grad norm 3.397210690536667e-08 and flax grad norm 0.0.
144
- ❌ Layer ('model', 'decoder', 'layers', '5', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 4.269594100492213e-08 and flax grad norm 0.0.
145
- ❌ Layer ('model', 'decoder', 'layers', '5', 'self_attn', 'k_proj', 'bias') has PT grad norm 2.3111518032692402e-08 and flax grad norm 0.0.
146
- ❌ Layer ('model', 'decoder', 'layers', '6', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 3.222339373110117e-08 and flax grad norm 0.0.
147
- ❌ Layer ('model', 'decoder', 'layers', '6', 'self_attn', 'k_proj', 'bias') has PT grad norm 2.2391466458770992e-08 and flax grad norm 0.0.
148
- ❌ Layer ('model', 'decoder', 'layers', '7', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 2.4886496419185278e-08 and flax grad norm 0.0.
149
- ❌ Layer ('model', 'decoder', 'layers', '7', 'self_attn', 'k_proj', 'bias') has PT grad norm 1.1538945798861278e-08 and flax grad norm 0.0.
150
- ❌ Layer ('model', 'decoder', 'layers', '8', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 2.3322632713984603e-08 and flax grad norm 0.0.
151
- ❌ Layer ('model', 'decoder', 'layers', '8', 'self_attn', 'k_proj', 'bias') has PT grad norm 1.525707205018989e-08 and flax grad norm 0.0.
152
- ❌ Layer ('model', 'decoder', 'layers', '9', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 2.3752901867624132e-08 and flax grad norm 0.0.
153
- ❌ Layer ('model', 'decoder', 'layers', '9', 'self_attn', 'k_proj', 'bias') has PT grad norm 1.4804857784156411e-08 and flax grad norm 0.0.
154
- ❌ Layer ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 8.215602775862862e-08 and flax grad norm 0.0.
155
- ❌ Layer ('model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias') has PT grad norm 1.3884370275718538e-07 and flax grad norm 0.0.
156
- ❌ Layer ('model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'bias') has PT grad norm 6.668184937552724e-08 and flax grad norm 0.0.
157
- ❌ Layer ('model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'bias') has PT grad norm 8.486110658623147e-08 and flax grad norm 0.0.
158
- ❌ Layer ('model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'bias') has PT grad norm 7.912614563565512e-08 and flax grad norm 0.0.
159
- ❌ Layer ('model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'bias') has PT grad norm 7.959246062227976e-08 and flax grad norm 0.0.
160
- ❌ Layer ('model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'bias') has PT grad norm 1.1111747255654336e-07 and flax grad norm 0.0.
161
- ❌ Layer ('model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'bias') has PT grad norm 1.0121711824240265e-07 and flax grad norm 0.0.
162
- ❌ Layer ('model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'bias') has PT grad norm 7.735735607639072e-08 and flax grad norm 0.0.
163
- ❌ Layer ('model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'bias') has PT grad norm 1.0352330548357713e-07 and flax grad norm 0.0.
164
- ❌ Layer ('model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'bias') has PT grad norm 8.3155640595578e-08 and flax grad norm 0.0.
165
- ❌ Layer ('model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'bias') has PT grad norm 6.08824493042448e-08 and flax grad norm 0.0.
166
  =========================================
167
  Check facebook/bart-large-cnn ...
168
  --------------------------Checking logits match--------------------------
@@ -174,39 +65,8 @@ Flax loss: 13.418181419372559, PyTorch loss: 13.418176651000977
174
  --------------------------Checking gradients match--------------------------
175
  βœ… All grads pass
176
  --------------------------Checking rel gradients match--------------------------
177
- ❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.0.
178
  ❌ Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 3.5387660091146245e-07 and flax grad norm 4.874667069998395e-07.
179
  ❌ Layer ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 6.254911966152576e-08 and flax grad norm 6.927437112835833e-08.
180
- ❌ Layer ('model', 'decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 2.0864062832970376e-07 and flax grad norm 1.8239754240312323e-07.
181
- ❌ Layer ('model', 'decoder', 'layers', '1', 'self_attn', 'k_proj', 'bias') has PT grad norm 6.265895535761956e-08 and flax grad norm 6.79184637419894e-08.
182
- ❌ Layer ('model', 'decoder', 'layers', '10', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 2.3348372124587513e-08 and flax grad norm 1.9192864186834413e-08.
183
- ❌ Layer ('model', 'decoder', 'layers', '10', 'self_attn', 'k_proj', 'bias') has PT grad norm 2.108701835368265e-08 and flax grad norm 1.938536442480654e-08.
184
- ❌ Layer ('model', 'decoder', 'layers', '11', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 2.4383044916476138e-08 and flax grad norm 1.890670375814807e-08.
185
- ❌ Layer ('model', 'decoder', 'layers', '11', 'self_attn', 'k_proj', 'bias') has PT grad norm 1.4334801790027996e-08 and flax grad norm 1.3059753278810149e-08.
186
- ❌ Layer ('model', 'decoder', 'layers', '2', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 2.0973679681901558e-07 and flax grad norm 2.6336286396144715e-07.
187
- ❌ Layer ('model', 'decoder', 'layers', '2', 'self_attn', 'k_proj', 'bias') has PT grad norm 8.528008521579977e-08 and flax grad norm 1.286241939624233e-07.
188
- ❌ Layer ('model', 'decoder', 'layers', '3', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 1.4659757141544105e-07 and flax grad norm 1.4895935862568876e-07.
189
- ❌ Layer ('model', 'decoder', 'layers', '3', 'self_attn', 'k_proj', 'bias') has PT grad norm 1.2112769809391466e-07 and flax grad norm 1.382354497536653e-07.
190
- ❌ Layer ('model', 'decoder', 'layers', '4', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 6.575358924010288e-08 and flax grad norm 7.240216604031957e-08.
191
- ❌ Layer ('model', 'decoder', 'layers', '4', 'self_attn', 'k_proj', 'bias') has PT grad norm 5.0567738441031906e-08 and flax grad norm 5.817578241362753e-08.
192
- ❌ Layer ('model', 'decoder', 'layers', '5', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 2.5854723162410664e-07 and flax grad norm 2.9657505251634575e-07.
193
- ❌ Layer ('model', 'decoder', 'layers', '5', 'self_attn', 'k_proj', 'bias') has PT grad norm 8.977868048987148e-08 and flax grad norm 8.695727160556999e-08.
194
- ❌ Layer ('model', 'decoder', 'layers', '6', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 1.271907592581556e-07 and flax grad norm 1.5420768306739774e-07.
195
- ❌ Layer ('model', 'decoder', 'layers', '6', 'self_attn', 'k_proj', 'bias') has PT grad norm 3.32223777377294e-08 and flax grad norm 2.9034252335691235e-08.
196
- ❌ Layer ('model', 'decoder', 'layers', '7', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 2.932817011469524e-08 and flax grad norm 2.990202219166349e-08.
197
- ❌ Layer ('model', 'decoder', 'layers', '7', 'self_attn', 'k_proj', 'bias') has PT grad norm 2.2563149570942187e-08 and flax grad norm 1.903124946522894e-08.
198
- ❌ Layer ('model', 'decoder', 'layers', '8', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 2.1031839381180362e-08 and flax grad norm 1.922012415889185e-08.
199
- ❌ Layer ('model', 'decoder', 'layers', '8', 'self_attn', 'k_proj', 'bias') has PT grad norm 1.9617312219111227e-08 and flax grad norm 1.901776514046105e-08.
200
- ❌ Layer ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 9.498415209918676e-08 and flax grad norm 1.0714285281210323e-07.
201
- ❌ Layer ('model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias') has PT grad norm 4.2883741002697207e-08 and flax grad norm 3.708849050099161e-08.
202
- ❌ Layer ('model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'bias') has PT grad norm 6.720534884152585e-08 and flax grad norm 5.363398614122161e-08.
203
- ❌ Layer ('model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'bias') has PT grad norm 7.474503149751399e-08 and flax grad norm 7.271213320336756e-08.
204
- ❌ Layer ('model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'bias') has PT grad norm 2.8607770374833308e-08 and flax grad norm 2.428515344377047e-08.
205
- ❌ Layer ('model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'bias') has PT grad norm 6.400713203902342e-08 and flax grad norm 5.387828849734433e-08.
206
- ❌ Layer ('model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'bias') has PT grad norm 4.873627545975978e-08 and flax grad norm 4.757723104376055e-08.
207
- ❌ Layer ('model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'bias') has PT grad norm 5.0619281211083944e-08 and flax grad norm 4.67279193117065e-08.
208
- ❌ Layer ('model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'bias') has PT grad norm 5.6844903895125753e-08 and flax grad norm 6.739185920423552e-08.
209
- ❌ Layer ('model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'bias') has PT grad norm 5.603576624935158e-08 and flax grad norm 5.457893337279529e-08.
210
- ❌ Layer ('model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'bias') has PT grad norm 5.864935914701164e-08 and flax grad norm 6.345069891722233e-08.
211
- ❌ Layer ('model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'bias') has PT grad norm 6.470781244161117e-08 and flax grad norm 6.696199505995537e-08.
212
  =========================================
9
  βœ… All grads pass
10
  --------------------------Checking rel gradients match--------------------------
11
  ❌ Layer ('roberta', 'encoder', 'layer', '0', 'attention', 'self', 'key', 'bias') has PT grad norm 6.889232651019483e-08 and flax grad norm 5.7956174970286156e-08.
12
+ ...
 
 
 
 
 
 
 
 
 
 
13
  =========================================
14
  Check bert-base-cased ...
15
  --------------------------Checking logits match--------------------------
22
  βœ… All grads pass
23
  --------------------------Checking rel gradients match--------------------------
24
  ❌ Layer ('bert', 'encoder', 'layer', '0', 'attention', 'self', 'key', 'bias') has PT grad norm 8.025740783068613e-08 and flax grad norm 8.381563532111613e-08.
25
+ ...
 
 
 
 
 
 
 
 
 
26
  =========================================
27
  Check t5-small ...
28
  --------------------------Checking logits match--------------------------
46
  --------------------------Checking gradients match--------------------------
47
  ❌ Layer ('model', 'decoder', 'layers', '0', 'fc1', 'kernel') has PT grad norm 11.655710220336914 and flax grad norm 11.6015625.
48
  ❌ Layer ('model', 'decoder', 'layers', '0', 'fc2', 'kernel') has PT grad norm 7.740886211395264 and flax grad norm 7.71484375.
 
 
 
 
49
  ❌ Layer ('model', 'decoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel') has PT grad norm 6.97633171081543 and flax grad norm 6.96484375.
50
+ ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  --------------------------Checking rel gradients match--------------------------
52
  ❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.0.
53
  ❌ Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 8.274865592738934e-08 and flax grad norm 0.0.
54
+ ❌ Layer ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 2.2391466458770992e-08 and flax grad norm 0.0.
55
+ ❌ Layer ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 8.3155640595578e-08 and flax grad norm 0.0.
56
+ ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  =========================================
58
  Check facebook/bart-large-cnn ...
59
  --------------------------Checking logits match--------------------------
65
  --------------------------Checking gradients match--------------------------
66
  βœ… All grads pass
67
  --------------------------Checking rel gradients match--------------------------
 
68
  ❌ Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 3.5387660091146245e-07 and flax grad norm 4.874667069998395e-07.
69
  ❌ Layer ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 6.254911966152576e-08 and flax grad norm 6.927437112835833e-08.
70
+ ❌ Layer ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 5.864935914701164e-08 and flax grad norm 6.345069891722233e-08.
71
+ ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  =========================================