Spaces:
Sleeping
Sleeping
| # model_img2ph.py | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class CNNEncoder(nn.Module): | |
| def __init__(self, in_channels=1, hidden_dim=256, dropout=0.2): | |
| super().__init__() | |
| # Convolutions mostly reduce frequency dimension, not time | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_channels, 64, kernel_size=3, stride=(2,1), padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Conv2d(64, 128, kernel_size=3, stride=(2,1), padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Conv2d(128, 256, kernel_size=3, stride=(2,1), padding=1), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Conv2d(256, hidden_dim, kernel_size=3, stride=(2,1), padding=1), | |
| nn.BatchNorm2d(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, x): | |
| # x: (B, n_mels, T) | |
| x = x.unsqueeze(1) # (B,1,n_mels,T) | |
| feat = self.conv(x) # (B,C,Hβ,T) | |
| B, C, H, T = feat.size() | |
| # collapse frequency into features, keep time intact | |
| feat = feat.permute(0, 3, 1, 2).contiguous() # (B,T,C,H) | |
| feat = feat.view(B, T, C*H) # (B,T,features) | |
| return feat | |
| class PhonemeDecoder(nn.Module): | |
| def __init__(self, vocab_size, enc_dim=128*5, rnn_hidden=128, num_layers=2, dropout=0.3): | |
| super().__init__() | |
| self.rnn = nn.GRU( | |
| enc_dim, rnn_hidden, | |
| num_layers=num_layers, | |
| batch_first=True, | |
| dropout=dropout, | |
| bidirectional=False # Changed to unidirectional | |
| ) | |
| self.proj = nn.Linear(rnn_hidden, 256) # Single projection layer | |
| self.norm = nn.LayerNorm(256) # Added LayerNorm | |
| self.fc_out = nn.Linear(256, vocab_size) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, enc_out): | |
| rnn_out, _ = self.rnn(enc_out) # (B,T,rnn_hidden) | |
| dense_out = self.proj(rnn_out) # (B,T,256) | |
| dense_out = self.norm(dense_out) # Normalize | |
| dense_out = F.relu(dense_out) # Activation | |
| dense_out = self.dropout(dense_out) # Dropout after activation | |
| logits = self.fc_out(dense_out) # (B,T,vocab_size) | |
| return logits | |
| class Image2Phoneme(nn.Module): | |
| def __init__(self, vocab_size, in_channels=1, enc_hidden=128, rnn_hidden=128): | |
| super().__init__() | |
| self.encoder = CNNEncoder(in_channels=in_channels, hidden_dim=enc_hidden) | |
| # enc_dim = enc_hidden * Hβ, after convs Hββ5 (if input mel=80, stride=(2,1) 4 times β 80/16=5) | |
| enc_dim = enc_hidden * 5 | |
| self.decoder = PhonemeDecoder(vocab_size, enc_dim=enc_dim, rnn_hidden=rnn_hidden) | |
| def forward(self, mels): | |
| # mels: (B,n_mels,T) | |
| enc_out = self.encoder(mels) # (B,T,enc_dim) | |
| logits = self.decoder(enc_out) # (B,T,vocab_size) | |
| return logits | |