abhishekrs4's picture
minor fixes in model to avoid downloading pretrained weights for resnet backbone model
e56f0fc
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from model_visual_features import ResNetFeatureExtractor, TPS_SpatialTransformerNetwork
class HW_RNN_Seq2Seq(nn.Module):
"""
Visual Seq2Seq model using BiLSTM
"""
def __init__(
self,
num_classes,
image_height,
cnn_output_channels=512,
num_feats_mapped_seq_hidden=128,
num_feats_seq_hidden=256,
):
"""
---------
Arguments
---------
num_classes : int
num of distinct characters (classes) in the dataset
image_height : int
image height
cnn_output_channels : int
number of channels output from the CNN visual feature extractor (default: 512)
num_feats_mapped_seq_hidden : int
number of features to be used in the mapped visual features as sequences (default: 128)
num_feats_seq_hidden : int
number of features to be used in the LSTM for sequence modeling (default: 256)
"""
super().__init__()
self.output_height = image_height // 32
self.dropout = nn.Dropout(p=0.25)
self.map_visual_to_seq = nn.Linear(
cnn_output_channels * self.output_height, num_feats_mapped_seq_hidden
)
self.b_lstm_1 = nn.LSTM(
num_feats_mapped_seq_hidden, num_feats_seq_hidden, bidirectional=True
)
self.b_lstm_2 = nn.LSTM(
2 * num_feats_seq_hidden, num_feats_seq_hidden, bidirectional=True
)
self.final_dense = nn.Linear(2 * num_feats_seq_hidden, num_classes)
def forward(self, visual_feats):
visual_feats = visual_feats.permute(3, 0, 1, 2)
# WBCH
# the sequence is along the width of the image as a sentence
visual_feats = visual_feats.contiguous().view(
visual_feats.shape[0], visual_feats.shape[1], -1
)
# WBC
seq = self.map_visual_to_seq(visual_feats)
seq = self.dropout(seq)
lstm_1, _ = self.b_lstm_1(seq)
lstm_2, _ = self.b_lstm_2(lstm_1)
lstm_2 = self.dropout(lstm_2)
dense_output = self.final_dense(lstm_2)
# [seq_len, B, num_classes]
log_probs = F.log_softmax(dense_output, dim=2)
return log_probs
class CRNN(nn.Module):
"""
Hybrid CNN - RNN model
CNN - Modified ResNet34 for visual features
RNN - BiLSTM for seq2seq modeling
"""
def __init__(
self,
num_classes,
image_height,
num_feats_mapped_seq_hidden=128,
num_feats_seq_hidden=256,
pretrained=False,
):
"""
---------
Arguments
---------
num_classes : int
num of distinct characters (classes) in the dataset
image_height : int
image height
num_feats_mapped_seq_hidden : int
number of features to be used in the mapped visual features as sequences (default: 128)
num_feats_seq_hidden : int
number of features to be used in the LSTM for sequence modeling (default: 256)
"""
super().__init__()
self.visual_feature_extractor = ResNetFeatureExtractor(pretrained=pretrained)
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(
num_classes,
image_height,
self.visual_feature_extractor.output_channels,
num_feats_mapped_seq_hidden,
num_feats_seq_hidden,
)
def forward(self, x):
visual_feats = self.visual_feature_extractor(x)
# [B, 512, H/32, W/32]
log_probs = self.rnn_seq2seq_module(visual_feats)
return log_probs
class STN_CRNN(nn.Module):
"""
STN + CNN + RNN model
STN - Spatial Transformer Network for learning variable handwriting
CNN - Modified ResNet34 for visual features
RNN - BiLSTM for seq2seq modeling
"""
def __init__(
self,
num_classes,
image_height,
image_width,
num_feats_mapped_seq_hidden=128,
num_feats_seq_hidden=256,
pretrained=False,
):
"""
---------
Arguments
---------
num_classes : int
num of distinct characters (classes) in the dataset
image_height : int
image height
image_width : int
image width
num_feats_mapped_seq_hidden : int
number of features to be used in the mapped visual features as sequences (default: 128)
num_feats_seq_hidden : int
number of features to be used in the LSTM for sequence modeling (default: 256)
"""
super().__init__()
self.stn = TPS_SpatialTransformerNetwork(
80,
(image_height, image_width),
(image_height, image_width),
I_channel_num=3,
)
self.visual_feature_extractor = ResNetFeatureExtractor(pretrained=pretrained)
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(
num_classes,
image_height,
self.visual_feature_extractor.output_channels,
num_feats_mapped_seq_hidden,
num_feats_seq_hidden,
)
def forward(self, x):
stn_output = self.stn(x)
visual_feats = self.visual_feature_extractor(stn_output)
log_probs = self.rnn_seq2seq_module(visual_feats)
return log_probs
"""
class STN_PP_CRNN(nn.Module):
def __init__(self, num_classes, image_height, image_width, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
super().__init__()
self.stn = TPS_SpatialTransformerNetwork(
20,
(image_height, image_width),
(image_height, image_width),
I_channel_num=3,
)
self.visual_feature_extractor = ResNetFeatureExtractor()
self.pp_block = PyramidPoolBlock(num_channels=self.visual_feature_extractor.output_channels)
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(num_classes, image_height, self.visual_feature_extractor.output_channels, num_feats_mapped_seq_hidden, num_feats_seq_hidden)
def forward(self, x):
stn_output = self.stn(x)
visual_feats = self.visual_feature_extractor(stn_output)
pp_feats = self.pp_block(visual_feats)
log_probs = self.rnn_seq2seq_module(pp_feats)
return log_probs
"""