| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| from transformers.modeling_utils import ModuleUtilsMixin |
| from transformers.models.t5.modeling_t5 import T5Block, T5Config, T5LayerNorm |
|
|
| from ....configuration_utils import ConfigMixin, register_to_config |
| from ....models import ModelMixin |
|
|
|
|
| class SpectrogramNotesEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): |
| @register_to_config |
| def __init__( |
| self, |
| max_length: int, |
| vocab_size: int, |
| d_model: int, |
| dropout_rate: float, |
| num_layers: int, |
| num_heads: int, |
| d_kv: int, |
| d_ff: int, |
| feed_forward_proj: str, |
| is_decoder: bool = False, |
| ): |
| super().__init__() |
|
|
| self.token_embedder = nn.Embedding(vocab_size, d_model) |
|
|
| self.position_encoding = nn.Embedding(max_length, d_model) |
| self.position_encoding.weight.requires_grad = False |
|
|
| self.dropout_pre = nn.Dropout(p=dropout_rate) |
|
|
| t5config = T5Config( |
| vocab_size=vocab_size, |
| d_model=d_model, |
| num_heads=num_heads, |
| d_kv=d_kv, |
| d_ff=d_ff, |
| dropout_rate=dropout_rate, |
| feed_forward_proj=feed_forward_proj, |
| is_decoder=is_decoder, |
| is_encoder_decoder=False, |
| ) |
|
|
| self.encoders = nn.ModuleList() |
| for lyr_num in range(num_layers): |
| lyr = T5Block(t5config) |
| self.encoders.append(lyr) |
|
|
| self.layer_norm = T5LayerNorm(d_model) |
| self.dropout_post = nn.Dropout(p=dropout_rate) |
|
|
| def forward(self, encoder_input_tokens, encoder_inputs_mask): |
| x = self.token_embedder(encoder_input_tokens) |
|
|
| seq_length = encoder_input_tokens.shape[1] |
| inputs_positions = torch.arange(seq_length, device=encoder_input_tokens.device) |
| x += self.position_encoding(inputs_positions) |
|
|
| x = self.dropout_pre(x) |
|
|
| |
| input_shape = encoder_input_tokens.size() |
| extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) |
|
|
| for lyr in self.encoders: |
| x = lyr(x, extended_attention_mask)[0] |
| x = self.layer_norm(x) |
|
|
| return self.dropout_post(x), encoder_inputs_mask |
|
|