Commit
•
e56f0fc
1
Parent(s):
44066b7
minor fixes in model to avoid downloading pretrained weights for resnet backbone model
Browse files
iam_line_recognition/model_main.py
CHANGED
@@ -87,6 +87,7 @@ class CRNN(nn.Module):
|
|
87 |
image_height,
|
88 |
num_feats_mapped_seq_hidden=128,
|
89 |
num_feats_seq_hidden=256,
|
|
|
90 |
):
|
91 |
"""
|
92 |
---------
|
@@ -102,7 +103,7 @@ class CRNN(nn.Module):
|
|
102 |
number of features to be used in the LSTM for sequence modeling (default: 256)
|
103 |
"""
|
104 |
super().__init__()
|
105 |
-
self.visual_feature_extractor = ResNetFeatureExtractor()
|
106 |
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(
|
107 |
num_classes,
|
108 |
image_height,
|
@@ -134,6 +135,7 @@ class STN_CRNN(nn.Module):
|
|
134 |
image_width,
|
135 |
num_feats_mapped_seq_hidden=128,
|
136 |
num_feats_seq_hidden=256,
|
|
|
137 |
):
|
138 |
"""
|
139 |
---------
|
@@ -157,7 +159,7 @@ class STN_CRNN(nn.Module):
|
|
157 |
(image_height, image_width),
|
158 |
I_channel_num=3,
|
159 |
)
|
160 |
-
self.visual_feature_extractor = ResNetFeatureExtractor()
|
161 |
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(
|
162 |
num_classes,
|
163 |
image_height,
|
|
|
87 |
image_height,
|
88 |
num_feats_mapped_seq_hidden=128,
|
89 |
num_feats_seq_hidden=256,
|
90 |
+
pretrained=False,
|
91 |
):
|
92 |
"""
|
93 |
---------
|
|
|
103 |
number of features to be used in the LSTM for sequence modeling (default: 256)
|
104 |
"""
|
105 |
super().__init__()
|
106 |
+
self.visual_feature_extractor = ResNetFeatureExtractor(pretrained=pretrained)
|
107 |
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(
|
108 |
num_classes,
|
109 |
image_height,
|
|
|
135 |
image_width,
|
136 |
num_feats_mapped_seq_hidden=128,
|
137 |
num_feats_seq_hidden=256,
|
138 |
+
pretrained=False,
|
139 |
):
|
140 |
"""
|
141 |
---------
|
|
|
159 |
(image_height, image_width),
|
160 |
I_channel_num=3,
|
161 |
)
|
162 |
+
self.visual_feature_extractor = ResNetFeatureExtractor(pretrained=pretrained)
|
163 |
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(
|
164 |
num_classes,
|
165 |
image_height,
|