Spaces:
Sleeping
Sleeping
from typing import List | |
from torch import nn | |
import torch | |
class BaseLineModel(nn.Module): | |
def __init__( | |
self, | |
inp_vocab_size: int, | |
targ_vocab_size: int, | |
embedding_dim: int = 512, | |
layers_units: List[int] = [256, 256, 256], | |
use_batch_norm: bool = False, | |
): | |
super().__init__() | |
self.targ_vocab_size = targ_vocab_size | |
self.embedding = nn.Embedding(inp_vocab_size, embedding_dim) | |
layers_units = [embedding_dim // 2] + layers_units | |
layers = [] | |
for i in range(1, len(layers_units)): | |
layers.append( | |
nn.LSTM( | |
layers_units[i - 1] * 2, | |
layers_units[i], | |
bidirectional=True, | |
batch_first=True, | |
) | |
) | |
if use_batch_norm: | |
layers.append(nn.BatchNorm1d(layers_units[i] * 2)) | |
self.layers = nn.ModuleList(layers) | |
self.projections = nn.Linear(layers_units[-1] * 2, targ_vocab_size) | |
self.layers_units = layers_units | |
self.use_batch_norm = use_batch_norm | |
def forward(self, src: torch.Tensor, lengths: torch.Tensor, target=None): | |
outputs = self.embedding(src) | |
# embedded_inputs = [batch_size, src_len, embedding_dim] | |
for i, layer in enumerate(self.layers): | |
if isinstance(layer, nn.BatchNorm1d): | |
outputs = layer(outputs.permute(0, 2, 1)) | |
outputs = outputs.permute(0, 2, 1) | |
continue | |
if i > 0: | |
outputs, (hn, cn) = layer(outputs, (hn, cn)) | |
else: | |
outputs, (hn, cn) = layer(outputs) | |
predictions = self.projections(outputs) | |
output = {"diacritics": predictions} | |
return output | |