Hecheng0625 commited on
Commit
f6cb138
1 Parent(s): b38e3ae

Update Amphion/models/ns3_codec/transformer.py

Browse files
Amphion/models/ns3_codec/transformer.py CHANGED
@@ -14,6 +14,18 @@ class StyleAdaptiveLayerNorm(nn.Module):
14
  self.style.bias.data[: self.in_dim] = 1
15
  self.style.bias.data[self.in_dim :] = 0
16
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  class PositionalEncoding(nn.Module):
19
  def __init__(self, d_model, dropout, max_len=5000):
 
14
  self.style.bias.data[: self.in_dim] = 1
15
  self.style.bias.data[self.in_dim :] = 0
16
 
17
+ def forward(self, x, condition):
18
+ # x: (B, T, d); condition: (B, T, d)
19
+
20
+ style = self.style(torch.mean(condition, dim=1, keepdim=True))
21
+
22
+ gamma, beta = style.chunk(2, -1)
23
+
24
+ out = self.norm(x)
25
+
26
+ out = gamma * out + beta
27
+ return out
28
+
29
 
30
  class PositionalEncoding(nn.Module):
31
  def __init__(self, d_model, dropout, max_len=5000):