Spaces:
Sleeping
Sleeping
File size: 6,596 Bytes
bd421ea 44066b7 bd421ea 44066b7 bd421ea 44066b7 bd421ea 44066b7 bd421ea 44066b7 bd421ea 44066b7 e56f0fc 44066b7 bd421ea e56f0fc 44066b7 bd421ea 44066b7 e56f0fc 44066b7 bd421ea e56f0fc 44066b7 bd421ea 44066b7 bd421ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from model_visual_features import ResNetFeatureExtractor, TPS_SpatialTransformerNetwork
class HW_RNN_Seq2Seq(nn.Module):
"""
Visual Seq2Seq model using BiLSTM
"""
def __init__(
self,
num_classes,
image_height,
cnn_output_channels=512,
num_feats_mapped_seq_hidden=128,
num_feats_seq_hidden=256,
):
"""
---------
Arguments
---------
num_classes : int
num of distinct characters (classes) in the dataset
image_height : int
image height
cnn_output_channels : int
number of channels output from the CNN visual feature extractor (default: 512)
num_feats_mapped_seq_hidden : int
number of features to be used in the mapped visual features as sequences (default: 128)
num_feats_seq_hidden : int
number of features to be used in the LSTM for sequence modeling (default: 256)
"""
super().__init__()
self.output_height = image_height // 32
self.dropout = nn.Dropout(p=0.25)
self.map_visual_to_seq = nn.Linear(
cnn_output_channels * self.output_height, num_feats_mapped_seq_hidden
)
self.b_lstm_1 = nn.LSTM(
num_feats_mapped_seq_hidden, num_feats_seq_hidden, bidirectional=True
)
self.b_lstm_2 = nn.LSTM(
2 * num_feats_seq_hidden, num_feats_seq_hidden, bidirectional=True
)
self.final_dense = nn.Linear(2 * num_feats_seq_hidden, num_classes)
def forward(self, visual_feats):
visual_feats = visual_feats.permute(3, 0, 1, 2)
# WBCH
# the sequence is along the width of the image as a sentence
visual_feats = visual_feats.contiguous().view(
visual_feats.shape[0], visual_feats.shape[1], -1
)
# WBC
seq = self.map_visual_to_seq(visual_feats)
seq = self.dropout(seq)
lstm_1, _ = self.b_lstm_1(seq)
lstm_2, _ = self.b_lstm_2(lstm_1)
lstm_2 = self.dropout(lstm_2)
dense_output = self.final_dense(lstm_2)
# [seq_len, B, num_classes]
log_probs = F.log_softmax(dense_output, dim=2)
return log_probs
class CRNN(nn.Module):
"""
Hybrid CNN - RNN model
CNN - Modified ResNet34 for visual features
RNN - BiLSTM for seq2seq modeling
"""
def __init__(
self,
num_classes,
image_height,
num_feats_mapped_seq_hidden=128,
num_feats_seq_hidden=256,
pretrained=False,
):
"""
---------
Arguments
---------
num_classes : int
num of distinct characters (classes) in the dataset
image_height : int
image height
num_feats_mapped_seq_hidden : int
number of features to be used in the mapped visual features as sequences (default: 128)
num_feats_seq_hidden : int
number of features to be used in the LSTM for sequence modeling (default: 256)
"""
super().__init__()
self.visual_feature_extractor = ResNetFeatureExtractor(pretrained=pretrained)
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(
num_classes,
image_height,
self.visual_feature_extractor.output_channels,
num_feats_mapped_seq_hidden,
num_feats_seq_hidden,
)
def forward(self, x):
visual_feats = self.visual_feature_extractor(x)
# [B, 512, H/32, W/32]
log_probs = self.rnn_seq2seq_module(visual_feats)
return log_probs
class STN_CRNN(nn.Module):
"""
STN + CNN + RNN model
STN - Spatial Transformer Network for learning variable handwriting
CNN - Modified ResNet34 for visual features
RNN - BiLSTM for seq2seq modeling
"""
def __init__(
self,
num_classes,
image_height,
image_width,
num_feats_mapped_seq_hidden=128,
num_feats_seq_hidden=256,
pretrained=False,
):
"""
---------
Arguments
---------
num_classes : int
num of distinct characters (classes) in the dataset
image_height : int
image height
image_width : int
image width
num_feats_mapped_seq_hidden : int
number of features to be used in the mapped visual features as sequences (default: 128)
num_feats_seq_hidden : int
number of features to be used in the LSTM for sequence modeling (default: 256)
"""
super().__init__()
self.stn = TPS_SpatialTransformerNetwork(
80,
(image_height, image_width),
(image_height, image_width),
I_channel_num=3,
)
self.visual_feature_extractor = ResNetFeatureExtractor(pretrained=pretrained)
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(
num_classes,
image_height,
self.visual_feature_extractor.output_channels,
num_feats_mapped_seq_hidden,
num_feats_seq_hidden,
)
def forward(self, x):
stn_output = self.stn(x)
visual_feats = self.visual_feature_extractor(stn_output)
log_probs = self.rnn_seq2seq_module(visual_feats)
return log_probs
"""
class STN_PP_CRNN(nn.Module):
def __init__(self, num_classes, image_height, image_width, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
super().__init__()
self.stn = TPS_SpatialTransformerNetwork(
20,
(image_height, image_width),
(image_height, image_width),
I_channel_num=3,
)
self.visual_feature_extractor = ResNetFeatureExtractor()
self.pp_block = PyramidPoolBlock(num_channels=self.visual_feature_extractor.output_channels)
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(num_classes, image_height, self.visual_feature_extractor.output_channels, num_feats_mapped_seq_hidden, num_feats_seq_hidden)
def forward(self, x):
stn_output = self.stn(x)
visual_feats = self.visual_feature_extractor(stn_output)
pp_feats = self.pp_block(visual_feats)
log_probs = self.rnn_seq2seq_module(pp_feats)
return log_probs
"""
|