import torch from torch import nn from data import Tokenizer class ResidualBlock(nn.Module): def __init__(self, num_channels, dropout=0.5): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv1d(num_channels, num_channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm1d(num_channels) self.conv2 = nn.Conv1d(num_channels, num_channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm1d(num_channels) self.prelu = nn.PReLU() self.dropout = nn.Dropout(dropout) def forward(self, x): residual = x x = self.prelu(self.bn1(self.conv1(x))) x = self.dropout(x) x = self.bn2(self.conv2(x)) x = self.prelu(x) x = self.dropout(x) x += residual # shouldn't it be after activation function? return x class Seq2SeqCNN(nn.Module): # def __init__(self, dict_size_src, dict_size_trg, embedding_dim, num_channels, num_residual_blocks, dropout=0.5): def __init__(self, config): dict_size_src = config['dict_size_src'] dict_size_trg = config['dict_size_trg'] embedding_dim = config['embedding_dim'] num_channels = config['num_channels'] num_residual_blocks = config['num_residual_blocks'] dropout = config['dropout'] many_to_one = config['many_to_one'] self.config = config super(Seq2SeqCNN, self).__init__() self.embedding = nn.Embedding(dict_size_src, embedding_dim) self.conv = nn.Conv1d(embedding_dim, num_channels, kernel_size=3, padding=1) self.bn = nn.BatchNorm1d(num_channels) self.residual_blocks = nn.Sequential( *(ResidualBlock(num_channels, dropout) for _ in range(num_residual_blocks)) # Add as many blocks as required ) self.fc = nn.Linear(num_channels, dict_size_trg*many_to_one) self.dropout = nn.Dropout(dropout) self.dict_size_trg = dict_size_trg def forward(self, src): # src: (batch_size, seq_len) batch_size = src.size(0) embedded = self.embedding(src).permute(0, 2, 1) # (bsize, emb_dim, seq_len) # print('embedded:', embedded.shape) conv_out0 = self.conv(embedded) # (bsize, num_channels, seq_len) # print('conv_out0:', conv_out0.shape) # conv_out = embedded conv_out = self.dropout(torch.relu(self.bn(conv_out0))) # conv_out = conv_out0 res_out = self.residual_blocks(conv_out) # print('res_out:', res_out.shape) res_out = res_out + conv_out # res_out = torch.cat([res_out, embedded], dim=1) out = self.fc(self.dropout(res_out.permute(0, 2, 1))) # permute back to original out = out.view(batch_size, -1, self.config['many_to_one'], self.dict_size_trg) return out def init_model(path, device="cpu"): d = torch.load(path, map_location=device) state_dict = d['state_dict'] model = Seq2SeqCNN(d['config']).to(device) model.load_state_dict(state_dict) return model @torch.no_grad() def _predict(model, src, device): model.eval() src = src.to(device) output = model(src) _, pred = torch.max(output, dim=-1) # output = torch.softmax(output, dim=3) # print(output.shape) # pred = torch.multinomial(output.view(-1, output.size(-1)), 1) # pred = pred.reshape(output.size()[:-1]) # print(pred.shape) return pred @torch.no_grad() def predict(model, tokenizer: "Tokenizer", text:str, device): print('text:', text) if not text: return '' text_encoded = tokenizer.encode_src(text) batch = text_encoded.unsqueeze(0) prd = _predict(model, batch, device)[0] prd = prd[batch[0] != tokenizer.src_pad_idx,:] predicted_text = ''.join(tokenizer.decode_trg(prd)) print('predicted_text:', repr(predicted_text)) return predicted_text # .replace('\u200c', '')