Santiago Hincapie-Potes
commited on
Commit
•
249d3f7
1
Parent(s):
d8bd706
feat: improve onnx support
Browse files- modelling_codegen.py +4 -1
modelling_codegen.py
CHANGED
@@ -40,7 +40,10 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
|
|
40 |
if seq_len is None:
|
41 |
seq_len = x.shape[seq_dim]
|
42 |
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
|
43 |
-
|
|
|
|
|
|
|
44 |
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
|
45 |
|
46 |
|
|
|
40 |
if seq_len is None:
|
41 |
seq_len = x.shape[seq_dim]
|
42 |
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
|
43 |
+
# original
|
44 |
+
# sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq).to(x.device).float()
|
45 |
+
# QHD fix onnx error by https://github.com/microsoft/onnxruntime/discussions/10121#discussioncomment-1987845
|
46 |
+
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len).float(), inv_freq).to(x.device).float()
|
47 |
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
|
48 |
|
49 |
|