File size: 6,063 Bytes
bd421ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torchvision
import torch.nn as nn
import torch.nn.functional as F

from model_visual_features import ResNetFeatureExtractor, TPS_SpatialTransformerNetwork

class HW_RNN_Seq2Seq(nn.Module):
    """

    Visual Seq2Seq model using BiLSTM

    """
    def __init__(self, num_classes, image_height, cnn_output_channels=512, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
        """

        ---------

        Arguments

        ---------

        num_classes : int

            num of distinct characters (classes) in the dataset

        image_height : int

            image height

        cnn_output_channels : int

            number of channels output from the CNN visual feature extractor (default: 512)

        num_feats_mapped_seq_hidden : int

            number of features to be used in the mapped visual features as sequences (default: 128)

        num_feats_seq_hidden : int

            number of features to be used in the LSTM for sequence modeling (default: 256)

        """
        super().__init__()
        self.output_height = image_height // 32

        self.dropout = nn.Dropout(p=0.25)
        self.map_visual_to_seq = nn.Linear(cnn_output_channels * self.output_height, num_feats_mapped_seq_hidden)

        self.b_lstm_1 = nn.LSTM(num_feats_mapped_seq_hidden, num_feats_seq_hidden, bidirectional=True)
        self.b_lstm_2 = nn.LSTM(2 * num_feats_seq_hidden, num_feats_seq_hidden, bidirectional=True)

        self.final_dense = nn.Linear(2 * num_feats_seq_hidden, num_classes)

    def forward(self, visual_feats):
        visual_feats = visual_feats.permute(3, 0, 1, 2)
        # WBCH
        # the sequence is along the width of the image as a sentence

        visual_feats = visual_feats.contiguous().view(visual_feats.shape[0], visual_feats.shape[1], -1)
        # WBC

        seq = self.map_visual_to_seq(visual_feats)
        seq = self.dropout(seq)
        lstm_1, _ = self.b_lstm_1(seq)
        lstm_2, _ = self.b_lstm_2(lstm_1)
        lstm_2 = self.dropout(lstm_2)

        dense_output = self.final_dense(lstm_2)
        # [seq_len, B, num_classes]

        log_probs = F.log_softmax(dense_output, dim=2)

        return log_probs


class CRNN(nn.Module):
    """

    Hybrid CNN - RNN model

    CNN - Modified ResNet34 for visual features

    RNN - BiLSTM for seq2seq modeling

    """
    def __init__(self, num_classes, image_height, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
        """

        ---------

        Arguments

        ---------

        num_classes : int

            num of distinct characters (classes) in the dataset

        image_height : int

            image height

        num_feats_mapped_seq_hidden : int

            number of features to be used in the mapped visual features as sequences (default: 128)

        num_feats_seq_hidden : int

            number of features to be used in the LSTM for sequence modeling (default: 256)

        """
        super().__init__()
        self.visual_feature_extractor = ResNetFeatureExtractor()
        self.rnn_seq2seq_module = HW_RNN_Seq2Seq(num_classes, image_height, self.visual_feature_extractor.output_channels, num_feats_mapped_seq_hidden, num_feats_seq_hidden)

    def forward(self, x):
        visual_feats = self.visual_feature_extractor(x)
        # [B, 512, H/32, W/32]

        log_probs = self.rnn_seq2seq_module(visual_feats)
        return log_probs


class STN_CRNN(nn.Module):
    """

    STN + CNN + RNN model

    STN - Spatial Transformer Network for learning variable handwriting

    CNN - Modified ResNet34 for visual features

    RNN - BiLSTM for seq2seq modeling

    """
    def __init__(self, num_classes, image_height, image_width, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
        """

        ---------

        Arguments

        ---------

        num_classes : int

            num of distinct characters (classes) in the dataset

        image_height : int

            image height

        image_width : int

            image width

        num_feats_mapped_seq_hidden : int

            number of features to be used in the mapped visual features as sequences (default: 128)

        num_feats_seq_hidden : int

            number of features to be used in the LSTM for sequence modeling (default: 256)

        """
        super().__init__()
        self.stn = TPS_SpatialTransformerNetwork(
            80,
            (image_height, image_width),
            (image_height, image_width),
            I_channel_num=3,
        )
        self.visual_feature_extractor = ResNetFeatureExtractor()
        self.rnn_seq2seq_module = HW_RNN_Seq2Seq(num_classes, image_height, self.visual_feature_extractor.output_channels, num_feats_mapped_seq_hidden, num_feats_seq_hidden)

    def forward(self, x):
        stn_output = self.stn(x)
        visual_feats = self.visual_feature_extractor(stn_output)
        log_probs = self.rnn_seq2seq_module(visual_feats)
        return log_probs

"""

class STN_PP_CRNN(nn.Module):

    def __init__(self, num_classes, image_height, image_width, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):

        super().__init__()

        self.stn = TPS_SpatialTransformerNetwork(

            20,

            (image_height, image_width),

            (image_height, image_width),

            I_channel_num=3,

        )

        self.visual_feature_extractor = ResNetFeatureExtractor()

        self.pp_block = PyramidPoolBlock(num_channels=self.visual_feature_extractor.output_channels)

        self.rnn_seq2seq_module = HW_RNN_Seq2Seq(num_classes, image_height, self.visual_feature_extractor.output_channels, num_feats_mapped_seq_hidden, num_feats_seq_hidden)



    def forward(self, x):

        stn_output = self.stn(x)

        visual_feats = self.visual_feature_extractor(stn_output)

        pp_feats = self.pp_block(visual_feats)

        log_probs = self.rnn_seq2seq_module(pp_feats)

        return log_probs

"""