Santiago Hincapie-Potes commited on
Commit
249d3f7
1 Parent(s): d8bd706

feat: improve onnx support

Browse files
Files changed (1) hide show
  1. modelling_codegen.py +4 -1
modelling_codegen.py CHANGED
@@ -40,7 +40,10 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
40
  if seq_len is None:
41
  seq_len = x.shape[seq_dim]
42
  inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
43
- sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq).to(x.device).float()
 
 
 
44
  return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
45
 
46
 
 
40
  if seq_len is None:
41
  seq_len = x.shape[seq_dim]
42
  inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
43
+ # original
44
+ # sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq).to(x.device).float()
45
+ # QHD fix onnx error by https://github.com/microsoft/onnxruntime/discussions/10121#discussioncomment-1987845
46
+ sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len).float(), inv_freq).to(x.device).float()
47
  return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
48
 
49