Spaces:
Sleeping
Sleeping
File size: 6,063 Bytes
bd421ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from model_visual_features import ResNetFeatureExtractor, TPS_SpatialTransformerNetwork
class HW_RNN_Seq2Seq(nn.Module):
"""
Visual Seq2Seq model using BiLSTM
"""
def __init__(self, num_classes, image_height, cnn_output_channels=512, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
"""
---------
Arguments
---------
num_classes : int
num of distinct characters (classes) in the dataset
image_height : int
image height
cnn_output_channels : int
number of channels output from the CNN visual feature extractor (default: 512)
num_feats_mapped_seq_hidden : int
number of features to be used in the mapped visual features as sequences (default: 128)
num_feats_seq_hidden : int
number of features to be used in the LSTM for sequence modeling (default: 256)
"""
super().__init__()
self.output_height = image_height // 32
self.dropout = nn.Dropout(p=0.25)
self.map_visual_to_seq = nn.Linear(cnn_output_channels * self.output_height, num_feats_mapped_seq_hidden)
self.b_lstm_1 = nn.LSTM(num_feats_mapped_seq_hidden, num_feats_seq_hidden, bidirectional=True)
self.b_lstm_2 = nn.LSTM(2 * num_feats_seq_hidden, num_feats_seq_hidden, bidirectional=True)
self.final_dense = nn.Linear(2 * num_feats_seq_hidden, num_classes)
def forward(self, visual_feats):
visual_feats = visual_feats.permute(3, 0, 1, 2)
# WBCH
# the sequence is along the width of the image as a sentence
visual_feats = visual_feats.contiguous().view(visual_feats.shape[0], visual_feats.shape[1], -1)
# WBC
seq = self.map_visual_to_seq(visual_feats)
seq = self.dropout(seq)
lstm_1, _ = self.b_lstm_1(seq)
lstm_2, _ = self.b_lstm_2(lstm_1)
lstm_2 = self.dropout(lstm_2)
dense_output = self.final_dense(lstm_2)
# [seq_len, B, num_classes]
log_probs = F.log_softmax(dense_output, dim=2)
return log_probs
class CRNN(nn.Module):
"""
Hybrid CNN - RNN model
CNN - Modified ResNet34 for visual features
RNN - BiLSTM for seq2seq modeling
"""
def __init__(self, num_classes, image_height, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
"""
---------
Arguments
---------
num_classes : int
num of distinct characters (classes) in the dataset
image_height : int
image height
num_feats_mapped_seq_hidden : int
number of features to be used in the mapped visual features as sequences (default: 128)
num_feats_seq_hidden : int
number of features to be used in the LSTM for sequence modeling (default: 256)
"""
super().__init__()
self.visual_feature_extractor = ResNetFeatureExtractor()
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(num_classes, image_height, self.visual_feature_extractor.output_channels, num_feats_mapped_seq_hidden, num_feats_seq_hidden)
def forward(self, x):
visual_feats = self.visual_feature_extractor(x)
# [B, 512, H/32, W/32]
log_probs = self.rnn_seq2seq_module(visual_feats)
return log_probs
class STN_CRNN(nn.Module):
"""
STN + CNN + RNN model
STN - Spatial Transformer Network for learning variable handwriting
CNN - Modified ResNet34 for visual features
RNN - BiLSTM for seq2seq modeling
"""
def __init__(self, num_classes, image_height, image_width, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
"""
---------
Arguments
---------
num_classes : int
num of distinct characters (classes) in the dataset
image_height : int
image height
image_width : int
image width
num_feats_mapped_seq_hidden : int
number of features to be used in the mapped visual features as sequences (default: 128)
num_feats_seq_hidden : int
number of features to be used in the LSTM for sequence modeling (default: 256)
"""
super().__init__()
self.stn = TPS_SpatialTransformerNetwork(
80,
(image_height, image_width),
(image_height, image_width),
I_channel_num=3,
)
self.visual_feature_extractor = ResNetFeatureExtractor()
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(num_classes, image_height, self.visual_feature_extractor.output_channels, num_feats_mapped_seq_hidden, num_feats_seq_hidden)
def forward(self, x):
stn_output = self.stn(x)
visual_feats = self.visual_feature_extractor(stn_output)
log_probs = self.rnn_seq2seq_module(visual_feats)
return log_probs
"""
class STN_PP_CRNN(nn.Module):
def __init__(self, num_classes, image_height, image_width, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
super().__init__()
self.stn = TPS_SpatialTransformerNetwork(
20,
(image_height, image_width),
(image_height, image_width),
I_channel_num=3,
)
self.visual_feature_extractor = ResNetFeatureExtractor()
self.pp_block = PyramidPoolBlock(num_channels=self.visual_feature_extractor.output_channels)
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(num_classes, image_height, self.visual_feature_extractor.output_channels, num_feats_mapped_seq_hidden, num_feats_seq_hidden)
def forward(self, x):
stn_output = self.stn(x)
visual_feats = self.visual_feature_extractor(stn_output)
pp_feats = self.pp_block(visual_feats)
log_probs = self.rnn_seq2seq_module(pp_feats)
return log_probs
"""
|