# model_img2ph.py import torch import torch.nn as nn import torch.nn.functional as F class CNNEncoder(nn.Module): def __init__(self, in_channels=1, hidden_dim=256, dropout=0.2): super().__init__() # Convolutions mostly reduce frequency dimension, not time self.conv = nn.Sequential( nn.Conv2d(in_channels, 64, kernel_size=3, stride=(2,1), padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.Dropout(dropout), nn.Conv2d(64, 128, kernel_size=3, stride=(2,1), padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.Dropout(dropout), nn.Conv2d(128, 256, kernel_size=3, stride=(2,1), padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.Dropout(dropout), nn.Conv2d(256, hidden_dim, kernel_size=3, stride=(2,1), padding=1), nn.BatchNorm2d(hidden_dim), nn.ReLU(), nn.Dropout(dropout), ) def forward(self, x): # x: (B, n_mels, T) x = x.unsqueeze(1) # (B,1,n_mels,T) feat = self.conv(x) # (B,C,H’,T) B, C, H, T = feat.size() # collapse frequency into features, keep time intact feat = feat.permute(0, 3, 1, 2).contiguous() # (B,T,C,H) feat = feat.view(B, T, C*H) # (B,T,features) return feat class PhonemeDecoder(nn.Module): def __init__(self, vocab_size, enc_dim=128*5, rnn_hidden=128, num_layers=2, dropout=0.3): super().__init__() self.rnn = nn.GRU( enc_dim, rnn_hidden, num_layers=num_layers, batch_first=True, dropout=dropout, bidirectional=False # Changed to unidirectional ) self.proj = nn.Linear(rnn_hidden, 256) # Single projection layer self.norm = nn.LayerNorm(256) # Added LayerNorm self.fc_out = nn.Linear(256, vocab_size) self.dropout = nn.Dropout(dropout) def forward(self, enc_out): rnn_out, _ = self.rnn(enc_out) # (B,T,rnn_hidden) dense_out = self.proj(rnn_out) # (B,T,256) dense_out = self.norm(dense_out) # Normalize dense_out = F.relu(dense_out) # Activation dense_out = self.dropout(dense_out) # Dropout after activation logits = self.fc_out(dense_out) # (B,T,vocab_size) return logits class Image2Phoneme(nn.Module): def __init__(self, vocab_size, in_channels=1, enc_hidden=128, rnn_hidden=128): super().__init__() self.encoder = CNNEncoder(in_channels=in_channels, hidden_dim=enc_hidden) # enc_dim = enc_hidden * H’, after convs H’≈5 (if input mel=80, stride=(2,1) 4 times → 80/16=5) enc_dim = enc_hidden * 5 self.decoder = PhonemeDecoder(vocab_size, enc_dim=enc_dim, rnn_hidden=rnn_hidden) def forward(self, mels): # mels: (B,n_mels,T) enc_out = self.encoder(mels) # (B,T,enc_dim) logits = self.decoder(enc_out) # (B,T,vocab_size) return logits