Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| from comfy.ldm.modules.attention import optimized_attention_masked | |
| class LayerNormConv(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None): | |
| super().__init__() | |
| self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype) | |
| self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1)) | |
| class ConvFeatureEncoder(nn.Module): | |
| def __init__(self, conv_dim, dtype=None, device=None, operations=None): | |
| super().__init__() | |
| self.conv_layers = nn.ModuleList([ | |
| LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations), | |
| LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), | |
| LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), | |
| LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), | |
| LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations), | |
| LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations), | |
| LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations), | |
| ]) | |
| def forward(self, x): | |
| x = x.unsqueeze(1) | |
| for conv in self.conv_layers: | |
| x = conv(x) | |
| return x.transpose(1, 2) | |
| class FeatureProjection(nn.Module): | |
| def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None): | |
| super().__init__() | |
| self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype) | |
| self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype) | |
| def forward(self, x): | |
| x = self.layer_norm(x) | |
| x = self.projection(x) | |
| return x | |
| class PositionalConvEmbedding(nn.Module): | |
| def __init__(self, embed_dim=768, kernel_size=128, groups=16): | |
| super().__init__() | |
| self.conv = nn.Conv1d( | |
| embed_dim, | |
| embed_dim, | |
| kernel_size=kernel_size, | |
| padding=kernel_size // 2, | |
| groups=groups, | |
| ) | |
| self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2) | |
| self.activation = nn.GELU() | |
| def forward(self, x): | |
| x = x.transpose(1, 2) | |
| x = self.conv(x)[:, :, :-1] | |
| x = self.activation(x) | |
| x = x.transpose(1, 2) | |
| return x | |
| class TransformerEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim=768, | |
| num_heads=12, | |
| num_layers=12, | |
| mlp_ratio=4.0, | |
| dtype=None, device=None, operations=None | |
| ): | |
| super().__init__() | |
| self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim) | |
| self.layers = nn.ModuleList([ | |
| TransformerEncoderLayer( | |
| embed_dim=embed_dim, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| device=device, dtype=dtype, operations=operations | |
| ) | |
| for _ in range(num_layers) | |
| ]) | |
| self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype) | |
| def forward(self, x, mask=None): | |
| x = x + self.pos_conv_embed(x) | |
| all_x = () | |
| for layer in self.layers: | |
| all_x += (x,) | |
| x = layer(x, mask) | |
| x = self.layer_norm(x) | |
| all_x += (x,) | |
| return x, all_x | |
| class Attention(nn.Module): | |
| def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.head_dim = embed_dim // num_heads | |
| self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype) | |
| self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype) | |
| self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype) | |
| self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype) | |
| def forward(self, x, mask=None): | |
| assert (mask is None) # TODO? | |
| q = self.q_proj(x) | |
| k = self.k_proj(x) | |
| v = self.v_proj(x) | |
| out = optimized_attention_masked(q, k, v, self.num_heads) | |
| return self.out_proj(out) | |
| class FeedForward(nn.Module): | |
| def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None): | |
| super().__init__() | |
| self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype) | |
| self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype) | |
| def forward(self, x): | |
| x = self.intermediate_dense(x) | |
| x = torch.nn.functional.gelu(x) | |
| x = self.output_dense(x) | |
| return x | |
| class TransformerEncoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim=768, | |
| num_heads=12, | |
| mlp_ratio=4.0, | |
| dtype=None, device=None, operations=None | |
| ): | |
| super().__init__() | |
| self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations) | |
| self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype) | |
| self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations) | |
| self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype) | |
| def forward(self, x, mask=None): | |
| residual = x | |
| x = self.layer_norm(x) | |
| x = self.attention(x, mask=mask) | |
| x = residual + x | |
| x = x + self.feed_forward(self.final_layer_norm(x)) | |
| return x | |
| class Wav2Vec2Model(nn.Module): | |
| """Complete Wav2Vec 2.0 model.""" | |
| def __init__( | |
| self, | |
| embed_dim=1024, | |
| final_dim=256, | |
| num_heads=16, | |
| num_layers=24, | |
| dtype=None, device=None, operations=None | |
| ): | |
| super().__init__() | |
| conv_dim = 512 | |
| self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations) | |
| self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations) | |
| self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype)) | |
| self.encoder = TransformerEncoder( | |
| embed_dim=embed_dim, | |
| num_heads=num_heads, | |
| num_layers=num_layers, | |
| device=device, dtype=dtype, operations=operations | |
| ) | |
| def forward(self, x, mask_time_indices=None, return_dict=False): | |
| x = torch.mean(x, dim=1) | |
| x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7) | |
| features = self.feature_extractor(x) | |
| features = self.feature_projection(features) | |
| batch_size, seq_len, _ = features.shape | |
| x, all_x = self.encoder(features) | |
| return x, all_x | |