""" File: model.py Author: Elena Ryumina and Dmitry Ryumin Description: This module provides model architectures. License: MIT License """ import torch import torch.nn as nn import torch.nn.functional as F import math import numpy as np from transformers.models.wav2vec2.modeling_wav2vec2 import ( Wav2Vec2Model, Wav2Vec2PreTrainedModel, ) from typing import Optional class Bottleneck(nn.Module): expansion = 4 def __init__(self, in_channels, out_channels, i_downsample=None, stride=1): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False) self.batch_norm1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same', bias=False) self.batch_norm2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99) self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0, bias=False) self.batch_norm3 = nn.BatchNorm2d(out_channels*self.expansion, eps=0.001, momentum=0.99) self.i_downsample = i_downsample self.stride = stride self.relu = nn.ReLU() def forward(self, x): identity = x.clone() x = self.relu(self.batch_norm1(self.conv1(x))) x = self.relu(self.batch_norm2(self.conv2(x))) x = self.conv3(x) x = self.batch_norm3(x) #downsample if needed if self.i_downsample is not None: identity = self.i_downsample(identity) #add identity x+=identity x=self.relu(x) return x class Conv2dSame(torch.nn.Conv2d): def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: ih, iw = x.size()[-2:] pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]) pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]) if pad_h > 0 or pad_w > 0: x = F.pad( x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] ) return F.conv2d( x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, ) class ResNet(nn.Module): def __init__(self, ResBlock, layer_list, num_classes, num_channels=3): super(ResNet, self).__init__() self.in_channels = 64 self.conv_layer_s2_same = Conv2dSame(num_channels, 64, 7, stride=2, groups=1, bias=False) self.batch_norm1 = nn.BatchNorm2d(64, eps=0.001, momentum=0.99) self.relu = nn.ReLU() self.max_pool = nn.MaxPool2d(kernel_size = 3, stride=2) self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64, stride=1) self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2) self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2) self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1,1)) self.fc1 = nn.Linear(512*ResBlock.expansion, 512) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(512, num_classes) def extract_features(self, x): x = self.relu(self.batch_norm1(self.conv_layer_s2_same(x))) x = self.max_pool(x) # print(x.shape) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = x.reshape(x.shape[0], -1) x = self.fc1(x) return x def forward(self, x): x = self.extract_features(x) x = self.relu1(x) x = self.fc2(x) return x def _make_layer(self, ResBlock, blocks, planes, stride=1): ii_downsample = None layers = [] if stride != 1 or self.in_channels != planes*ResBlock.expansion: ii_downsample = nn.Sequential( nn.Conv2d(self.in_channels, planes*ResBlock.expansion, kernel_size=1, stride=stride, bias=False, padding=0), nn.BatchNorm2d(planes*ResBlock.expansion, eps=0.001, momentum=0.99) ) layers.append(ResBlock(self.in_channels, planes, i_downsample=ii_downsample, stride=stride)) self.in_channels = planes*ResBlock.expansion for i in range(blocks-1): layers.append(ResBlock(self.in_channels, planes)) return nn.Sequential(*layers) def ResNet50(num_classes, channels=3): return ResNet(Bottleneck, [3,4,6,3], num_classes, channels) class LSTMPyTorch(nn.Module): def __init__(self): super(LSTMPyTorch, self).__init__() self.lstm1 = nn.LSTM(input_size=512, hidden_size=512, batch_first=True, bidirectional=False) self.lstm2 = nn.LSTM(input_size=512, hidden_size=256, batch_first=True, bidirectional=False) self.fc = nn.Linear(256, 7) # self.softmax = nn.Softmax(dim=1) def forward(self, x): x, _ = self.lstm1(x) x, _ = self.lstm2(x) x = self.fc(x[:, -1, :]) # x = self.softmax(x) return x class ExprModelV3(Wav2Vec2PreTrainedModel): def __init__(self, config) -> None: super().__init__(config) self.config = config self.wav2vec2 = Wav2Vec2Model(config) self.tl1 = TransformerLayer( input_dim=1024, num_heads=32, dropout=0.1, positional_encoding=True ) self.tl2 = TransformerLayer( input_dim=1024, num_heads=16, dropout=0.1, positional_encoding=True ) self.f_size = 1024 self.time_downsample = torch.nn.Sequential( torch.nn.Conv1d( self.f_size, self.f_size, kernel_size=5, stride=3, dilation=2 ), torch.nn.BatchNorm1d(self.f_size), torch.nn.MaxPool1d(5), torch.nn.ReLU(), torch.nn.Conv1d(self.f_size, self.f_size, kernel_size=3), torch.nn.BatchNorm1d(self.f_size), torch.nn.AdaptiveAvgPool1d(1), torch.nn.ReLU(), ) self.feature_downsample = nn.Linear(self.f_size, 8) self.init_weights() self.unfreeze_last_n_blocks(4) def freeze_conv_only(self): # freeze conv for param in self.wav2vec2.feature_extractor.conv_layers.parameters(): param.requires_grad = False def unfreeze_last_n_blocks(self, num_blocks: int) -> None: # freeze all wav2vec for param in self.wav2vec2.parameters(): param.requires_grad = False # unfreeze last n transformer blocks for i in range(0, num_blocks): for param in self.wav2vec2.encoder.layers[-1 * (i + 1)].parameters(): param.requires_grad = True def forward(self, x): x = self.wav2vec2(x)[0] x = self.tl1(query=x, key=x, value=x) x = self.tl2(query=x, key=x, value=x) x = x.permute(0, 2, 1) x = self.time_downsample(x) x = x.squeeze() x = self.feature_downsample(x) return x class ScaledDotProductAttention_MultiHead(nn.Module): def __init__(self): super(ScaledDotProductAttention_MultiHead, self).__init__() self.softmax = nn.Softmax(dim=-1) def forward(self, query, key, value, mask=None): if mask is not None: raise ValueError("Mask is not supported yet") # key, query, value shapes: [batch_size, num_heads, seq_len, dim] emb_dim = key.shape[-1] # Calculate attention weights attention_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt( emb_dim ) # masking if mask is not None: raise ValueError("Mask is not supported yet") # Softmax attention_weights = self.softmax(attention_weights) # modify value value = torch.matmul(attention_weights, value) return value, attention_weights class PositionWiseFeedForward(nn.Module): def __init__(self, input_dim, hidden_dim, dropout: float = 0.1): super().__init__() self.layer_1 = nn.Linear(input_dim, hidden_dim) self.layer_2 = nn.Linear(hidden_dim, input_dim) self.layer_norm = nn.LayerNorm(input_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): # feed-forward network x = self.layer_1(x) x = self.dropout(x) x = F.relu(x) x = self.layer_2(x) return x class Add_and_Norm(nn.Module): def __init__(self, input_dim, dropout: Optional[float] = 0.1): super().__init__() self.layer_norm = nn.LayerNorm(input_dim) if dropout is not None: self.dropout = nn.Dropout(dropout) def forward(self, x1, residual): x = x1 # apply dropout of needed if hasattr(self, "dropout"): x = self.dropout(x) # add and then norm x = x + residual x = self.layer_norm(x) return x class MultiHeadAttention(nn.Module): def __init__(self, input_dim, num_heads, dropout: Optional[float] = 0.1): super().__init__() self.input_dim = input_dim self.num_heads = num_heads if input_dim % num_heads != 0: raise ValueError("input_dim must be divisible by num_heads") self.head_dim = input_dim // num_heads self.dropout = dropout # initialize weights self.query_w = nn.Linear(input_dim, self.num_heads * self.head_dim, bias=False) self.keys_w = nn.Linear(input_dim, self.num_heads * self.head_dim, bias=False) self.values_w = nn.Linear(input_dim, self.num_heads * self.head_dim, bias=False) self.ff_layer_after_concat = nn.Linear( self.num_heads * self.head_dim, input_dim, bias=False ) self.attention = ScaledDotProductAttention_MultiHead() if self.dropout is not None: self.dropout = nn.Dropout(dropout) def forward(self, queries, keys, values, mask=None): # query, keys, values shapes: [batch_size, seq_len, input_dim] batch_size, len_query, len_keys, len_values = ( queries.size(0), queries.size(1), keys.size(1), values.size(1), ) # linear transformation before attention queries = ( self.query_w(queries) .view(batch_size, len_query, self.num_heads, self.head_dim) .transpose(1, 2) ) # [batch_size, num_heads, seq_len, dim] keys = ( self.keys_w(keys) .view(batch_size, len_keys, self.num_heads, self.head_dim) .transpose(1, 2) ) # [batch_size, num_heads, seq_len, dim] values = ( self.values_w(values) .view(batch_size, len_values, self.num_heads, self.head_dim) .transpose(1, 2) ) # [batch_size, num_heads, seq_len, dim] # attention itself values, attention_weights = self.attention( queries, keys, values, mask=mask ) # values shape:[batch_size, num_heads, seq_len, dim] # concatenation out = ( values.transpose(1, 2) .contiguous() .view(batch_size, len_values, self.num_heads * self.head_dim) ) # [batch_size, seq_len, num_heads * dim = input_dim] # go through last linear layer out = self.ff_layer_after_concat(out) return out class EncoderLayer(nn.Module): def __init__( self, input_dim, num_heads, dropout: Optional[float] = 0.1, positional_encoding: bool = True, ): super(EncoderLayer, self).__init__() self.positional_encoding = positional_encoding self.input_dim = input_dim self.num_heads = num_heads self.head_dim = input_dim // num_heads self.dropout = dropout # initialize layers self.self_attention = MultiHeadAttention(input_dim, num_heads, dropout=dropout) self.feed_forward = PositionWiseFeedForward( input_dim, input_dim, dropout=dropout ) self.add_norm_after_attention = Add_and_Norm(input_dim, dropout=dropout) self.add_norm_after_ff = Add_and_Norm(input_dim, dropout=dropout) # calculate positional encoding if self.positional_encoding: self.positional_encoding = PositionalEncoding(input_dim) def forward(self, x): # x shape: [batch_size, seq_len, input_dim] # positional encoding if self.positional_encoding: x = self.positional_encoding(x) # multi-head attention residual = x x = self.self_attention(x, x, x) x = self.add_norm_after_attention(x, residual) # feed forward residual = x x = self.feed_forward(x) x = self.add_norm_after_ff(x, residual) return x class PositionalEncoding(nn.Module): def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) ) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) pe = pe.permute( 1, 0, 2 ) # [seq_len, batch_size, embedding_dim] -> [batch_size, seq_len, embedding_dim] self.register_buffer("pe", pe) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: Tensor, shape [batch_size, seq_len, embedding_dim] """ x = x + self.pe[:, : x.size(1)] return self.dropout(x) class TransformerLayer(nn.Module): def __init__( self, input_dim, num_heads, dropout: Optional[float] = 0.1, positional_encoding: bool = True, ): super(TransformerLayer, self).__init__() self.positional_encoding = positional_encoding self.input_dim = input_dim self.num_heads = num_heads self.head_dim = input_dim // num_heads self.dropout = dropout # initialize layers self.self_attention = MultiHeadAttention(input_dim, num_heads, dropout=dropout) self.feed_forward = PositionWiseFeedForward( input_dim, input_dim, dropout=dropout ) self.add_norm_after_attention = Add_and_Norm(input_dim, dropout=dropout) self.add_norm_after_ff = Add_and_Norm(input_dim, dropout=dropout) # calculate positional encoding if self.positional_encoding: self.positional_encoding = PositionalEncoding(input_dim) def forward(self, key, value, query, mask=None): # key, value, and query shapes: [batch_size, seq_len, input_dim] # positional encoding if self.positional_encoding: key = self.positional_encoding(key) value = self.positional_encoding(value) query = self.positional_encoding(query) # multi-head attention residual = query x = self.self_attention(queries=query, keys=key, values=value, mask=mask) x = self.add_norm_after_attention(x, residual) # feed forward residual = x x = self.feed_forward(x) x = self.add_norm_after_ff(x, residual) return x