| import torch.nn as nn |
|
|
| from model.attention import Attention |
|
|
| |
|
|
|
|
|
|
| class TransformerLayer(nn.Module): |
| def __init__(self, config) -> None: |
| super().__init__() |
| self.hidden_dim = config["hidden_dim"] |
|
|
| |
| self.ff_dim = 4 * self.hidden_dim |
|
|
| |
| self.attn_norm = nn.LayerNorm( |
| self.hidden_dim, |
| self.ff_dim, |
| elementwise_affine=False, |
| ) |
|
|
| self.attn_block = Attention(config) |
|
|
| |
| self.ff_norm = nn.LayerNorm( |
| self.hidden_dim, |
| self.ff_dim, |
| elementwise_affine=False, |
| ) |
|
|
| self.mlp_block = nn.Sequential( |
| nn.Linear(self.hidden_dim, self.ff_dim), |
| nn.GELU(approximate="tanh"), |
| nn.Linear(self.ff_dim, self.hidden_dim), |
| ) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| self.adaptive_norm_layer = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(self.hidden_dim, 6 * self.hidden_dim, bias=True), |
| ) |
|
|
| nn.init.xavier_uniform_(self.mlp_block[0].weight) |
| nn.init.constant_(self.mlp_block[0].bias, 0) |
| nn.init.xavier_uniform_(self.mlp_block[-1].weight) |
| nn.init.constant_(self.mlp_block[-1].bias, 0) |
|
|
| nn.init.constant_(self.adaptive_norm_layer[-1].weight, 0) |
| nn.init.constant_(self.adaptive_norm_layer[-1].bias, 0) |
|
|
| def forward(self, x, condition): |
| scale_shift_params = self.adaptive_norm_layer(condition).chunk(6, dim=1) |
| ( |
| pre_attn_shift, |
| pre_attn_scale, |
| post_attn_scale, |
| pre_mlp_shift, |
| pre_mlp_scale, |
| post_mlp_scale, |
| ) = scale_shift_params |
|
|
| out = x |
| attn_norm = self.attn_norm(out) * ( |
| 1 + pre_attn_scale.unsqueeze(1) |
| ) + pre_attn_shift.unsqueeze(1) |
|
|
| out = out + self.attn_block(attn_norm) * post_attn_scale.unsqueeze(1) |
|
|
| mlp_norm = self.attn_norm(out) * ( |
| 1 + pre_mlp_scale.unsqueeze(1) |
| ) + pre_mlp_shift.unsqueeze(1) |
|
|
| out = out + self.mlp_block(mlp_norm) * post_mlp_scale.unsqueeze(1) |
| return out |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|