audio_to_phonome / model.py
hash-map's picture
Upload 6 files
4ffa9fc verified
# model_img2ph.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNNEncoder(nn.Module):
def __init__(self, in_channels=1, hidden_dim=256, dropout=0.2):
super().__init__()
# Convolutions mostly reduce frequency dimension, not time
self.conv = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=3, stride=(2,1), padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Dropout(dropout),
nn.Conv2d(64, 128, kernel_size=3, stride=(2,1), padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Dropout(dropout),
nn.Conv2d(128, 256, kernel_size=3, stride=(2,1), padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Dropout(dropout),
nn.Conv2d(256, hidden_dim, kernel_size=3, stride=(2,1), padding=1),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
)
def forward(self, x):
# x: (B, n_mels, T)
x = x.unsqueeze(1) # (B,1,n_mels,T)
feat = self.conv(x) # (B,C,H’,T)
B, C, H, T = feat.size()
# collapse frequency into features, keep time intact
feat = feat.permute(0, 3, 1, 2).contiguous() # (B,T,C,H)
feat = feat.view(B, T, C*H) # (B,T,features)
return feat
class PhonemeDecoder(nn.Module):
def __init__(self, vocab_size, enc_dim=128*5, rnn_hidden=128, num_layers=2, dropout=0.3):
super().__init__()
self.rnn = nn.GRU(
enc_dim, rnn_hidden,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
bidirectional=False # Changed to unidirectional
)
self.proj = nn.Linear(rnn_hidden, 256) # Single projection layer
self.norm = nn.LayerNorm(256) # Added LayerNorm
self.fc_out = nn.Linear(256, vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self, enc_out):
rnn_out, _ = self.rnn(enc_out) # (B,T,rnn_hidden)
dense_out = self.proj(rnn_out) # (B,T,256)
dense_out = self.norm(dense_out) # Normalize
dense_out = F.relu(dense_out) # Activation
dense_out = self.dropout(dense_out) # Dropout after activation
logits = self.fc_out(dense_out) # (B,T,vocab_size)
return logits
class Image2Phoneme(nn.Module):
def __init__(self, vocab_size, in_channels=1, enc_hidden=128, rnn_hidden=128):
super().__init__()
self.encoder = CNNEncoder(in_channels=in_channels, hidden_dim=enc_hidden)
# enc_dim = enc_hidden * H’, after convs Hβ€™β‰ˆ5 (if input mel=80, stride=(2,1) 4 times β†’ 80/16=5)
enc_dim = enc_hidden * 5
self.decoder = PhonemeDecoder(vocab_size, enc_dim=enc_dim, rnn_hidden=rnn_hidden)
def forward(self, mels):
# mels: (B,n_mels,T)
enc_out = self.encoder(mels) # (B,T,enc_dim)
logits = self.decoder(enc_out) # (B,T,vocab_size)
return logits