abhishekrs4 commited on
Commit
44066b7
1 Parent(s): bd1dc81

code formatting

Browse files
app.py CHANGED
@@ -26,14 +26,10 @@ hw_recog_model = CRNN(num_classes, image_height)
26
 
27
  try:
28
  logging.info(f"loading model from {file_model_local}")
29
- hw_recog_model.load_state_dict(
30
- torch.load(file_model_local, map_location=device)
31
- )
32
  except:
33
  logging.info(f"loading model from {file_model_cont}")
34
- hw_recog_model.load_state_dict(
35
- torch.load(file_model_cont, map_location=device)
36
- )
37
  hw_recog_model.to(device)
38
  hw_recog_model.eval()
39
 
@@ -51,6 +47,7 @@ def predict_hw(img_test: np.ndarray) -> str:
51
  str_pred = "".join(str_pred)
52
  return str_pred
53
 
 
54
  @app.route("/predict", methods=["POST"])
55
  def predict() -> Response:
56
  logging.info("IAM Handwriting recognition app")
@@ -62,7 +59,7 @@ def predict() -> Response:
62
  img_dec = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
63
  img_dec = cv2.cvtColor(img_dec, cv2.COLOR_BGR2RGB)
64
 
65
- img_dec = cv2.resize(img_dec, (768, 32), interpolation = cv2.INTER_LINEAR)
66
 
67
  str_pred = predict_hw(img_dec)
68
 
@@ -77,5 +74,6 @@ def predict() -> Response:
77
  json_pred = jsonify({"error": str(e)})
78
  return json_pred
79
 
 
80
  if __name__ == "__main__":
81
  app.run(host="0.0.0.0", debug=True, port=7860)
 
26
 
27
  try:
28
  logging.info(f"loading model from {file_model_local}")
29
+ hw_recog_model.load_state_dict(torch.load(file_model_local, map_location=device))
 
 
30
  except:
31
  logging.info(f"loading model from {file_model_cont}")
32
+ hw_recog_model.load_state_dict(torch.load(file_model_cont, map_location=device))
 
 
33
  hw_recog_model.to(device)
34
  hw_recog_model.eval()
35
 
 
47
  str_pred = "".join(str_pred)
48
  return str_pred
49
 
50
+
51
  @app.route("/predict", methods=["POST"])
52
  def predict() -> Response:
53
  logging.info("IAM Handwriting recognition app")
 
59
  img_dec = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
60
  img_dec = cv2.cvtColor(img_dec, cv2.COLOR_BGR2RGB)
61
 
62
+ img_dec = cv2.resize(img_dec, (768, 32), interpolation=cv2.INTER_LINEAR)
63
 
64
  str_pred = predict_hw(img_dec)
65
 
 
74
  json_pred = jsonify({"error": str(e)})
75
  return json_pred
76
 
77
+
78
  if __name__ == "__main__":
79
  app.run(host="0.0.0.0", debug=True, port=7860)
iam_line_recognition/dataset.py CHANGED
@@ -8,6 +8,7 @@ import torchvision.transforms as transforms
8
  from torch.utils.data import Dataset, DataLoader
9
  from sklearn.model_selection import train_test_split
10
 
 
11
  def read_IAM_label_txt_file(file_txt_labels):
12
  """
13
  ---------
@@ -42,15 +43,25 @@ def read_IAM_label_txt_file(file_txt_labels):
42
 
43
  return all_image_files, all_labels
44
 
 
45
  class HWRecogIAMDataset(Dataset):
46
  """
47
  Main dataset class to be used only for training, validation and internal testing
48
  """
49
- CHAR_SET = ' !"#&\'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
 
50
  CHAR_2_LABEL = {char: i + 1 for i, char in enumerate(CHAR_SET)}
51
  LABEL_2_CHAR = {label: char for char, label in CHAR_2_LABEL.items()}
52
 
53
- def __init__(self, list_image_files, list_labels, dir_images, image_height=32, image_width=768, which_set="train"):
 
 
 
 
 
 
 
 
54
  """
55
  ---------
56
  Arguments
@@ -77,28 +88,41 @@ class HWRecogIAMDataset(Dataset):
77
 
78
  if self.which_set == "train":
79
  # apply data augmentation only for train set
80
- self.transform = transforms.Compose([
81
- transforms.ToPILImage(),
82
- transforms.Resize((self.image_height, self.image_width), Image.BILINEAR),
83
- transforms.RandomAffine(degrees=[-0.75, 0.75], translate=[0, 0.05], scale=[0.75, 1],
84
- shear=[-35, 35], interpolation=transforms.InterpolationMode.BILINEAR, fill=255,
85
- ),
86
- transforms.ToTensor(),
87
- transforms.Normalize(
88
- mean=[0.485, 0.456, 0.406],
89
- std=[0.229, 0.224, 0.225],
90
- ),
91
- ])
 
 
 
 
 
 
 
 
 
92
  else:
93
- self.transform = transforms.Compose([
94
- transforms.ToPILImage(),
95
- transforms.Resize((self.image_height, self.image_width), Image.BILINEAR),
96
- transforms.ToTensor(),
97
- transforms.Normalize(
98
- mean=[0.485, 0.456, 0.406],
99
- std=[0.229, 0.224, 0.225],
100
- ),
101
- ])
 
 
 
 
102
 
103
  def __len__(self):
104
  return len(self.list_image_files)
@@ -118,6 +142,7 @@ class HWRecogIAMDataset(Dataset):
118
 
119
  return image_3_channel, label_encoded, label_length
120
 
 
121
  def IAM_collate_fn(batch):
122
  """
123
  collate function
@@ -145,6 +170,7 @@ def IAM_collate_fn(batch):
145
  label_lengths = torch.cat(label_lengths, 0)
146
  return images, labels, label_lengths
147
 
 
148
  def split_dataset(file_txt_labels, for_train=True):
149
  """
150
  ---------
@@ -161,14 +187,28 @@ def split_dataset(file_txt_labels, for_train=True):
161
  a tuple of files depending for train or internal testing
162
  """
163
  all_image_files, all_labels = read_IAM_label_txt_file(file_txt_labels)
164
- train_image_files, test_image_files, train_labels, test_labels = train_test_split(all_image_files, all_labels, test_size=0.1, random_state=4)
165
- train_image_files, valid_image_files, train_labels, valid_labels = train_test_split(train_image_files, train_labels, test_size=0.1, random_state=4)
 
 
 
 
166
  if for_train:
167
  return train_image_files, valid_image_files, train_labels, valid_labels
168
  else:
169
  return test_image_files, test_labels
170
 
171
- def get_dataloaders_for_training(train_x, train_y, valid_x, valid_y, dir_images, image_height=32, image_width=768, batch_size=8):
 
 
 
 
 
 
 
 
 
 
172
  """
173
  ---------
174
  Arguments
