|
========================================= |
|
Check facebook/bart-large ... |
|
--------------------------Checking logits match-------------------------- |
|
Flax logits shape: (2, 64, 50265), PyTorch logits shape: torch.Size([2, 64, 50265]) |
|
β
Difference between Flax and PyTorch is 0.00039315223693847656 (< 0.01) |
|
--------------------------Checking losses match-------------------------- |
|
Flax loss: 15.027304649353027, PyTorch loss: 15.027304649353027 |
|
β
Difference between Flax and PyTorch is 0.0 (< 0.01) |
|
--------------------------Checking gradients match-------------------------- |
|
β Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.09944064915180206. |
|
β Layer ('model', 'decoder', 'layers', '0', 'fc1', 'kernel') has PT grad norm 13.111018180847168 and flax grad norm 13.0546875. |
|
β Layer ('model', 'decoder', 'layers', '0', 'fc2', 'kernel') has PT grad norm 8.751346588134766 and flax grad norm 8.71875. |
|
... |
|
β Layer ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel') has PT grad norm 18.60892105102539 and flax grad norm 18.59375. |
|
... |
|
β Layer ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel') has PT grad norm 96.85579681396484 and flax grad norm 96.8125. |
|
... |
|
β Layer ('model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel') has PT grad norm 199.41278076171875 and flax grad norm 199.25. |
|
... |
|
--------------------------Checking rel gradients match-------------------------- |
|
β Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.09944064915180206. |
|
β Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 1.4212106691502413e-07 and flax grad norm 0.0. |
|
β Layer ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 2.0100719311244575e-08 and flax grad norm 0.0. |
|
... |
|
========================================= |
|
Check facebook/bart-large-cnn ... |
|
--------------------------Checking logits match-------------------------- |
|
Flax logits shape: (2, 64, 50264), PyTorch logits shape: torch.Size([2, 64, 50264]) |
|
β
Difference between Flax and PyTorch is 0.0001919269561767578 (< 0.01) |
|
--------------------------Checking losses match-------------------------- |
|
Flax loss: 13.262251853942871, PyTorch loss: 13.262249946594238 |
|
β
Difference between Flax and PyTorch is 1.9073486328125e-06 (< 0.01) |
|
--------------------------Checking gradients match-------------------------- |
|
β Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.09764379262924194. |
|
--------------------------Checking rel gradients match-------------------------- |
|
β Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.09764379262924194. |
|
β Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 2.1513474735002092e-07 and flax grad norm 1.5481474235912174e-07. |
|
β Layer ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 3.8047311079481005e-08 and flax grad norm 3.508952062247772e-08. |
|
... |
|
========================================= |
|
|