from torch import nn class CTCHead(nn.Module): def __init__(self, in_channels, out_channels=6625, fc_decay=0.0004, mid_channels=None, return_feats=False, **kwargs): super(CTCHead, self).__init__() if mid_channels is None: self.fc = nn.Linear( in_channels, out_channels, bias=True,) else: self.fc1 = nn.Linear( in_channels, mid_channels, bias=True, ) self.fc2 = nn.Linear( mid_channels, out_channels, bias=True, ) self.out_channels = out_channels self.mid_channels = mid_channels self.return_feats = return_feats def forward(self, x, labels=None): if self.mid_channels is None: predicts = self.fc(x) else: x = self.fc1(x) predicts = self.fc2(x) if self.return_feats: result = dict() result['ctc'] = predicts result['ctc_neck'] = x else: result = predicts return result