@@ -199,8 +239,22 @@ def get_dataloaders_for_training(train_x, train_y, valid_x, valid_y, dir_images,
199
  valid_loader : object
200
  object of validation set dataloader
201
  """
202
- train_dataset = HWRecogIAMDataset(train_x, train_y, dir_images, image_height=image_height, image_width=image_width, which_set="train")
203
- valid_dataset = HWRecogIAMDataset(valid_x, valid_y, dir_images, image_height=image_height, image_width=image_width, which_set="valid")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  train_loader = DataLoader(
206
  train_dataset,
@@ -218,7 +272,10 @@ def get_dataloaders_for_training(train_x, train_y, valid_x, valid_y, dir_images,
218
  )
219
  return train_loader, valid_loader
220
 
221
- def get_dataloader_for_testing(test_x, test_y, dir_images, image_height=32, image_width=768, batch_size=1):
 
 
 
222
  """
223
  ---------
224
  Arguments
@@ -242,7 +299,14 @@ def get_dataloader_for_testing(test_x, test_y, dir_images, image_height=32, imag
242
  test_loader : object
243
  object of test set dataloader
244
  """
245
- test_dataset = HWRecogIAMDataset(test_x, test_y, dir_images=dir_images, image_height=image_height, image_width=image_width, which_set="test")
 
 
 
 
 
 
 
246
  test_loader = DataLoader(
247
  test_dataset,
248
  batch_size=batch_size,
 
8
  from torch.utils.data import Dataset, DataLoader
9
  from sklearn.model_selection import train_test_split
10
 
11
+
12
  def read_IAM_label_txt_file(file_txt_labels):
13
  """
14
  ---------
 
43
 
44
  return all_image_files, all_labels
45
 
46
+
47
  class HWRecogIAMDataset(Dataset):
48
  """
49
  Main dataset class to be used only for training, validation and internal testing
50
  """
51
+
52
+ CHAR_SET = " !\"#&'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
53
  CHAR_2_LABEL = {char: i + 1 for i, char in enumerate(CHAR_SET)}
54
  LABEL_2_CHAR = {label: char for char, label in CHAR_2_LABEL.items()}
55
 
56
+ def __init__(
57
+ self,
58
+ list_image_files,
59
+ list_labels,
60
+ dir_images,
61
+ image_height=32,
62
+ image_width=768,
63
+ which_set="train",
64
+ ):
65
  """
66
  ---------
67
  Arguments
 
88
 
89
  if self.which_set == "train":
90
  # apply data augmentation only for train set
91
+ self.transform = transforms.Compose(
92
+ [
93
+ transforms.ToPILImage(),
94
+ transforms.Resize(
95
+ (self.image_height, self.image_width), Image.BILINEAR
96
+ ),
97
+ transforms.RandomAffine(
98
+ degrees=[-0.75, 0.75],
99
+ translate=[0, 0.05],
100
+ scale=[0.75, 1],
101
+ shear=[-35, 35],
102
+ interpolation=transforms.InterpolationMode.BILINEAR,
103
+ fill=255,
104
+ ),
105
+ transforms.ToTensor(),
106
+ transforms.Normalize(
107
+ mean=[0.485, 0.456, 0.406],
108
+ std=[0.229, 0.224, 0.225],
109
+ ),
110
+ ]
111
+ )
112
  else:
113
+ self.transform = transforms.Compose(
114
+ [
115
+ transforms.ToPILImage(),
116
+ transforms.Resize(
117
+ (self.image_height, self.image_width), Image.BILINEAR
118
+ ),
119
+ transforms.ToTensor(),
120
+ transforms.Normalize(
121
+ mean=[0.485, 0.456, 0.406],
122
+ std=[0.229, 0.224, 0.225],
123
+ ),
124
+ ]
125
+ )
126
 
127
  def __len__(self):
128
  return len(self.list_image_files)
 
142
 
143
  return image_3_channel, label_encoded, label_length
144
 
145
+
146
  def IAM_collate_fn(batch):
147
  """
148
  collate function
 
170
  label_lengths = torch.cat(label_lengths, 0)
171
  return images, labels, label_lengths
172
 
173
+
174
  def split_dataset(file_txt_labels, for_train=True):
175
  """
176
  ---------
 
187
  a tuple of files depending for train or internal testing
188
  """
189
  all_image_files, all_labels = read_IAM_label_txt_file(file_txt_labels)
190
+ train_image_files, test_image_files, train_labels, test_labels = train_test_split(
191
+ all_image_files, all_labels, test_size=0.1, random_state=4
192
+ )
193
+ train_image_files, valid_image_files, train_labels, valid_labels = train_test_split(
194
+ train_image_files, train_labels, test_size=0.1, random_state=4
195
+ )
196
  if for_train:
197
  return train_image_files, valid_image_files, train_labels, valid_labels
198
  else:
199
  return test_image_files, test_labels
200
 
201
+
202
+ def get_dataloaders_for_training(
203
+ train_x,
204
+ train_y,
205
+ valid_x,
206
+ valid_y,
207
+ dir_images,
208
+ image_height=32,
209
+ image_width=768,
210
+ batch_size=8,
211
+ ):
212
  """
213
  ---------
214
  Arguments
 
239
  valid_loader : object
240
  object of validation set dataloader
241
  """
242
+ train_dataset = HWRecogIAMDataset(
243
+ train_x,
244
+ train_y,
245
+ dir_images,
246
+ image_height=image_height,
247
+ image_width=image_width,
248
+ which_set="train",
249
+ )
250
+ valid_dataset = HWRecogIAMDataset(
251
+ valid_x,
252
+ valid_y,
253
+ dir_images,
254
+ image_height=image_height,
255
+ image_width=image_width,
256
+ which_set="valid",
257
+ )
258
 
259
  train_loader = DataLoader(
260
  train_dataset,
 
272
  )
273
  return train_loader, valid_loader
274
 
275
+
276
+ def get_dataloader_for_testing(
277
+ test_x, test_y, dir_images, image_height=32, image_width=768, batch_size=1
278
+ ):
279
  """
280
  ---------
281
  Arguments
 
299
  test_loader : object
300
  object of test set dataloader
301
  """
302
+ test_dataset = HWRecogIAMDataset(
303
+ test_x,
304
+ test_y,
305
+ dir_images=dir_images,
306
+ image_height=image_height,
307
+ image_width=image_width,
308
+ which_set="test",
309
+ )
310
  test_loader = DataLoader(
311
  test_dataset,
312
  batch_size=batch_size,
iam_line_recognition/final_iam_line_recognizer.py CHANGED
@@ -21,6 +21,7 @@ class DatasetFinalEval(HWRecogIAMDataset):
21
  """
22
  Dataset class for final evaluation - inherits main dataset class
23
  """
 
24
  def __init__(self, dir_images, image_height=32, image_width=768):
25
  """
26
  ---------
@@ -34,18 +35,24 @@ class DatasetFinalEval(HWRecogIAMDataset):
34
  image width (default: 768)
35
  """
36
  self.dir_images = dir_images
37
- self.image_files = [f for f in os.listdir(self.dir_images) if f.endswith(".png")]
 
 
38
  self.image_width = image_width
39
  self.image_height = image_height
40
- self.transform = transforms.Compose([
41
- transforms.ToPILImage(),
42
- transforms.Resize((self.image_height, self.image_width), Image.BILINEAR),
43
- transforms.ToTensor(),
44
- transforms.Normalize(
45
- mean=[0.485, 0.456, 0.406],
46
- std=[0.229, 0.224, 0.225],
47
- ),
48
- ])
 
 
 
 
49
 
50
  def __len__(self):
51
  return len(self.image_files)
@@ -57,7 +64,10 @@ class DatasetFinalEval(HWRecogIAMDataset):
57
  image_3_channel = self.transform(image_3_channel)
58
  return image_3_channel
59
 
60
- def get_dataloader_for_evaluation(dir_images, image_height=32, image_width=768, batch_size=1):
 
 
 
61
  """
62
  ---------
63
  Arguments
@@ -77,7 +87,9 @@ def get_dataloader_for_evaluation(dir_images, image_height=32, image_width=768,
77
  test_loader : object
78
  dataset loader object for final evaluation
79
  """
80
- test_dataset = DatasetFinalEval(dir_images=dir_images, image_height=image_height, image_width=image_width)
 
 
81
  test_loader = DataLoader(
82
  test_dataset,
83
  batch_size=batch_size,
@@ -86,6 +98,7 @@ def get_dataloader_for_evaluation(dir_images, image_height=32, image_width=768,
86
  )
87
  return test_loader
88
 
 
89
  def final_eval(hw_model, device, test_loader, dir_images, dir_results):
90
  """
91
  ---------
@@ -126,14 +139,22 @@ def final_eval(hw_model, device, test_loader, dir_images, dir_results):
126
  str_pred = [DatasetFinalEval.LABEL_2_CHAR[i] for i in pred_labels[0]]
127
  str_pred = "".join(str_pred)
128
 
129
- with open(os.path.join(dir_results, file_test+".txt"), "w", encoding="utf-8", newline="\n") as fh_pred:
 
 
 
 
 
130
  fh_pred.write(str_pred)
131
 
132
- print(f"progress: {count}/{num_test_samples}, test file: {list_test_files[count-1]}")
 
 
133
  print(f"{str_pred}\n")
134
  print(f"predictions saved in directory: ./{dir_results}\n")
135
  return
136
 
 
137
  def test_hw_recognizer(FLAGS):
138
  os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
139
 
@@ -163,15 +184,20 @@ def test_hw_recognizer(FLAGS):
163
 
164
  # get test set dataloader
165
  test_loader = get_dataloader_for_evaluation(
166
- dir_images=FLAGS.dir_images, image_height=FLAGS.image_height, image_width=FLAGS.image_width,
 
 
167
  )
168
 
169
  # start the evaluation on the final test set
170
- print(f"final evaluation of handwriting recognition model {FLAGS.which_hw_model} started\n")
 
 
171
  final_eval(hw_model, device, test_loader, FLAGS.dir_images, dir_results)
172
  print(f"final evaluation of handwriting recognition model completed!!!!")
173
  return
174
 
 
175
  def main():
176
  image_height = 32
177
  image_width = 768
@@ -184,22 +210,49 @@ def main():
184
  formatter_class=argparse.ArgumentDefaultsHelpFormatter
185
  )
186
 
187
- parser.add_argument("--image_height", default=image_height,
188
- type=int, help="image height to be used to predict with the model")
189
- parser.add_argument("--image_width", default=image_width,
190
- type=int, help="image width to be used to predict with the model")
191
- parser.add_argument("--dir_images", default=dir_images,
192
- type=str, help="full directory path to directory containing images")
193
- parser.add_argument("--which_hw_model", default=which_hw_model,
194
- type=str, choices=["crnn", "stn_crnn"], help="which model to be used for prediction")
195
- parser.add_argument("--file_model", default=file_model,
196
- type=str, help="full path to trained model file (.pth)")
197
- parser.add_argument("--save_predictions", default=save_predictions,
198
- type=int, choices=[0, 1], help="save or do not save the predictions (1 - save, 0 - do not save)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
  FLAGS, unparsed = parser.parse_known_args()
201
  test_hw_recognizer(FLAGS)
202
  return
203
 
 
204
  if __name__ == "__main__":
205
  main()
 
21
  """
22
  Dataset class for final evaluation - inherits main dataset class
23
  """
24
+
25
  def __init__(self, dir_images, image_height=32, image_width=768):
26
  """
27
  ---------
 
35
  image width (default: 768)
36
  """
37
  self.dir_images = dir_images
38
+ self.image_files = [
39
+ f for f in os.listdir(self.dir_images) if f.endswith(".png")
40
+ ]
41
  self.image_width = image_width
42
  self.image_height = image_height
43
+ self.transform = transforms.Compose(
44
+ [
45
+ transforms.ToPILImage(),
46
+ transforms.Resize(
47
+ (self.image_height, self.image_width), Image.BILINEAR
48
+ ),
49
+ transforms.ToTensor(),
50
+ transforms.Normalize(
51
+ mean=[0.485, 0.456, 0.406],
52
+ std=[0.229, 0.224, 0.225],
53
+ ),
54
+ ]
55
+ )
56
 
57
  def __len__(self):
58
  return len(self.image_files)
 
64
  image_3_channel = self.transform(image_3_channel)
65
  return image_3_channel
66
 
67
+
68
+ def get_dataloader_for_evaluation(
69
+ dir_images, image_height=32, image_width=768, batch_size=1
70
+ ):
71
  """
72
  ---------
73
  Arguments
 
87
  test_loader : object
88
  dataset loader object for final evaluation
89
  """
90
+ test_dataset = DatasetFinalEval(
91
+ dir_images=dir_images, image_height=image_height, image_width=image_width
92
+ )
93
  test_loader = DataLoader(
94
  test_dataset,
95
  batch_size=batch_size,
 
98
  )
99
  return test_loader
100
 
101
+
102
  def final_eval(hw_model, device, test_loader, dir_images, dir_results):
103
  """
104
  ---------
 
139
  str_pred = [DatasetFinalEval.LABEL_2_CHAR[i] for i in pred_labels[0]]
140
  str_pred = "".join(str_pred)
141
 
142
+ with open(
143
+ os.path.join(dir_results, file_test + ".txt"),
144
+ "w",
145
+ encoding="utf-8",
146
+ newline="\n",
147
+ ) as fh_pred:
148
  fh_pred.write(str_pred)
149
 
150
+ print(
151
+ f"progress: {count}/{num_test_samples}, test file: {list_test_files[count-1]}"
152
+ )
153
  print(f"{str_pred}\n")
154
  print(f"predictions saved in directory: ./{dir_results}\n")
155
  return
156
 
157
+
158
  def test_hw_recognizer(FLAGS):
159
  os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
160
 
 
184
 
185
  # get test set dataloader
186
  test_loader = get_dataloader_for_evaluation(
187
+ dir_images=FLAGS.dir_images,
188
+ image_height=FLAGS.image_height,
189
+ image_width=FLAGS.image_width,
190
  )
191
 
192
  # start the evaluation on the final test set
193
+ print(
194
+ f"final evaluation of handwriting recognition model {FLAGS.which_hw_model} started\n"
195
+ )
196
  final_eval(hw_model, device, test_loader, FLAGS.dir_images, dir_results)
197
  print(f"final evaluation of handwriting recognition model completed!!!!")
198
  return
199
 
200
+
201
  def main():
202
  image_height = 32
203
  image_width = 768
 
210
  formatter_class=argparse.ArgumentDefaultsHelpFormatter
211
  )
212
 
213
+ parser.add_argument(
214
+ "--image_height",
215
+ default=image_height,
216
+ type=int,
217
+ help="image height to be used to predict with the model",
218
+ )
219
+ parser.add_argument(
220
+ "--image_width",
221
+ default=image_width,
222
+ type=int,
223
+ help="image width to be used to predict with the model",
224
+ )
225
+ parser.add_argument(
226
+ "--dir_images",
227
+ default=dir_images,
228
+ type=str,
229
+ help="full directory path to directory containing images",
230
+ )
231
+ parser.add_argument(
232
+ "--which_hw_model",
233
+ default=which_hw_model,
234
+ type=str,
235
+ choices=["crnn", "stn_crnn"],
236
+ help="which model to be used for prediction",
237
+ )
238
+ parser.add_argument(
239
+ "--file_model",
240
+ default=file_model,
241
+ type=str,
242
+ help="full path to trained model file (.pth)",
243
+ )
244
+ parser.add_argument(
245
+ "--save_predictions",
246
+ default=save_predictions,
247
+ type=int,
248
+ choices=[0, 1],
249
+ help="save or do not save the predictions (1 - save, 0 - do not save)",
250
+ )
251
 
252
  FLAGS, unparsed = parser.parse_known_args()
253
  test_hw_recognizer(FLAGS)
254
  return
255
 
256
+
257
  if __name__ == "__main__":
258
  main()
iam_line_recognition/logger_utils.py CHANGED
@@ -1,6 +1,7 @@
1
  import csv
2
  import json
3
 
 
4
  def write_json_file(file_json, dict_data):
5
  """
6
  ---------
@@ -15,10 +16,12 @@ def write_json_file(file_json, dict_data):
15
  fh.write(json.dumps(dict_data, indent=4))
16
  return
17
 
 
18
  class CSVWriter:
19
  """
20
  for writing tabular data to a csv file
21
  """
 
22
  def __init__(self, file_name, column_names):
23
  """
24
  ---------
 
1
  import csv
2
  import json
3
 
4
+
5
  def write_json_file(file_json, dict_data):
6
  """
7
  ---------
 
16
  fh.write(json.dumps(dict_data, indent=4))
17
  return
18
 
19
+
20
  class CSVWriter:
21
  """
22
  for writing tabular data to a csv file
23
  """
24
+
25
  def __init__(self, file_name, column_names):
26
  """
27
  ---------
iam_line_recognition/model_main.py CHANGED
@@ -4,11 +4,20 @@ import torch.nn.functional as F
4
 
5
  from model_visual_features import ResNetFeatureExtractor, TPS_SpatialTransformerNetwork
6
 
 
7
  class HW_RNN_Seq2Seq(nn.Module):
8
  """
9
  Visual Seq2Seq model using BiLSTM
10
  """
11
- def __init__(self, num_classes, image_height, cnn_output_channels=512, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
 
 
 
 
 
 
 
 
12
  """
13
  ---------
14
  Arguments
@@ -28,10 +37,16 @@ class HW_RNN_Seq2Seq(nn.Module):
28
  self.output_height = image_height // 32
29
 
30
  self.dropout = nn.Dropout(p=0.25)
31
- self.map_visual_to_seq = nn.Linear(cnn_output_channels * self.output_height, num_feats_mapped_seq_hidden)
 
 
32
 
33
- self.b_lstm_1 = nn.LSTM(num_feats_mapped_seq_hidden, num_feats_seq_hidden, bidirectional=True)
34
- self.b_lstm_2 = nn.LSTM(2 * num_feats_seq_hidden, num_feats_seq_hidden, bidirectional=True)
 
 
 
 
35
 
36
  self.final_dense = nn.Linear(2 * num_feats_seq_hidden, num_classes)
37
 
@@ -40,7 +55,9 @@ class HW_RNN_Seq2Seq(nn.Module):
40
  # WBCH
41
  # the sequence is along the width of the image as a sentence
42
 
43
- visual_feats = visual_feats.contiguous().view(visual_feats.shape[0], visual_feats.shape[1], -1)
 
 
44
  # WBC
45
 
46
  seq = self.map_visual_to_seq(visual_feats)
@@ -63,7 +80,14 @@ class CRNN(nn.Module):
63
  CNN - Modified ResNet34 for visual features
64
  RNN - BiLSTM for seq2seq modeling
65
  """
66
- def __init__(self, num_classes, image_height, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
 
 
 
 
 
 
 
67
  """
68
  ---------
69
  Arguments
@@ -79,7 +103,13 @@ class CRNN(nn.Module):
79
  """
80
  super().__init__()
81
  self.visual_feature_extractor = ResNetFeatureExtractor()
82
- self.rnn_seq2seq_module = HW_RNN_Seq2Seq(num_classes, image_height, self.visual_feature_extractor.output_channels, num_feats_mapped_seq_hidden, num_feats_seq_hidden)
 
 
 
 
 
 
83
 
84
  def forward(self, x):
85
  visual_feats = self.visual_feature_extractor(x)
@@ -96,7 +126,15 @@ class STN_CRNN(nn.Module):
96
  CNN - Modified ResNet34 for visual features
97
  RNN - BiLSTM for seq2seq modeling
98
  """
99
- def __init__(self, num_classes, image_height, image_width, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
 
 
 
 
 
 
 
 
100
  """
101
  ---------
102
  Arguments
@@ -120,7 +158,13 @@ class STN_CRNN(nn.Module):
120
  I_channel_num=3,
121
  )
122
  self.visual_feature_extractor = ResNetFeatureExtractor()
123
- self.rnn_seq2seq_module = HW_RNN_Seq2Seq(num_classes, image_height, self.visual_feature_extractor.output_channels, num_feats_mapped_seq_hidden, num_feats_seq_hidden)
 
 
 
 
 
 
124
 
125
  def forward(self, x):
126
  stn_output = self.stn(x)
@@ -128,6 +172,7 @@ class STN_CRNN(nn.Module):
128
  log_probs = self.rnn_seq2seq_module(visual_feats)
129
  return log_probs
130
 
 
131
  """
132
  class STN_PP_CRNN(nn.Module):
133
  def __init__(self, num_classes, image_height, image_width, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
 
4
 
5
  from model_visual_features import ResNetFeatureExtractor, TPS_SpatialTransformerNetwork
6
 
7
+
8
  class HW_RNN_Seq2Seq(nn.Module):
9
  """
10
  Visual Seq2Seq model using BiLSTM
11
  """
12
+
13
+ def __init__(
14
+ self,
15
+ num_classes,
16
+ image_height,
17
+ cnn_output_channels=512,
18
+ num_feats_mapped_seq_hidden=128,
19
+ num_feats_seq_hidden=256,
20
+ ):
21
  """
22
  ---------
23
  Arguments
 
37
  self.output_height = image_height // 32
38
 
39
  self.dropout = nn.Dropout(p=0.25)
40
+ self.map_visual_to_seq = nn.Linear(
41
+ cnn_output_channels * self.output_height, num_feats_mapped_seq_hidden
42
+ )
43
 
44
+ self.b_lstm_1 = nn.LSTM(
45
+ num_feats_mapped_seq_hidden, num_feats_seq_hidden, bidirectional=True
46
+ )
47
+ self.b_lstm_2 = nn.LSTM(
48
+ 2 * num_feats_seq_hidden, num_feats_seq_hidden, bidirectional=True
49
+ )
50
 
51
  self.final_dense = nn.Linear(2 * num_feats_seq_hidden, num_classes)
52
 
 
55
  # WBCH
56
  # the sequence is along the width of the image as a sentence
57
 
58
+ visual_feats = visual_feats.contiguous().view(
59
+ visual_feats.shape[0], visual_feats.shape[1], -1
60
+ )
61
  # WBC
62
 
63
  seq = self.map_visual_to_seq(visual_feats)
 
80
  CNN - Modified ResNet34 for visual features
81
  RNN - BiLSTM for seq2seq modeling
82
  """
83
+
84
+ def __init__(
85
+ self,
86
+ num_classes,
87
+ image_height,
88
+ num_feats_mapped_seq_hidden=128,
89
+ num_feats_seq_hidden=256,
90
+ ):
91
  """
92
  ---------
93
  Arguments
 
103
  """
104
  super().__init__()
105
  self.visual_feature_extractor = ResNetFeatureExtractor()
106
+ self.rnn_seq2seq_module = HW_RNN_Seq2Seq(
107
+ num_classes,
108
+ image_height,
109
+ self.visual_feature_extractor.output_channels,
110
+ num_feats_mapped_seq_hidden,
111
+ num_feats_seq_hidden,
112
+ )
113
 
114
  def forward(self, x):
115
  visual_feats = self.visual_feature_extractor(x)
 
126
  CNN - Modified ResNet34 for visual features
127
  RNN - BiLSTM for seq2seq modeling
128
  """
129
+
130
+ def __init__(
131
+ self,
132
+ num_classes,
133
+ image_height,
134
+ image_width,
135
+ num_feats_mapped_seq_hidden=128,
136
+ num_feats_seq_hidden=256,
137
+ ):
138
  """
139
  ---------
140
  Arguments
 
158
  I_channel_num=3,
159
  )
160
  self.visual_feature_extractor = ResNetFeatureExtractor()
161
+ self.rnn_seq2seq_module = HW_RNN_Seq2Seq(
162
+ num_classes,
163
+ image_height,
164
+ self.visual_feature_extractor.output_channels,
165
+ num_feats_mapped_seq_hidden,
166
+ num_feats_seq_hidden,
167
+ )
168
 
169
  def forward(self, x):
170
  stn_output = self.stn(x)
 
172
  log_probs = self.rnn_seq2seq_module(visual_feats)
173
  return log_probs
174
 
175
+
176
  """
177
  class STN_PP_CRNN(nn.Module):
178
  def __init__(self, num_classes, image_height, image_width, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
iam_line_recognition/model_visual_features.py CHANGED
@@ -4,10 +4,17 @@ import torch.nn as nn
4
  from typing import List
5
  from torch import Tensor
6
  import torch.nn.functional as F
7
- from torchvision.models.resnet import BasicBlock, model_urls, load_state_dict_from_url, conv1x1, conv3x3
 
 
 
 
 
 
8
 
9
  device = torch.device("cuda")
10
 
 
11
  class CustomResNet(nn.Module):
12
  def __init__(
13
  self,
@@ -43,14 +50,22 @@ class CustomResNet(nn.Module):
43
  self.groups = groups
44
  self.base_width = width_per_group
45
 
46
- self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
 
 
47
  self.bn1 = self._norm_layer(self.inplanes)
48
  self.relu = nn.ReLU(inplace=True)
49
  self.maxpool = nn.MaxPool2d(kernel_size=3, stride=(2, 1), padding=1)
50
  self.layer1 = self._make_layer(block, 64, layers[0])
51
- self.layer2 = self._make_layer(block, 128, layers[1], stride=(2, 1), dilate=replace_stride_with_dilation[0])
52
- self.layer3 = self._make_layer(block, 256, layers[2], stride=(2, 2), dilate=replace_stride_with_dilation[1])
53
- self.layer4 = self._make_layer(block, 512, layers[3], stride=(2, 1), dilate=replace_stride_with_dilation[2])
 
 
 
 
 
 
54
  self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
55
  self.fc = nn.Linear(512 * block.expansion, num_classes)
56
 
@@ -92,7 +107,14 @@ class CustomResNet(nn.Module):
92
  layers = []
93
  layers.append(
94
  block(
95
- self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
 
 
 
 
 
 
 
96
  )
97
  )
98
  self.inplanes = planes * block.expansion
@@ -126,6 +148,7 @@ class CustomResNet(nn.Module):
126
  def forward(self, x: Tensor) -> Tensor:
127
  return self._forward_impl(x)
128
 
 
129
  def _resnet(layers: List[int], pretrained=True) -> CustomResNet:
130
  model = CustomResNet(layers)
131
 
@@ -134,6 +157,7 @@ def _resnet(layers: List[int], pretrained=True) -> CustomResNet:
134
 
135
  return model
136
 
 
137
  def resnet34(*, pretrained=True) -> CustomResNet:
138
  """ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
139
  Args:
@@ -159,6 +183,7 @@ class ResNetFeatureExtractor(nn.Module):
159
  """
160
  Defines Base ResNet-34 feature extractor
161
  """
 
162
  def __init__(self, pretrained=True):
163
  """
164
  ---------
@@ -174,7 +199,7 @@ class ResNetFeatureExtractor(nn.Module):
174
  def forward(self, x):
175
  block1 = self.resnet34.conv1(x)
176
  block1 = self.resnet34.bn1(block1)
177
- block1 = self.resnet34.relu(block1) # [64, H/2, W/2]
178
 
179
  block2 = self.resnet34.maxpool(block1)
180
  block2 = self.resnet34.layer1(block2) # [64, H/4, W/4]
@@ -190,10 +215,10 @@ class ResNetFeatureExtractor(nn.Module):
190
  ### STN - Spatial Transformer Network ###
191
  #########################################
192
  class TPS_SpatialTransformerNetwork(nn.Module):
193
- """ Rectification Network of RARE, namely TPS based STN """
194
 
195
  def __init__(self, num_fiducial_points, I_size, I_r_size, I_channel_num=1):
196
- """ Based on RARE TPS
197
  input:
198
  batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
199
  I_size : (height, width) of the input image I
@@ -207,39 +232,66 @@ class TPS_SpatialTransformerNetwork(nn.Module):
207
  self.I_size = I_size
208
  self.I_r_size = I_r_size # = (I_r_height, I_r_width)
209
  self.I_channel_num = I_channel_num
210
- self.LocalizationNetwork = LocalizationNetwork(self.num_fiducial_points, self.I_channel_num)
 
 
211
  self.GridGenerator = GridGenerator(self.num_fiducial_points, self.I_r_size)
212
 
213
  def forward(self, batch_I):
214
  batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
215
- build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2
216
- build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2])
 
 
 
 
217
 
218
  if torch.__version__ > "1.2.0":
219
- batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True)
 
 
 
 
 
220
  else:
221
- batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border')
 
 
222
 
223
  return batch_I_r
224
 
225
 
226
  class LocalizationNetwork(nn.Module):
227
- """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """
228
 
229
  def __init__(self, num_fiducial_points, I_channel_num):
230
  super(LocalizationNetwork, self).__init__()
231
  self.num_fiducial_points = num_fiducial_points
232
  self.I_channel_num = I_channel_num
233
  self.conv = nn.Sequential(
234
- nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1,
235
- bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
 
 
 
 
 
 
 
 
236
  nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2
237
- nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True),
 
 
238
  nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4
239
- nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True),
 
 
240
  nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8
241
- nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True),
242
- nn.AdaptiveAvgPool2d(1) # batch_size x 512
 
 
243
  )
244
 
245
  self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True))
@@ -254,7 +306,9 @@ class LocalizationNetwork(nn.Module):
254
  ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
255
  ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
256
  initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
257
- self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1)
 
 
258
 
259
  def forward(self, batch_I):
260
  """
@@ -263,15 +317,17 @@ class LocalizationNetwork(nn.Module):
263
  """
264
  batch_size = batch_I.size(0)
265
  features = self.conv(batch_I).view(batch_size, -1)
266
- batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.num_fiducial_points, 2)
 
 
267
  return batch_C_prime
268
 
269
 
270
  class GridGenerator(nn.Module):
271
- """ Grid Generator of RARE, which produces P_prime by multipling T with P """
272
 
273
  def __init__(self, num_fiducial_points, I_r_size):
274
- """ Generate P_hat and inv_delta_C for later """
275
  super(GridGenerator, self).__init__()
276
  self.eps = 1e-6
277
  self.I_r_height, self.I_r_width = I_r_size
@@ -279,14 +335,24 @@ class GridGenerator(nn.Module):
279
  self.C = self._build_C(self.num_fiducial_points) # F x 2
280
  self.P = self._build_P(self.I_r_width, self.I_r_height)
281
  ## for multi-gpu, you need register buffer
282
- self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.num_fiducial_points, self.C)).float()) # F+3 x F+3
283
- self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.num_fiducial_points, self.C, self.P)).float()) # n x F+3
 
 
 
 
 
 
 
 
 
 
284
  ## for fine-tuning with different image width, you may use below instead of self.register_buffer
285
- #self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.num_fiducial_points, self.C)).float().cuda() # F+3 x F+3
286
- #self.P_hat = torch.tensor(self._build_P_hat(self.num_fiducial_points, self.C, self.P)).float().cuda() # n x F+3
287
 
288
  def _build_C(self, F):
289
- """ Return coordinates of fiducial points in I_r; C """
290
  ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
291
  ctrl_pts_y_top = -1 * np.ones(int(F / 2))
292
  ctrl_pts_y_bottom = np.ones(int(F / 2))
@@ -296,7 +362,7 @@ class GridGenerator(nn.Module):
296
  return C # F x 2
297
 
298
  def _build_inv_delta_C(self, F, C):
299
- """ Return inv_delta_C which is needed to calculate T """
300
  hat_C = np.zeros((F, F), dtype=float) # F x F
301
  for i in range(0, F):
302
  for j in range(i, F):
@@ -304,31 +370,36 @@ class GridGenerator(nn.Module):
304
  hat_C[i, j] = r
305
  hat_C[j, i] = r
306
  np.fill_diagonal(hat_C, 1)
307
- hat_C = (hat_C ** 2) * np.log(hat_C)
308
  # print(C.shape, hat_C.shape)
309
  delta_C = np.concatenate( # F+3 x F+3
310
  [
311
  np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3
312
  np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3
313
- np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3
314
  ],
315
- axis=0
316
  )
317
  inv_delta_C = np.linalg.inv(delta_C)
318
  return inv_delta_C # F+3 x F+3
319
 
320
  def _build_P(self, I_r_width, I_r_height):
321
- I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width
322
- I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height
 
 
 
 
323
  P = np.stack( # self.I_r_width x self.I_r_height x 2
324
- np.meshgrid(I_r_grid_x, I_r_grid_y),
325
- axis=2
326
  )
327
  return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2
328
 
329
  def _build_P_hat(self, F, C, P):
330
  n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
331
- P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2
 
 
332
  C_tile = np.expand_dims(C, axis=0) # 1 x F x 2
333
  P_diff = P_tile - C_tile # n x F x 2
334
  rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F
@@ -337,13 +408,16 @@ class GridGenerator(nn.Module):
337
  return P_hat # n x F+3
338
 
339
  def build_P_prime(self, batch_C_prime):
340
- """ Generate Grid from batch_C_prime [batch_size x F x 2] """
341
  batch_size = batch_C_prime.size(0)
342
  batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
343
  batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
344
- batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros(
345
- batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2
346
- batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2
 
 
 
347
  batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
348
  return batch_P_prime # batch_size x n x 2
349
 
 
4
  from typing import List
5
  from torch import Tensor
6
  import torch.nn.functional as F
7
+ from torchvision.models.resnet import (
8
+ BasicBlock,
9
+ model_urls,
10
+ load_state_dict_from_url,
11
+ conv1x1,
12
+ conv3x3,
13
+ )
14
 
15
  device = torch.device("cuda")
16
 
17
+
18
  class CustomResNet(nn.Module):
19
  def __init__(
20
  self,
 
50
  self.groups = groups
51
  self.base_width = width_per_group
52
 
53
+ self.conv1 = nn.Conv2d(
54
+ 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
55
+ )
56
  self.bn1 = self._norm_layer(self.inplanes)
57
  self.relu = nn.ReLU(inplace=True)
58
  self.maxpool = nn.MaxPool2d(kernel_size=3, stride=(2, 1), padding=1)
59
  self.layer1 = self._make_layer(block, 64, layers[0])
60
+ self.layer2 = self._make_layer(
61
+ block, 128, layers[1], stride=(2, 1), dilate=replace_stride_with_dilation[0]
62
+ )
63
+ self.layer3 = self._make_layer(
64
+ block, 256, layers[2], stride=(2, 2), dilate=replace_stride_with_dilation[1]
65
+ )
66
+ self.layer4 = self._make_layer(
67
+ block, 512, layers[3], stride=(2, 1), dilate=replace_stride_with_dilation[2]
68
+ )
69
  self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
70
  self.fc = nn.Linear(512 * block.expansion, num_classes)
71
 
 
107
  layers = []
108
  layers.append(
109
  block(
110
+ self.inplanes,
111
+ planes,
112
+ stride,
113
+ downsample,
114
+ self.groups,
115
+ self.base_width,
116
+ previous_dilation,
117
+ norm_layer,
118
  )
119
  )
120
  self.inplanes = planes * block.expansion
 
148
  def forward(self, x: Tensor) -> Tensor:
149
  return self._forward_impl(x)
150
 
151
+
152
  def _resnet(layers: List[int], pretrained=True) -> CustomResNet:
153
  model = CustomResNet(layers)
154
 
 
157
 
158
  return model
159
 
160
+
161
  def resnet34(*, pretrained=True) -> CustomResNet:
162
  """ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
163
  Args:
 
183
  """
184
  Defines Base ResNet-34 feature extractor
185
  """
186
+
187
  def __init__(self, pretrained=True):
188
  """
189
  ---------
 
199
  def forward(self, x):
200
  block1 = self.resnet34.conv1(x)
201
  block1 = self.resnet34.bn1(block1)
202
+ block1 = self.resnet34.relu(block1) # [64, H/2, W/2]
203
 
204
  block2 = self.resnet34.maxpool(block1)
205
  block2 = self.resnet34.layer1(block2) # [64, H/4, W/4]
 
215
  ### STN - Spatial Transformer Network ###
216
  #########################################
217
  class TPS_SpatialTransformerNetwork(nn.Module):
218
+ """Rectification Network of RARE, namely TPS based STN"""
219
 
220
  def __init__(self, num_fiducial_points, I_size, I_r_size, I_channel_num=1):
221
+ """Based on RARE TPS
222
  input:
223
  batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
224
  I_size : (height, width) of the input image I
 
232
  self.I_size = I_size
233
  self.I_r_size = I_r_size # = (I_r_height, I_r_width)
234
  self.I_channel_num = I_channel_num
235
+ self.LocalizationNetwork = LocalizationNetwork(
236
+ self.num_fiducial_points, self.I_channel_num
237
+ )
238
  self.GridGenerator = GridGenerator(self.num_fiducial_points, self.I_r_size)
239
 
240
  def forward(self, batch_I):
241
  batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
242
+ build_P_prime = self.GridGenerator.build_P_prime(
243
+ batch_C_prime
244
+ ) # batch_size x n (= I_r_width x I_r_height) x 2
245
+ build_P_prime_reshape = build_P_prime.reshape(
246
+ [build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]
247
+ )
248
 
249
  if torch.__version__ > "1.2.0":
250
+ batch_I_r = F.grid_sample(
251
+ batch_I,
252
+ build_P_prime_reshape,
253
+ padding_mode="border",
254
+ align_corners=True,
255
+ )
256
  else:
257
+ batch_I_r = F.grid_sample(
258
+ batch_I, build_P_prime_reshape, padding_mode="border"
259
+ )
260
 
261
  return batch_I_r
262
 
263
 
264
  class LocalizationNetwork(nn.Module):
265
+ """Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height)"""
266
 
267
  def __init__(self, num_fiducial_points, I_channel_num):
268
  super(LocalizationNetwork, self).__init__()
269
  self.num_fiducial_points = num_fiducial_points
270
  self.I_channel_num = I_channel_num
271
  self.conv = nn.Sequential(
272
+ nn.Conv2d(
273
+ in_channels=self.I_channel_num,
274
+ out_channels=64,
275
+ kernel_size=3,
276
+ stride=1,
277
+ padding=1,
278
+ bias=False,
279
+ ),
280
+ nn.BatchNorm2d(64),
281
+ nn.ReLU(True),
282
  nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2
283
+ nn.Conv2d(64, 128, 3, 1, 1, bias=False),
284
+ nn.BatchNorm2d(128),
285
+ nn.ReLU(True),
286
  nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4
287
+ nn.Conv2d(128, 256, 3, 1, 1, bias=False),
288
+ nn.BatchNorm2d(256),
289
+ nn.ReLU(True),
290
  nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8
291
+ nn.Conv2d(256, 512, 3, 1, 1, bias=False),
292
+ nn.BatchNorm2d(512),
293
+ nn.ReLU(True),
294
+ nn.AdaptiveAvgPool2d(1), # batch_size x 512
295
  )
296
 
297
  self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True))
 
306
  ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
307
  ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
308
  initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
309
+ self.localization_fc2.bias.data = (
310
+ torch.from_numpy(initial_bias).float().view(-1)
311
+ )
312
 
313
  def forward(self, batch_I):
314
  """
 
317
  """
318
  batch_size = batch_I.size(0)
319
  features = self.conv(batch_I).view(batch_size, -1)
320
+ batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(
321
+ batch_size, self.num_fiducial_points, 2
322
+ )
323
  return batch_C_prime
324
 
325
 
326
  class GridGenerator(nn.Module):
327
+ """Grid Generator of RARE, which produces P_prime by multipling T with P"""
328
 
329
  def __init__(self, num_fiducial_points, I_r_size):
330
+ """Generate P_hat and inv_delta_C for later"""
331
  super(GridGenerator, self).__init__()
332
  self.eps = 1e-6
333
  self.I_r_height, self.I_r_width = I_r_size
 
335
  self.C = self._build_C(self.num_fiducial_points) # F x 2
336
  self.P = self._build_P(self.I_r_width, self.I_r_height)
337
  ## for multi-gpu, you need register buffer
338
+ self.register_buffer(
339
+ "inv_delta_C",
340
+ torch.tensor(
341
+ self._build_inv_delta_C(self.num_fiducial_points, self.C)
342
+ ).float(),
343
+ ) # F+3 x F+3
344
+ self.register_buffer(
345
+ "P_hat",
346
+ torch.tensor(
347
+ self._build_P_hat(self.num_fiducial_points, self.C, self.P)
348
+ ).float(),
349
+ ) # n x F+3
350
  ## for fine-tuning with different image width, you may use below instead of self.register_buffer
351
+ # self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.num_fiducial_points, self.C)).float().cuda() # F+3 x F+3
352
+ # self.P_hat = torch.tensor(self._build_P_hat(self.num_fiducial_points, self.C, self.P)).float().cuda() # n x F+3
353
 
354
  def _build_C(self, F):
355
+ """Return coordinates of fiducial points in I_r; C"""
356
  ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
357
  ctrl_pts_y_top = -1 * np.ones(int(F / 2))
358
  ctrl_pts_y_bottom = np.ones(int(F / 2))
 
362
  return C # F x 2
363
 
364
  def _build_inv_delta_C(self, F, C):
365
+ """Return inv_delta_C which is needed to calculate T"""
366
  hat_C = np.zeros((F, F), dtype=float) # F x F
367
  for i in range(0, F):
368
  for j in range(i, F):
 
370
  hat_C[i, j] = r
371
  hat_C[j, i] = r
372
  np.fill_diagonal(hat_C, 1)
373
+ hat_C = (hat_C**2) * np.log(hat_C)
374
  # print(C.shape, hat_C.shape)
375
  delta_C = np.concatenate( # F+3 x F+3
376
  [
377
  np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3
378
  np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3
379
+ np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1), # 1 x F+3
380
  ],
381
+ axis=0,
382
  )
383
  inv_delta_C = np.linalg.inv(delta_C)
384
  return inv_delta_C # F+3 x F+3
385
 
386
  def _build_P(self, I_r_width, I_r_height):
387
+ I_r_grid_x = (
388
+ np.arange(-I_r_width, I_r_width, 2) + 1.0
389
+ ) / I_r_width # self.I_r_width
390
+ I_r_grid_y = (
391
+ np.arange(-I_r_height, I_r_height, 2) + 1.0
392
+ ) / I_r_height # self.I_r_height
393
  P = np.stack( # self.I_r_width x self.I_r_height x 2
394
+ np.meshgrid(I_r_grid_x, I_r_grid_y), axis=2
 
395
  )
396
  return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2
397
 
398
  def _build_P_hat(self, F, C, P):
399
  n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
400
+ P_tile = np.tile(
401
+ np.expand_dims(P, axis=1), (1, F, 1)
402
+ ) # n x 2 -> n x 1 x 2 -> n x F x 2
403
  C_tile = np.expand_dims(C, axis=0) # 1 x F x 2
404
  P_diff = P_tile - C_tile # n x F x 2
405
  rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F
 
408
  return P_hat # n x F+3
409
 
410
  def build_P_prime(self, batch_C_prime):
411
+ """Generate Grid from batch_C_prime [batch_size x F x 2]"""
412
  batch_size = batch_C_prime.size(0)
413
  batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
414
  batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
415
+ batch_C_prime_with_zeros = torch.cat(
416
+ (batch_C_prime, torch.zeros(batch_size, 3, 2).float().to(device)), dim=1
417
+ ) # batch_size x F+3 x 2
418
+ batch_T = torch.bmm(
419
+ batch_inv_delta_C, batch_C_prime_with_zeros
420
+ ) # batch_size x F+3 x 2
421
  batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
422
  return batch_P_prime # batch_size x n x 2
423
 
iam_line_recognition/test_internal.py CHANGED
@@ -14,7 +14,14 @@ from utils import ctc_decode, compute_wer_and_cer_for_sample
14
  from dataset import HWRecogIAMDataset, split_dataset, get_dataloader_for_testing
15
 
16
 
17
- def test(hw_model, test_loader, device, list_test_files, which_ctc_decoder="beam_search", save_prediction_stats=False):
 
 
 
 
 
 
 
18
  """
19
  ---------
20
  Arguments
@@ -42,7 +49,7 @@ def test(hw_model, test_loader, device, list_test_files, which_ctc_decoder="beam
42
  if save_prediction_stats:
43
  csv_writer = CSVWriter(
44
  file_name="pred_stats.csv",
45
- column_names=["file_name", "num_chars", "num_words", "cer", "wer"]
46
  )
47
 
48
  with torch.no_grad():
@@ -62,19 +69,23 @@ def test(hw_model, test_loader, device, list_test_files, which_ctc_decoder="beam
62
  list_test_cers.append(cer_sample)
63
  list_test_wers.append(wer_sample)
64
 
65
- print(f"progress: {count}/{num_test_samples}, test file: {list_test_files[count-1]}")
 
 
66
  print(f"{str_label} - label")
67
  print(f"{str_pred} - prediction")
68
  print(f"cer: {cer_sample:.3f}, wer: {wer_sample:.3f}\n")
69
 
70
  if save_prediction_stats:
71
- csv_writer.write_row([
72
- list_test_files[count-1],
73
- len(str_label),
74
- len(str_label.split(" ")),
75
- cer_sample,
76
- wer_sample,
77
- ])
 
 
78
  list_test_cers = np.array(list_test_cers)
79
  list_test_wers = np.array(list_test_wers)
80
  mean_test_cer = np.mean(list_test_cers)
@@ -85,6 +96,7 @@ def test(hw_model, test_loader, device, list_test_files, which_ctc_decoder="beam
85
  csv_writer.close()
86
  return
87
 
 
88
  def test_hw_recognizer(FLAGS):
89
  file_txt_labels = os.path.join(FLAGS.dir_dataset, "iam_lines_gt.txt")
90
  dir_images = os.path.join(FLAGS.dir_dataset, "img")
@@ -101,8 +113,11 @@ def test_hw_recognizer(FLAGS):
101
  num_test_samples = len(test_x)
102
  # get the internal test set dataloader
103
  test_loader = get_dataloader_for_testing(
104
- test_x, test_y,
105
- dir_images=dir_images, image_height=FLAGS.image_height, image_width=FLAGS.image_width,
 
 
 
106
  )
107
 
108
  num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1
@@ -124,10 +139,18 @@ def test_hw_recognizer(FLAGS):
124
 
125
  # start testing of the model on the internal set
126
  print(f"testing of handwriting recognition model {FLAGS.which_hw_model} started\n")
127
- test(hw_model, test_loader, device, test_x, FLAGS.which_ctc_decoder, bool(FLAGS.save_prediction_stats))
 
 
 
 
 
 
 
128
  print(f"testing handwriting recognition model completed!!!!")
129
  return
130
 
 
131
  def main():
132
  image_height = 32
133
  image_width = 768
@@ -141,24 +164,56 @@ def main():
141
  formatter_class=argparse.ArgumentDefaultsHelpFormatter
142
  )
143
 
144
- parser.add_argument("--image_height", default=image_height,
145
- type=int, help="image height to be used to predict with the model")
146
- parser.add_argument("--image_width", default=image_width,
147
- type=int, help="image width to be used to predict with the model")
148
- parser.add_argument("--dir_dataset", default=dir_dataset,
149
- type=str, help="full directory path to the dataset")
150
- parser.add_argument("--which_hw_model", default=which_hw_model,
151
- type=str, choices=["crnn", "stn_crnn"], help="which model to be used for prediction")
152
- parser.add_argument("--which_ctc_decoder", default=which_ctc_decoder,
153
- type=str, choices=["beam_search", "greedy"], help="which ctc decoder to use")
154
- parser.add_argument("--file_model", default=file_model,
155
- type=str, help="full path to trained model file (.pth)")
156
- parser.add_argument("--save_prediction_stats", default=save_prediction_stats,
157
- type=int, choices=[0, 1], help="save prediction stats (1 - yes, 0 - no)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  FLAGS, unparsed = parser.parse_known_args()
160
  test_hw_recognizer(FLAGS)
161
  return
162
 
 
163
  if __name__ == "__main__":
164
  main()
 
14
  from dataset import HWRecogIAMDataset, split_dataset, get_dataloader_for_testing
15
 
16
 
17
+ def test(
18
+ hw_model,
19
+ test_loader,
20
+ device,
21
+ list_test_files,
22
+ which_ctc_decoder="beam_search",
23
+ save_prediction_stats=False,
24
+ ):
25
  """
26
  ---------
27
  Arguments
 
49
  if save_prediction_stats:
50
  csv_writer = CSVWriter(
51
  file_name="pred_stats.csv",
52
+ column_names=["file_name", "num_chars", "num_words", "cer", "wer"],
53
  )
54
 
55
  with torch.no_grad():
 
69
  list_test_cers.append(cer_sample)
70
  list_test_wers.append(wer_sample)
71
 
72
+ print(
73
+ f"progress: {count}/{num_test_samples}, test file: {list_test_files[count-1]}"
74
+ )
75
  print(f"{str_label} - label")
76
  print(f"{str_pred} - prediction")
77
  print(f"cer: {cer_sample:.3f}, wer: {wer_sample:.3f}\n")
78
 
79
  if save_prediction_stats:
80
+ csv_writer.write_row(
81
+ [
82
+ list_test_files[count - 1],
83
+ len(str_label),
84
+ len(str_label.split(" ")),
85
+ cer_sample,
86
+ wer_sample,
87
+ ]
88
+ )
89
  list_test_cers = np.array(list_test_cers)
90
  list_test_wers = np.array(list_test_wers)
91
  mean_test_cer = np.mean(list_test_cers)
 
96
  csv_writer.close()
97
  return
98
 
99
+
100
  def test_hw_recognizer(FLAGS):
101
  file_txt_labels = os.path.join(FLAGS.dir_dataset, "iam_lines_gt.txt")
102
  dir_images = os.path.join(FLAGS.dir_dataset, "img")
 
113
  num_test_samples = len(test_x)
114
  # get the internal test set dataloader
115
  test_loader = get_dataloader_for_testing(
116
+ test_x,
117
+ test_y,
118
+ dir_images=dir_images,
119
+ image_height=FLAGS.image_height,
120
+ image_width=FLAGS.image_width,
121
  )
122
 
123
  num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1
 
139
 
140
  # start testing of the model on the internal set
141
  print(f"testing of handwriting recognition model {FLAGS.which_hw_model} started\n")
142
+ test(
143
+ hw_model,
144
+ test_loader,
145
+ device,
146
+ test_x,
147
+ FLAGS.which_ctc_decoder,
148
+ bool(FLAGS.save_prediction_stats),
149
+ )
150
  print(f"testing handwriting recognition model completed!!!!")
151
  return
152
 
153
+
154
  def main():
155
  image_height = 32
156
  image_width = 768
 
164
  formatter_class=argparse.ArgumentDefaultsHelpFormatter
165
  )
166
 
167
+ parser.add_argument(
168
+ "--image_height",
169
+ default=image_height,
170
+ type=int,
171
+ help="image height to be used to predict with the model",
172
+ )
173
+ parser.add_argument(
174
+ "--image_width",
175
+ default=image_width,
176
+ type=int,
177
+ help="image width to be used to predict with the model",
178
+ )
179
+ parser.add_argument(
180
+ "--dir_dataset",
181
+ default=dir_dataset,
182
+ type=str,
183
+ help="full directory path to the dataset",
184
+ )
185
+ parser.add_argument(
186
+ "--which_hw_model",
187
+ default=which_hw_model,
188
+ type=str,
189
+ choices=["crnn", "stn_crnn"],
190
+ help="which model to be used for prediction",
191
+ )
192
+ parser.add_argument(
193
+ "--which_ctc_decoder",
194
+ default=which_ctc_decoder,
195
+ type=str,
196
+ choices=["beam_search", "greedy"],
197
+ help="which ctc decoder to use",
198
+ )
199
+ parser.add_argument(
200
+ "--file_model",
201
+ default=file_model,
202
+ type=str,
203
+ help="full path to trained model file (.pth)",
204
+ )
205
+ parser.add_argument(
206
+ "--save_prediction_stats",
207
+ default=save_prediction_stats,
208
+ type=int,
209
+ choices=[0, 1],
210
+ help="save prediction stats (1 - yes, 0 - no)",
211
+ )
212
 
213
  FLAGS, unparsed = parser.parse_known_args()
214
  test_hw_recognizer(FLAGS)
215
  return
216
 
217
+
218
  if __name__ == "__main__":
219
  main()
iam_line_recognition/train.py CHANGED
@@ -55,12 +55,15 @@ def train(hw_model, optimizer, criterion, train_loader, device):
55
  loss = criterion(log_probs, labels, lengths_preds, lengths_labels)
56
  train_running_loss += loss.item()
57
  loss.backward()
58
- torch.nn.utils.clip_grad_norm_(hw_model.parameters(), 5) # gradient clipping with 5
 
 
59
  optimizer.step()
60
 
61
  train_loss = train_running_loss / num_train_batches
62
  return train_loss
63
 
 
64
  def validate(hw_model, criterion, valid_loader, device):
65
  """
66
  ---------
@@ -114,19 +117,28 @@ def validate(hw_model, criterion, valid_loader, device):
114
  final_labels_for_eval = []
115
  length_label_counter = 0
116
  for pred_label, length_label in zip(pred_labels, lengths_labels_for_eval):
117
- label = labels_for_eval[length_label_counter:length_label_counter+length_label]
 
 
118
  length_label_counter += length_label
119
 
120
  final_labels_for_eval.append(label)
121
 
122
  for i in range(len(final_labels_for_eval)):
123
  if len(pred_labels[i]) != 0:
124
- str_label = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in final_labels_for_eval[i]]
 
 
 
125
  str_label = "".join(str_label)
126
- str_pred = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in pred_labels[i]]
 
 
127
  str_pred = "".join(str_pred)
128
 
129
- cer_sample, wer_sample = compute_wer_and_cer_for_sample(str_pred, str_label)
 
 
130
  else:
131
  cer_sample, wer_sample = 100, 100
132
 
@@ -138,6 +150,7 @@ def validate(hw_model, criterion, valid_loader, device):
138
  valid_wer = valid_running_wer / num_valid_samples
139
  return valid_loss, valid_cer, valid_wer
140
 
 
141
  def train_hw_recognizer(FLAGS):
142
  file_txt_labels = os.path.join(FLAGS.dir_dataset, "iam_lines_gt.txt")
143
  dir_images = os.path.join(FLAGS.dir_dataset, "img")
@@ -156,8 +169,13 @@ def train_hw_recognizer(FLAGS):
156
  num_valid_samples = len(valid_x)
157
  # get dataloaders for train and validation sets
158
  train_loader, valid_loader = get_dataloaders_for_training(
159
- train_x, train_y, valid_x, valid_y,
160
- dir_images=dir_images, image_height=FLAGS.image_height, image_width=FLAGS.image_width,
 
 
 
 
 
161
  batch_size=FLAGS.batch_size,
162
  )
163
 
@@ -171,7 +189,7 @@ def train_hw_recognizer(FLAGS):
171
  file_logger_train = os.path.join(dir_model, "train_metrics.csv")
172
  csv_writer = CSVWriter(
173
  file_name=file_logger_train,
174
- column_names=["epoch", "loss_train", "loss_valid", "cer_valid", "wer_valid"]
175
  )
176
 
177
  file_params = os.path.join(dir_model, "params.json")
@@ -180,9 +198,15 @@ def train_hw_recognizer(FLAGS):
180
  num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1
181
  print(f"task - handwriting recognition")
182
  print(f"model: {FLAGS.which_hw_model}")
183
- print(f"optimizer: {FLAGS.which_optimizer}, learning rate: {FLAGS.learning_rate:.6f}, weight decay: {FLAGS.weight_decay:.8f}")
184
- print(f"batch size: {FLAGS.batch_size}, image height: {FLAGS.image_height}, image width: {FLAGS.image_width}")
185
- print(f"num train samples: {num_train_samples}, num validation samples: {num_valid_samples}\n")
 
 
 
 
 
 
186
 
187
  # load the right model
188
  if FLAGS.which_hw_model == "crnn":
@@ -196,9 +220,19 @@ def train_hw_recognizer(FLAGS):
196
 
197
  # load the right optimizer based on user option
198
  if FLAGS.which_optimizer == "adam":
199
- optimizer = torch.optim.Adam(hw_model.parameters(), lr=FLAGS.learning_rate, weight_decay=FLAGS.weight_decay)
 
 
 
 
200
  elif FLAGS.which_optimizer == "adadelta":
201
- optimizer = torch.optim.Adadelta(hw_model.parameters(), lr=FLAGS.learning_rate, rho=0.95, eps=1e-8, weight_decay=FLAGS.weight_decay)
 
 
 
 
 
 
202
  else:
203
  print(f"unidentified option: {FLAGS.which_optimizer}")
204
  sys.exit(0)
@@ -207,13 +241,19 @@ def train_hw_recognizer(FLAGS):
207
 
208
  # start training the model
209
  print(f"training of handwriting recognition model {FLAGS.which_hw_model} started\n")
210
- for epoch in range(1, FLAGS.num_epochs+1):
211
  time_start = time.time()
212
  train_loss = train(hw_model, optimizer, criterion, train_loader, device)
213
- valid_loss, valid_cer, valid_wer = validate(hw_model, criterion, valid_loader, device)
 
 
214
  time_end = time.time()
215
- print(f"epoch: {epoch}/{FLAGS.num_epochs}, time: {time_end-time_start:.3f} sec.")
216
- print(f"train loss: {train_loss:.6f}, validation loss: {valid_loss:.6f}, validation cer: {valid_cer:.4f}, validation wer: {valid_wer:.4f}\n")
 
 
 
 
217
 
218
  csv_writer.write_row(
219
  [
@@ -224,12 +264,21 @@ def train_hw_recognizer(FLAGS):
224
  round(valid_wer, 4),
225
  ]
226
  )
227
- torch.save(hw_model.state_dict(), os.path.join(dir_model, f"{FLAGS.which_hw_model}_H_{FLAGS.image_height}_W_{FLAGS.image_width}_E_{epoch}.pth"))
228
- print(f"Training of handwriting recognition model {FLAGS.which_hw_model} complete!!!!")
 
 
 
 
 
 
 
 
229
  # close the csv file
230
  csv_writer.close()
231
  return
232
 
 
233
  def main():
234
  learning_rate = 1
235
  # 3e-4 for Adam, 1 for Adadelta
@@ -248,28 +297,67 @@ def main():
248
  formatter_class=argparse.ArgumentDefaultsHelpFormatter
249
  )
250
 
251
- parser.add_argument("--learning_rate", default=learning_rate,
252
- type=float, help="learning rate to use for training")
253
- parser.add_argument("--weight_decay", default=weight_decay,
254
- type=float, help="weight decay to use for training")
255
- parser.add_argument("--batch_size", default=batch_size,
256
- type=int, help="batch size to use for training")
257
- parser.add_argument("--num_epochs", default=num_epochs,
258
- type=int, help="num epochs to train the model")
259
- parser.add_argument("--image_height", default=image_height,
260
- type=int, help="image height to be used to train the model")
261
- parser.add_argument("--image_width", default=image_width,
262
- type=int, help="image width to be used to train the model")
263
- parser.add_argument("--dir_dataset", default=dir_dataset,
264
- type=str, help="full directory path to the dataset")
265
- parser.add_argument("--which_optimizer", default=which_optimizer,
266
- type=str, choices=["adadelta", "adam"], help="which optimizer to use to train")
267
- parser.add_argument("--which_hw_model", default=which_hw_model,
268
- type=str, choices=["crnn", "stn_crnn", "stn_pp_crnn"], help="which model to train")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
  FLAGS, unparsed = parser.parse_known_args()
271
  train_hw_recognizer(FLAGS)
272
  return
273
 
 
274
  if __name__ == "__main__":
275
  main()
 
55
  loss = criterion(log_probs, labels, lengths_preds, lengths_labels)
56
  train_running_loss += loss.item()
57
  loss.backward()
58
+ torch.nn.utils.clip_grad_norm_(
59
+ hw_model.parameters(), 5
60
+ ) # gradient clipping with 5
61
  optimizer.step()
62
 
63
  train_loss = train_running_loss / num_train_batches
64
  return train_loss
65
 
66
+
67
  def validate(hw_model, criterion, valid_loader, device):
68
  """
69
  ---------
 
117
  final_labels_for_eval = []
118
  length_label_counter = 0
119
  for pred_label, length_label in zip(pred_labels, lengths_labels_for_eval):
120
+ label = labels_for_eval[
121
+ length_label_counter : length_label_counter + length_label
122
+ ]
123
  length_label_counter += length_label
124
 
125
  final_labels_for_eval.append(label)
126
 
127
  for i in range(len(final_labels_for_eval)):
128
  if len(pred_labels[i]) != 0:
129
+ str_label = [
130
+ HWRecogIAMDataset.LABEL_2_CHAR[i]
131
+ for i in final_labels_for_eval[i]
132
+ ]
133
  str_label = "".join(str_label)
134
+ str_pred = [
135
+ HWRecogIAMDataset.LABEL_2_CHAR[i] for i in pred_labels[i]
136
+ ]
137
  str_pred = "".join(str_pred)
138
 
139
+ cer_sample, wer_sample = compute_wer_and_cer_for_sample(
140
+ str_pred, str_label
141
+ )
142
  else:
143
  cer_sample, wer_sample = 100, 100
144
 
 
150
  valid_wer = valid_running_wer / num_valid_samples
151
  return valid_loss, valid_cer, valid_wer
152
 
153
+
154
  def train_hw_recognizer(FLAGS):
155
  file_txt_labels = os.path.join(FLAGS.dir_dataset, "iam_lines_gt.txt")
156
  dir_images = os.path.join(FLAGS.dir_dataset, "img")
 
169
  num_valid_samples = len(valid_x)
170
  # get dataloaders for train and validation sets
171
  train_loader, valid_loader = get_dataloaders_for_training(
172
+ train_x,
173
+ train_y,
174
+ valid_x,
175
+ valid_y,
176
+ dir_images=dir_images,
177
+ image_height=FLAGS.image_height,
178
+ image_width=FLAGS.image_width,
179
  batch_size=FLAGS.batch_size,
180
  )
181
 
 
189
  file_logger_train = os.path.join(dir_model, "train_metrics.csv")
190
  csv_writer = CSVWriter(
191
  file_name=file_logger_train,
192
+ column_names=["epoch", "loss_train", "loss_valid", "cer_valid", "wer_valid"],
193
  )
194
 
195
  file_params = os.path.join(dir_model, "params.json")
 
198
  num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1
199
  print(f"task - handwriting recognition")
200
  print(f"model: {FLAGS.which_hw_model}")
201
+ print(
202
+ f"optimizer: {FLAGS.which_optimizer}, learning rate: {FLAGS.learning_rate:.6f}, weight decay: {FLAGS.weight_decay:.8f}"
203
+ )
204
+ print(
205
+ f"batch size: {FLAGS.batch_size}, image height: {FLAGS.image_height}, image width: {FLAGS.image_width}"
206
+ )
207
+ print(
208
+ f"num train samples: {num_train_samples}, num validation samples: {num_valid_samples}\n"
209
+ )
210
 
211
  # load the right model
212
  if FLAGS.which_hw_model == "crnn":
 
220
 
221
  # load the right optimizer based on user option
222
  if FLAGS.which_optimizer == "adam":
223
+ optimizer = torch.optim.Adam(
224
+ hw_model.parameters(),
225
+ lr=FLAGS.learning_rate,
226
+ weight_decay=FLAGS.weight_decay,
227
+ )
228
  elif FLAGS.which_optimizer == "adadelta":
229
+ optimizer = torch.optim.Adadelta(
230
+ hw_model.parameters(),
231
+ lr=FLAGS.learning_rate,
232
+ rho=0.95,
233
+ eps=1e-8,
234
+ weight_decay=FLAGS.weight_decay,
235
+ )
236
  else:
237
  print(f"unidentified option: {FLAGS.which_optimizer}")
238
  sys.exit(0)
 
241
 
242
  # start training the model
243
  print(f"training of handwriting recognition model {FLAGS.which_hw_model} started\n")
244
+ for epoch in range(1, FLAGS.num_epochs + 1):
245
  time_start = time.time()
246
  train_loss = train(hw_model, optimizer, criterion, train_loader, device)
247
+ valid_loss, valid_cer, valid_wer = validate(
248
+ hw_model, criterion, valid_loader, device
249
+ )
250
  time_end = time.time()
251
+ print(
252
+ f"epoch: {epoch}/{FLAGS.num_epochs}, time: {time_end-time_start:.3f} sec."
253
+ )
254
+ print(
255
+ f"train loss: {train_loss:.6f}, validation loss: {valid_loss:.6f}, validation cer: {valid_cer:.4f}, validation wer: {valid_wer:.4f}\n"
256
+ )
257
 
258
  csv_writer.write_row(
259
  [
 
264
  round(valid_wer, 4),
265
  ]
266
  )
267
+ torch.save(
268
+ hw_model.state_dict(),
269
+ os.path.join(
270
+ dir_model,
271
+ f"{FLAGS.which_hw_model}_H_{FLAGS.image_height}_W_{FLAGS.image_width}_E_{epoch}.pth",
272
+ ),
273
+ )
274
+ print(
275
+ f"Training of handwriting recognition model {FLAGS.which_hw_model} complete!!!!"
276
+ )
277
  # close the csv file
278
  csv_writer.close()
279
  return
280
 
281
+
282
  def main():
283
  learning_rate = 1
284
  # 3e-4 for Adam, 1 for Adadelta
 
297
  formatter_class=argparse.ArgumentDefaultsHelpFormatter
298
  )
299
 
300
+ parser.add_argument(
301
+ "--learning_rate",
302
+ default=learning_rate,
303
+ type=float,
304
+ help="learning rate to use for training",
305
+ )
306
+ parser.add_argument(
307
+ "--weight_decay",
308
+ default=weight_decay,
309
+ type=float,
310
+ help="weight decay to use for training",
311
+ )
312
+ parser.add_argument(
313
+ "--batch_size",
314
+ default=batch_size,
315
+ type=int,
316
+ help="batch size to use for training",
317
+ )
318
+ parser.add_argument(
319
+ "--num_epochs",
320
+ default=num_epochs,
321
+ type=int,
322
+ help="num epochs to train the model",
323
+ )
324
+ parser.add_argument(
325
+ "--image_height",
326
+ default=image_height,
327
+ type=int,
328
+ help="image height to be used to train the model",
329
+ )
330
+ parser.add_argument(
331
+ "--image_width",
332
+ default=image_width,
333
+ type=int,
334
+ help="image width to be used to train the model",
335
+ )
336
+ parser.add_argument(
337
+ "--dir_dataset",
338
+ default=dir_dataset,
339
+ type=str,
340
+ help="full directory path to the dataset",
341
+ )
342
+ parser.add_argument(
343
+ "--which_optimizer",
344
+ default=which_optimizer,
345
+ type=str,
346
+ choices=["adadelta", "adam"],
347
+ help="which optimizer to use to train",
348
+ )
349
+ parser.add_argument(
350
+ "--which_hw_model",
351
+ default=which_hw_model,
352
+ type=str,
353
+ choices=["crnn", "stn_crnn", "stn_pp_crnn"],
354
+ help="which model to train",
355
+ )
356
 
357
  FLAGS, unparsed = parser.parse_known_args()
358
  train_hw_recognizer(FLAGS)
359
  return
360
 
361
+
362
  if __name__ == "__main__":
363
  main()
iam_line_recognition/utils.py CHANGED
@@ -13,6 +13,7 @@ from scipy.special import logsumexp
13
  NINF = -1 * float("inf")
14
  DEFAULT_EMISSION_THRESHOLD = 0.01
15
 
 
16
  def _reconstruct(labels, blank=0):
17
  new_labels = []
18
  # merge same labels
@@ -25,9 +26,12 @@ def _reconstruct(labels, blank=0):
25
  new_labels = [l for l in new_labels if l != blank]
26
  return new_labels
27
 
 
28
  def beam_search_decode(emission_log_prob, blank=0, **kwargs):
29
  beam_size = kwargs["beam_size"]
30
- emission_threshold = kwargs.get("emission_threshold", np.log(DEFAULT_EMISSION_THRESHOLD))
 
 
31
 
32
  length, class_count = emission_log_prob.shape
33
 
@@ -53,29 +57,38 @@ def beam_search_decode(emission_log_prob, blank=0, **kwargs):
53
  for prefix, accu_log_prob in beams:
54
  labels = tuple(_reconstruct(prefix, blank))
55
  # log(p1 + p2) = logsumexp([log_p1, log_p2])
56
- total_accu_log_prob[labels] = \
57
- logsumexp([accu_log_prob, total_accu_log_prob.get(labels, NINF)])
58
-
59
- labels_beams = [(list(labels), accu_log_prob)
60
- for labels, accu_log_prob in total_accu_log_prob.items()]
 
 
 
61
  labels_beams.sort(key=lambda x: x[1], reverse=True)
62
  labels = labels_beams[0][0]
63
 
64
  return labels
65
 
 
66
  def greedy_decode(emission_log_prob, blank=0):
67
  labels = np.argmax(emission_log_prob, axis=-1)
68
  labels = _reconstruct(labels, blank=blank)
69
  return labels
70
 
71
- def ctc_decode(log_probs, which_ctc_decoder="beam_search", label_2_char=None, blank=0, beam_size=25):
 
 
 
72
  emission_log_probs = np.transpose(log_probs.cpu().numpy(), (1, 0, 2))
73
  # size of emission_log_probs: (batch, length, class)
74
 
75
  decoded_list = []
76
  for emission_log_prob in emission_log_probs:
77
  if which_ctc_decoder == "beam_search":
78
- decoded = beam_search_decode(emission_log_prob, blank=blank, beam_size=beam_size)
 
 
79
  elif which_ctc_decoder == "greedy":
80
  decoded = greedy_decode(emission_log_prob, blank=blank)
81
  else:
@@ -87,16 +100,20 @@ def ctc_decode(log_probs, which_ctc_decoder="beam_search", label_2_char=None, bl
87
  decoded_list.append(decoded)
88
  return decoded_list
89
 
 
90
  """
91
  --------------------
92
  Evaluation Metrics
93
  --------------------
94
  """
 
 
95
  def compute_wer_and_cer_for_batch(batch_preds, batch_gts):
96
  cer_batch = fastwer.score(batch_preds, batch_gts, char_level=True)
97
  wer_batch = fastwer.score(batch_preds, batch_gts)
98
  return cer_batch, wer_batch
99
 
 
100
  def compute_wer_and_cer_for_sample(str_pred, str_gt):
101
  cer_sample = fastwer.score_sent(str_pred, str_gt, char_level=True)
102
  wer_sample = fastwer.score_sent(str_pred, str_gt)
 
13
  NINF = -1 * float("inf")
14
  DEFAULT_EMISSION_THRESHOLD = 0.01
15
 
16
+
17
  def _reconstruct(labels, blank=0):
18
  new_labels = []
19
  # merge same labels
 
26
  new_labels = [l for l in new_labels if l != blank]
27
  return new_labels
28
 
29
+
30
  def beam_search_decode(emission_log_prob, blank=0, **kwargs):
31
  beam_size = kwargs["beam_size"]
32
+ emission_threshold = kwargs.get(
33
+ "emission_threshold", np.log(DEFAULT_EMISSION_THRESHOLD)
34
+ )
35
 
36
  length, class_count = emission_log_prob.shape
37
 
 
57
  for prefix, accu_log_prob in beams:
58
  labels = tuple(_reconstruct(prefix, blank))
59
  # log(p1 + p2) = logsumexp([log_p1, log_p2])
60
+ total_accu_log_prob[labels] = logsumexp(
61
+ [accu_log_prob, total_accu_log_prob.get(labels, NINF)]
62
+ )
63
+
64
+ labels_beams = [
65
+ (list(labels), accu_log_prob)
66
+ for labels, accu_log_prob in total_accu_log_prob.items()
67
+ ]
68
  labels_beams.sort(key=lambda x: x[1], reverse=True)
69
  labels = labels_beams[0][0]
70
 
71
  return labels
72
 
73
+
74
  def greedy_decode(emission_log_prob, blank=0):
75
  labels = np.argmax(emission_log_prob, axis=-1)
76
  labels = _reconstruct(labels, blank=blank)
77
  return labels
78
 
79
+
80
+ def ctc_decode(
81
+ log_probs, which_ctc_decoder="beam_search", label_2_char=None, blank=0, beam_size=25
82
+ ):
83
  emission_log_probs = np.transpose(log_probs.cpu().numpy(), (1, 0, 2))
84
  # size of emission_log_probs: (batch, length, class)
85
 
86
  decoded_list = []
87
  for emission_log_prob in emission_log_probs:
88
  if which_ctc_decoder == "beam_search":
89
+ decoded = beam_search_decode(
90
+ emission_log_prob, blank=blank, beam_size=beam_size
91
+ )
92
  elif which_ctc_decoder == "greedy":
93
  decoded = greedy_decode(emission_log_prob, blank=blank)
94
  else:
 
100
  decoded_list.append(decoded)
101
  return decoded_list
102
 
103
+
104
  """
105
  --------------------
106
  Evaluation Metrics
107
  --------------------
108
  """
109
+
110
+
111
  def compute_wer_and_cer_for_batch(batch_preds, batch_gts):
112
  cer_batch = fastwer.score(batch_preds, batch_gts, char_level=True)
113
  wer_batch = fastwer.score(batch_preds, batch_gts)
114
  return cer_batch, wer_batch
115
 
116
+
117
  def compute_wer_and_cer_for_sample(str_pred, str_gt):
118
  cer_sample = fastwer.score_sent(str_pred, str_gt, char_level=True)
119
  wer_sample = fastwer.score_sent(str_pred, str_gt)
iam_line_recognition/utils_unique_chars.py CHANGED
@@ -3,6 +3,7 @@ import numpy as np
3
 
4
  from dataset import read_IAM_label_txt_file
5
 
 
6
  def list_unique_characters_in_IAM_dataset(FLAGS):
7
  _, all_labels = read_IAM_label_txt_file(FLAGS.file_txt_labels)
8
 
@@ -16,8 +17,8 @@ def list_unique_characters_in_IAM_dataset(FLAGS):
16
  unique_chars = sorted(unique_chars)
17
  unique_chars = np.array(unique_chars)
18
  unique_chars = np.unique(unique_chars)
19
- unique_chars = ''.join(unique_chars)
20
-
21
  # prints all unique chars in the IAM dataset
22
  print(unique_chars)
23
 
@@ -25,19 +26,27 @@ def list_unique_characters_in_IAM_dataset(FLAGS):
25
  print(f"Number of unique characters : {len(unique_chars)}")
26
  return
27
 
 
28
  def main():
29
- file_txt_labels = "/home/abhishek/Desktop/RUG/hw_recognition/IAM-data/iam_lines_gt.txt"
 
 
30
 
31
  parser = argparse.ArgumentParser(
32
  formatter_class=argparse.ArgumentDefaultsHelpFormatter
33
  )
34
 
35
- parser.add_argument("--file_txt_labels", default=file_txt_labels,
36
- type=str, help="full path to label text file")
 
 
 
 
37
 
38
  FLAGS, unparsed = parser.parse_known_args()
39
  list_unique_characters_in_IAM_dataset(FLAGS)
40
  return
41
 
 
42
  if __name__ == "__main__":
43
  main()
 
3
 
4
  from dataset import read_IAM_label_txt_file
5
 
6
+
7
  def list_unique_characters_in_IAM_dataset(FLAGS):
8
  _, all_labels = read_IAM_label_txt_file(FLAGS.file_txt_labels)
9
 
 
17
  unique_chars = sorted(unique_chars)
18
  unique_chars = np.array(unique_chars)
19
  unique_chars = np.unique(unique_chars)
20
+ unique_chars = "".join(unique_chars)
21
+
22
  # prints all unique chars in the IAM dataset
23
  print(unique_chars)
24
 
 
26
  print(f"Number of unique characters : {len(unique_chars)}")
27
  return
28
 
29
+
30
  def main():
31
+ file_txt_labels = (
32
+ "/home/abhishek/Desktop/RUG/hw_recognition/IAM-data/iam_lines_gt.txt"
33
+ )
34
 
35
  parser = argparse.ArgumentParser(
36
  formatter_class=argparse.ArgumentDefaultsHelpFormatter
37
  )
38
 
39
+ parser.add_argument(
40
+ "--file_txt_labels",
41
+ default=file_txt_labels,
42
+ type=str,
43
+ help="full path to label text file",
44
+ )
45
 
46
  FLAGS, unparsed = parser.parse_known_args()
47
  list_unique_characters_in_IAM_dataset(FLAGS)
48
  return
49
 
50
+
51
  if __name__ == "__main__":
52
  main()