abhishekrs4 commited on
Commit
e56f0fc
1 Parent(s): 44066b7

minor fixes in model to avoid downloading pretrained weights for resnet backbone model

Browse files
Files changed (1) hide show
  1. iam_line_recognition/model_main.py +4 -2
iam_line_recognition/model_main.py CHANGED
@@ -87,6 +87,7 @@ class CRNN(nn.Module):
87
  image_height,
88
  num_feats_mapped_seq_hidden=128,
89
  num_feats_seq_hidden=256,
 
90
  ):
91
  """
92
  ---------
@@ -102,7 +103,7 @@ class CRNN(nn.Module):
102
  number of features to be used in the LSTM for sequence modeling (default: 256)
103
  """
104
  super().__init__()
105
- self.visual_feature_extractor = ResNetFeatureExtractor()
106
  self.rnn_seq2seq_module = HW_RNN_Seq2Seq(
107
  num_classes,
108
  image_height,
@@ -134,6 +135,7 @@ class STN_CRNN(nn.Module):
134
  image_width,
135
  num_feats_mapped_seq_hidden=128,
136
  num_feats_seq_hidden=256,
 
137
  ):
138
  """
139
  ---------
@@ -157,7 +159,7 @@ class STN_CRNN(nn.Module):
157
  (image_height, image_width),
158
  I_channel_num=3,
159
  )
160
- self.visual_feature_extractor = ResNetFeatureExtractor()
161
  self.rnn_seq2seq_module = HW_RNN_Seq2Seq(
162
  num_classes,
163
  image_height,
 
87
  image_height,
88
  num_feats_mapped_seq_hidden=128,
89
  num_feats_seq_hidden=256,
90
+ pretrained=False,
91
  ):
92
  """
93
  ---------
 
103
  number of features to be used in the LSTM for sequence modeling (default: 256)
104
  """
105
  super().__init__()
106
+ self.visual_feature_extractor = ResNetFeatureExtractor(pretrained=pretrained)
107
  self.rnn_seq2seq_module = HW_RNN_Seq2Seq(
108
  num_classes,
109
  image_height,
 
135
  image_width,
136
  num_feats_mapped_seq_hidden=128,
137
  num_feats_seq_hidden=256,
138
+ pretrained=False,
139
  ):
140
  """
141
  ---------
 
159
  (image_height, image_width),
160
  I_channel_num=3,
161
  )
162
+ self.visual_feature_extractor = ResNetFeatureExtractor(pretrained=pretrained)
163
  self.rnn_seq2seq_module = HW_RNN_Seq2Seq(
164
  num_classes,
165
  image_height,