fix: use cls_bias
Browse files- modeling_bert.py +1 -1
modeling_bert.py
CHANGED
@@ -747,7 +747,7 @@ class JinaBertEncoder(nn.Module):
|
|
747 |
alibi = alibi.unsqueeze(0)
|
748 |
assert alibi.shape == torch.Size([1, n_heads, size, size])
|
749 |
|
750 |
-
if
|
751 |
alibi[:, :, 0, :] = cls_bias
|
752 |
alibi[:, :, :, 0] = cls_bias
|
753 |
|
|
|
747 |
alibi = alibi.unsqueeze(0)
|
748 |
assert alibi.shape == torch.Size([1, n_heads, size, size])
|
749 |
|
750 |
+
if cls_bias is not None:
|
751 |
alibi[:, :, 0, :] = cls_bias
|
752 |
alibi[:, :, :, 0] = cls_bias
|
753 |
|