abhishekrs4 commited on
Commit
bd421ea
1 Parent(s): 2cabce6

added iam_line_recognition module

Browse files
iam_line_recognition/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import os, sys
2
+
3
+ sys.path.append(os.path.dirname(os.path.realpath(__file__)))
iam_line_recognition/dataset.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn
4
+ import numpy as np
5
+ from PIL import Image
6
+ from skimage.io import imread
7
+ import torchvision.transforms as transforms
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from sklearn.model_selection import train_test_split
10
+
11
+ def read_IAM_label_txt_file(file_txt_labels):
12
+ """
13
+ ---------
14
+ Arguments
15
+ ---------
16
+ file_txt_labels : str
17
+ full path to the text file containing labels
18
+
19
+ -------
20
+ Returns
21
+ -------
22
+ a tuple of
23
+ all_image_files : list
24
+ a list of all image file names
25
+ all_labels : list
26
+ a list of all labels
27
+ """
28
+ label_file_handler = open(file_txt_labels, mode="r")
29
+ all_lines = label_file_handler.readlines()
30
+ num_lines = len(all_lines)
31
+
32
+ all_image_files = []
33
+ all_labels = []
34
+
35
+ for cur_line_num in range(num_lines):
36
+ if cur_line_num % 3 == 0:
37
+ all_image_files.append(all_lines[cur_line_num].strip())
38
+ elif cur_line_num % 3 == 1:
39
+ all_labels.append(all_lines[cur_line_num].strip())
40
+ else:
41
+ continue
42
+
43
+ return all_image_files, all_labels
44
+
45
+ class HWRecogIAMDataset(Dataset):
46
+ """
47
+ Main dataset class to be used only for training, validation and internal testing
48
+ """
49
+ CHAR_SET = ' !"#&\'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
50
+ CHAR_2_LABEL = {char: i + 1 for i, char in enumerate(CHAR_SET)}
51
+ LABEL_2_CHAR = {label: char for char, label in CHAR_2_LABEL.items()}
52
+
53
+ def __init__(self, list_image_files, list_labels, dir_images, image_height=32, image_width=768, which_set="train"):
54
+ """
55
+ ---------
56
+ Arguments
57
+ ---------
58
+ list_image_files : list
59
+ list of image files
60
+ list_labels : list
61
+ list of labels
62
+ dir_images : str
63
+ full path to directory containing images
64
+ image_height : int
65
+ image height (default: 32)
66
+ image_width : int
67
+ image width (default: 768)
68
+ which_set : str
69
+ a string indicating which set is being used (default: train)
70
+ """
71
+ self.list_labels = list_labels
72
+ self.dir_images = dir_images
73
+ self.list_image_files = list_image_files
74
+ self.image_width = image_width
75
+ self.image_height = image_height
76
+ self.which_set = which_set
77
+
78
+ if self.which_set == "train":
79
+ # apply data augmentation only for train set
80
+ self.transform = transforms.Compose([
81
+ transforms.ToPILImage(),
82
+ transforms.Resize((self.image_height, self.image_width), Image.BILINEAR),
83
+ transforms.RandomAffine(degrees=[-0.75, 0.75], translate=[0, 0.05], scale=[0.75, 1],
84
+ shear=[-35, 35], interpolation=transforms.InterpolationMode.BILINEAR, fill=255,
85
+ ),
86
+ transforms.ToTensor(),
87
+ transforms.Normalize(
88
+ mean=[0.485, 0.456, 0.406],
89
+ std=[0.229, 0.224, 0.225],
90
+ ),
91
+ ])
92
+ else:
93
+ self.transform = transforms.Compose([
94
+ transforms.ToPILImage(),
95
+ transforms.Resize((self.image_height, self.image_width), Image.BILINEAR),
96
+ transforms.ToTensor(),
97
+ transforms.Normalize(
98
+ mean=[0.485, 0.456, 0.406],
99
+ std=[0.229, 0.224, 0.225],
100
+ ),
101
+ ])
102
+
103
+ def __len__(self):
104
+ return len(self.list_image_files)
105
+
106
+ def __getitem__(self, idx):
107
+ image_file_name = self.list_image_files[idx]
108
+ image_gray = imread(os.path.join(self.dir_images, image_file_name))
109
+ image_3_channel = np.repeat(np.expand_dims(image_gray, -1), 3, -1)
110
+ image_3_channel = self.transform(image_3_channel)
111
+
112
+ label_string = self.list_labels[idx]
113
+ label_encoded = [self.CHAR_2_LABEL[c] for c in label_string]
114
+ label_length = [len(label_encoded)]
115
+
116
+ label_encoded = torch.LongTensor(label_encoded)
117
+ label_length = torch.LongTensor(label_length)
118
+
119
+ return image_3_channel, label_encoded, label_length
120
+
121
+ def IAM_collate_fn(batch):
122
+ """
123
+ collate function
124
+
125
+ ---------
126
+ Arguments
127
+ ---------
128
+ batch : tuple
129
+ a batch of input data as a tuple
130
+
131
+ -------
132
+ Returns
133
+ -------
134
+ a collated tuple of
135
+ images : tensor
136
+ tensor of batch images
137
+ labels : tensor
138
+ tensor of batch labels
139
+ label_lengths : tensor
140
+ tensor of batch label lengths
141
+ """
142
+ images, labels, label_lengths = zip(*batch)
143
+ images = torch.stack(images, 0)
144
+ labels = torch.cat(labels, 0)
145
+ label_lengths = torch.cat(label_lengths, 0)
146
+ return images, labels, label_lengths
147
+
148
+ def split_dataset(file_txt_labels, for_train=True):
149
+ """
150
+ ---------
151
+ Arguments
152
+ ---------
153
+ file_txt_labels : str
154
+ full path to the text file containing labels
155
+ for_train : bool
156
+ indicating whether split is for training or internal testing
157
+
158
+ -------
159
+ Returns
160
+ -------
161
+ a tuple of files depending for train or internal testing
162
+ """
163
+ all_image_files, all_labels = read_IAM_label_txt_file(file_txt_labels)
164
+ train_image_files, test_image_files, train_labels, test_labels = train_test_split(all_image_files, all_labels, test_size=0.1, random_state=4)
165
+ train_image_files, valid_image_files, train_labels, valid_labels = train_test_split(train_image_files, train_labels, test_size=0.1, random_state=4)
166
+ if for_train:
167
+ return train_image_files, valid_image_files, train_labels, valid_labels
168
+ else:
169
+ return test_image_files, test_labels
170
+
171
+ def get_dataloaders_for_training(train_x, train_y, valid_x, valid_y, dir_images, image_height=32, image_width=768, batch_size=8):
172
+ """
173
+ ---------
174
+ Arguments
175
+ ---------
176
+ train_x : list
177
+ list of train file names
178
+ train_y : list
179
+ list of train labels
180
+ valid_x : list
181
+ list of validation file names
182
+ valid_y : list
183
+ list of validation labels
184
+ dir_images : str
185
+ full directory path containing the images
186
+ image_height : int
187
+ image height (default: 32)
188
+ image_width : int
189
+ image width (default: 768)
190
+ batch_size : int
191
+ batch size (default: 8)
192
+
193
+ -------
194
+ Returns
195
+ -------
196
+ a tuple of dataloaders objects
197
+ train_loader : object
198
+ object of train set dataloader
199
+ valid_loader : object
200
+ object of validation set dataloader
201
+ """
202
+ train_dataset = HWRecogIAMDataset(train_x, train_y, dir_images, image_height=image_height, image_width=image_width, which_set="train")
203
+ valid_dataset = HWRecogIAMDataset(valid_x, valid_y, dir_images, image_height=image_height, image_width=image_width, which_set="valid")
204
+
205
+ train_loader = DataLoader(
206
+ train_dataset,
207
+ batch_size=batch_size,
208
+ shuffle=True,
209
+ num_workers=4,
210
+ collate_fn=IAM_collate_fn,
211
+ )
212
+ valid_loader = DataLoader(
213
+ valid_dataset,
214
+ batch_size=batch_size,
215
+ shuffle=False,
216
+ num_workers=4,
217
+ collate_fn=IAM_collate_fn,
218
+ )
219
+ return train_loader, valid_loader
220
+
221
+ def get_dataloader_for_testing(test_x, test_y, dir_images, image_height=32, image_width=768, batch_size=1):
222
+ """
223
+ ---------
224
+ Arguments
225
+ ---------
226
+ test_x : list
227
+ list of test file names
228
+ test_y : list
229
+ list of test labels
230
+ dir_images : str
231
+ full directory path containing the images
232
+ image_height : int
233
+ image height (default: 32)
234
+ image_width : int
235
+ image width (default: 768)
236
+ batch_size : int
237
+ batch size (default: 1)
238
+
239
+ -------
240
+ Returns
241
+ -------
242
+ test_loader : object
243
+ object of test set dataloader
244
+ """
245
+ test_dataset = HWRecogIAMDataset(test_x, test_y, dir_images=dir_images, image_height=image_height, image_width=image_width, which_set="test")
246
+ test_loader = DataLoader(
247
+ test_dataset,
248
+ batch_size=batch_size,
249
+ shuffle=False,
250
+ num_workers=4,
251
+ collate_fn=IAM_collate_fn,
252
+ )
253
+ return test_loader
iam_line_recognition/final_iam_line_recognizer.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import torch
5
+ import argparse
6
+ import torchvision
7
+ import numpy as np
8
+ import torch.nn as nn
9
+ from PIL import Image
10
+ from skimage.io import imread
11
+ import torch.nn.functional as F
12
+ from torch.utils.data import DataLoader
13
+ import torchvision.transforms as transforms
14
+
15
+ from dataset import HWRecogIAMDataset
16
+ from model_main import CRNN, STN_CRNN
17
+ from utils import ctc_decode, compute_wer_and_cer_for_sample
18
+
19
+
20
+ class DatasetFinalEval(HWRecogIAMDataset):
21
+ """
22
+ Dataset class for final evaluation - inherits main dataset class
23
+ """
24
+ def __init__(self, dir_images, image_height=32, image_width=768):
25
+ """
26
+ ---------
27
+ Arguments
28
+ ---------
29
+ dir_images : str
30
+ full path to directory containing images
31
+ image_height : int
32
+ image height (default: 32)
33
+ image_width : int
34
+ image width (default: 768)
35
+ """
36
+ self.dir_images = dir_images
37
+ self.image_files = [f for f in os.listdir(self.dir_images) if f.endswith(".png")]
38
+ self.image_width = image_width
39
+ self.image_height = image_height
40
+ self.transform = transforms.Compose([
41
+ transforms.ToPILImage(),
42
+ transforms.Resize((self.image_height, self.image_width), Image.BILINEAR),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize(
45
+ mean=[0.485, 0.456, 0.406],
46
+ std=[0.229, 0.224, 0.225],
47
+ ),
48
+ ])
49
+
50
+ def __len__(self):
51
+ return len(self.image_files)
52
+
53
+ def __getitem__(self, idx):
54
+ image_file_name = self.image_files[idx]
55
+ image_gray = imread(os.path.join(self.dir_images, image_file_name))
56
+ image_3_channel = np.repeat(np.expand_dims(image_gray, -1), 3, -1)
57
+ image_3_channel = self.transform(image_3_channel)
58
+ return image_3_channel
59
+
60
+ def get_dataloader_for_evaluation(dir_images, image_height=32, image_width=768, batch_size=1):
61
+ """
62
+ ---------
63
+ Arguments
64
+ ---------
65
+ dir_images : str
66
+ full path to directory containing images
67
+ image_height : int
68
+ image height (default: 32)
69
+ image_width : int
70
+ image width (default: 768)
71
+ batch_size : int
72
+ batch size to use for final evaluation (default: 1)
73
+
74
+ -------
75
+ Returns
76
+ -------
77
+ test_loader : object
78
+ dataset loader object for final evaluation
79
+ """
80
+ test_dataset = DatasetFinalEval(dir_images=dir_images, image_height=image_height, image_width=image_width)
81
+ test_loader = DataLoader(
82
+ test_dataset,
83
+ batch_size=batch_size,
84
+ shuffle=False,
85
+ num_workers=4,
86
+ )
87
+ return test_loader
88
+
89
+ def final_eval(hw_model, device, test_loader, dir_images, dir_results):
90
+ """
91
+ ---------
92
+ Arguments
93
+ ---------
94
+ hw_model : object
95
+ handwriting recognition model object
96
+ device : str
97
+ device to be used for running the evaluation
98
+ test_loader : object
99
+ dataset loader object
100
+ dir_images : str
101
+ full path to directory containing test images
102
+ dir_results : str
103
+ relative path to directory to save the predictions as txt files
104
+ """
105
+ hw_model.eval()
106
+ count = 0
107
+ num_test_samples = len(test_loader.dataset)
108
+ list_test_files = os.listdir(dir_images)
109
+
110
+ if not os.path.isdir(dir_results):
111
+ print(f"creating directory: {dir_results}")
112
+ os.makedirs(dir_results)
113
+
114
+ with torch.no_grad():
115
+ for image_test in test_loader:
116
+ file_test = list_test_files[count]
117
+ count += 1
118
+ """
119
+ if count == 11:
120
+ break
121
+ """
122
+ image_test = image_test.to(device, dtype=torch.float)
123
+
124
+ log_probs = hw_model(image_test)
125
+ pred_labels = ctc_decode(log_probs)
126
+ str_pred = [DatasetFinalEval.LABEL_2_CHAR[i] for i in pred_labels[0]]
127
+ str_pred = "".join(str_pred)
128
+
129
+ with open(os.path.join(dir_results, file_test+".txt"), "w", encoding="utf-8", newline="\n") as fh_pred:
130
+ fh_pred.write(str_pred)
131
+
132
+ print(f"progress: {count}/{num_test_samples}, test file: {list_test_files[count-1]}")
133
+ print(f"{str_pred}\n")
134
+ print(f"predictions saved in directory: ./{dir_results}\n")
135
+ return
136
+
137
+ def test_hw_recognizer(FLAGS):
138
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
139
+
140
+ num_classes = len(DatasetFinalEval.LABEL_2_CHAR) + 1
141
+ print(f"task - handwriting recognition")
142
+ print(f"model: {FLAGS.which_hw_model}")
143
+ print(f"image height: {FLAGS.image_height}, image width: {FLAGS.image_width}")
144
+
145
+ # load the right model
146
+ if FLAGS.which_hw_model == "crnn":
147
+ hw_model = CRNN(num_classes, FLAGS.image_height)
148
+ elif FLAGS.which_hw_model == "stn_crnn":
149
+ hw_model = STN_CRNN(num_classes, FLAGS.image_height, FLAGS.image_width)
150
+ else:
151
+ print(f"unidentified option : {FLAGS.which_hw_model}")
152
+ sys.exit(0)
153
+ dir_results = f"results_{FLAGS.which_hw_model}"
154
+
155
+ # choose a device for evaluation
156
+ if torch.cuda.is_available():
157
+ device = torch.device("cuda")
158
+ else:
159
+ device = torch.device("cpu")
160
+
161
+ hw_model.to(device)
162
+ hw_model.load_state_dict(torch.load(FLAGS.file_model))
163
+
164
+ # get test set dataloader
165
+ test_loader = get_dataloader_for_evaluation(
166
+ dir_images=FLAGS.dir_images, image_height=FLAGS.image_height, image_width=FLAGS.image_width,
167
+ )
168
+
169
+ # start the evaluation on the final test set
170
+ print(f"final evaluation of handwriting recognition model {FLAGS.which_hw_model} started\n")
171
+ final_eval(hw_model, device, test_loader, FLAGS.dir_images, dir_results)
172
+ print(f"final evaluation of handwriting recognition model completed!!!!")
173
+ return
174
+
175
+ def main():
176
+ image_height = 32
177
+ image_width = 768
178
+ which_hw_model = "crnn"
179
+ dir_images = "/home/abhishek/Desktop/RUG/hw_recognition/IAM-data/img/"
180
+ file_model = "model_crnn/crnn_H_32_W_768_E_177.pth"
181
+ save_predictions = 1
182
+
183
+ parser = argparse.ArgumentParser(
184
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
185
+ )
186
+
187
+ parser.add_argument("--image_height", default=image_height,
188
+ type=int, help="image height to be used to predict with the model")
189
+ parser.add_argument("--image_width", default=image_width,
190
+ type=int, help="image width to be used to predict with the model")
191
+ parser.add_argument("--dir_images", default=dir_images,
192
+ type=str, help="full directory path to directory containing images")
193
+ parser.add_argument("--which_hw_model", default=which_hw_model,
194
+ type=str, choices=["crnn", "stn_crnn"], help="which model to be used for prediction")
195
+ parser.add_argument("--file_model", default=file_model,
196
+ type=str, help="full path to trained model file (.pth)")
197
+ parser.add_argument("--save_predictions", default=save_predictions,
198
+ type=int, choices=[0, 1], help="save or do not save the predictions (1 - save, 0 - do not save)")
199
+
200
+ FLAGS, unparsed = parser.parse_known_args()
201
+ test_hw_recognizer(FLAGS)
202
+ return
203
+
204
+ if __name__ == "__main__":
205
+ main()
iam_line_recognition/logger_utils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import json
3
+
4
+ def write_json_file(file_json, dict_data):
5
+ """
6
+ ---------
7
+ Arguments
8
+ ---------
9
+ file_json : str
10
+ full path of json file to be saved
11
+ dict_data : dict
12
+ dictionary of params to be saved in the json file
13
+ """
14
+ with open(file_json, "w", encoding="utf-8") as fh:
15
+ fh.write(json.dumps(dict_data, indent=4))
16
+ return
17
+
18
+ class CSVWriter:
19
+ """
20
+ for writing tabular data to a csv file
21
+ """
22
+ def __init__(self, file_name, column_names):
23
+ """
24
+ ---------
25
+ Arguments
26
+ ---------
27
+ file_name : str
28
+ full path of csv file
29
+ column_names : list
30
+ a list of columns names to be used to create the csv file
31
+ """
32
+ self.file_name = file_name
33
+ self.column_names = column_names
34
+
35
+ self.file_handle = open(self.file_name, "w", encoding="utf-8", newline="\n")
36
+ self.writer = csv.writer(self.file_handle)
37
+
38
+ self.write_header()
39
+ print(f"{self.file_name} created successfully with header row")
40
+
41
+ def write_header(self):
42
+ """
43
+ writes header into csv file
44
+ """
45
+ self.write_row(self.column_names)
46
+ return
47
+
48
+ def write_row(self, row):
49
+ """
50
+ writes a row into csv file
51
+ """
52
+ self.writer.writerow(row)
53
+ return
54
+
55
+ def close(self):
56
+ """
57
+ close the file
58
+ """
59
+ self.file_handle.close()
60
+ return
iam_line_recognition/model_main.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from model_visual_features import ResNetFeatureExtractor, TPS_SpatialTransformerNetwork
6
+
7
+ class HW_RNN_Seq2Seq(nn.Module):
8
+ """
9
+ Visual Seq2Seq model using BiLSTM
10
+ """
11
+ def __init__(self, num_classes, image_height, cnn_output_channels=512, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
12
+ """
13
+ ---------
14
+ Arguments
15
+ ---------
16
+ num_classes : int
17
+ num of distinct characters (classes) in the dataset
18
+ image_height : int
19
+ image height
20
+ cnn_output_channels : int
21
+ number of channels output from the CNN visual feature extractor (default: 512)
22
+ num_feats_mapped_seq_hidden : int
23
+ number of features to be used in the mapped visual features as sequences (default: 128)
24
+ num_feats_seq_hidden : int
25
+ number of features to be used in the LSTM for sequence modeling (default: 256)
26
+ """
27
+ super().__init__()
28
+ self.output_height = image_height // 32
29
+
30
+ self.dropout = nn.Dropout(p=0.25)
31
+ self.map_visual_to_seq = nn.Linear(cnn_output_channels * self.output_height, num_feats_mapped_seq_hidden)
32
+
33
+ self.b_lstm_1 = nn.LSTM(num_feats_mapped_seq_hidden, num_feats_seq_hidden, bidirectional=True)
34
+ self.b_lstm_2 = nn.LSTM(2 * num_feats_seq_hidden, num_feats_seq_hidden, bidirectional=True)
35
+
36
+ self.final_dense = nn.Linear(2 * num_feats_seq_hidden, num_classes)
37
+
38
+ def forward(self, visual_feats):
39
+ visual_feats = visual_feats.permute(3, 0, 1, 2)
40
+ # WBCH
41
+ # the sequence is along the width of the image as a sentence
42
+
43
+ visual_feats = visual_feats.contiguous().view(visual_feats.shape[0], visual_feats.shape[1], -1)
44
+ # WBC
45
+
46
+ seq = self.map_visual_to_seq(visual_feats)
47
+ seq = self.dropout(seq)
48
+ lstm_1, _ = self.b_lstm_1(seq)
49
+ lstm_2, _ = self.b_lstm_2(lstm_1)
50
+ lstm_2 = self.dropout(lstm_2)
51
+
52
+ dense_output = self.final_dense(lstm_2)
53
+ # [seq_len, B, num_classes]
54
+
55
+ log_probs = F.log_softmax(dense_output, dim=2)
56
+
57
+ return log_probs
58
+
59
+
60
+ class CRNN(nn.Module):
61
+ """
62
+ Hybrid CNN - RNN model
63
+ CNN - Modified ResNet34 for visual features
64
+ RNN - BiLSTM for seq2seq modeling
65
+ """
66
+ def __init__(self, num_classes, image_height, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
67
+ """
68
+ ---------
69
+ Arguments
70
+ ---------
71
+ num_classes : int
72
+ num of distinct characters (classes) in the dataset
73
+ image_height : int
74
+ image height
75
+ num_feats_mapped_seq_hidden : int
76
+ number of features to be used in the mapped visual features as sequences (default: 128)
77
+ num_feats_seq_hidden : int
78
+ number of features to be used in the LSTM for sequence modeling (default: 256)
79
+ """
80
+ super().__init__()
81
+ self.visual_feature_extractor = ResNetFeatureExtractor()
82
+ self.rnn_seq2seq_module = HW_RNN_Seq2Seq(num_classes, image_height, self.visual_feature_extractor.output_channels, num_feats_mapped_seq_hidden, num_feats_seq_hidden)
83
+
84
+ def forward(self, x):
85
+ visual_feats = self.visual_feature_extractor(x)
86
+ # [B, 512, H/32, W/32]
87
+
88
+ log_probs = self.rnn_seq2seq_module(visual_feats)
89
+ return log_probs
90
+
91
+
92
+ class STN_CRNN(nn.Module):
93
+ """
94
+ STN + CNN + RNN model
95
+ STN - Spatial Transformer Network for learning variable handwriting
96
+ CNN - Modified ResNet34 for visual features
97
+ RNN - BiLSTM for seq2seq modeling
98
+ """
99
+ def __init__(self, num_classes, image_height, image_width, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
100
+ """
101
+ ---------
102
+ Arguments
103
+ ---------
104
+ num_classes : int
105
+ num of distinct characters (classes) in the dataset
106
+ image_height : int
107
+ image height
108
+ image_width : int
109
+ image width
110
+ num_feats_mapped_seq_hidden : int
111
+ number of features to be used in the mapped visual features as sequences (default: 128)
112
+ num_feats_seq_hidden : int
113
+ number of features to be used in the LSTM for sequence modeling (default: 256)
114
+ """
115
+ super().__init__()
116
+ self.stn = TPS_SpatialTransformerNetwork(
117
+ 80,
118
+ (image_height, image_width),
119
+ (image_height, image_width),
120
+ I_channel_num=3,
121
+ )
122
+ self.visual_feature_extractor = ResNetFeatureExtractor()
123
+ self.rnn_seq2seq_module = HW_RNN_Seq2Seq(num_classes, image_height, self.visual_feature_extractor.output_channels, num_feats_mapped_seq_hidden, num_feats_seq_hidden)
124
+
125
+ def forward(self, x):
126
+ stn_output = self.stn(x)
127
+ visual_feats = self.visual_feature_extractor(stn_output)
128
+ log_probs = self.rnn_seq2seq_module(visual_feats)
129
+ return log_probs
130
+
131
+ """
132
+ class STN_PP_CRNN(nn.Module):
133
+ def __init__(self, num_classes, image_height, image_width, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
134
+ super().__init__()
135
+ self.stn = TPS_SpatialTransformerNetwork(
136
+ 20,
137
+ (image_height, image_width),
138
+ (image_height, image_width),
139
+ I_channel_num=3,
140
+ )
141
+ self.visual_feature_extractor = ResNetFeatureExtractor()
142
+ self.pp_block = PyramidPoolBlock(num_channels=self.visual_feature_extractor.output_channels)
143
+ self.rnn_seq2seq_module = HW_RNN_Seq2Seq(num_classes, image_height, self.visual_feature_extractor.output_channels, num_feats_mapped_seq_hidden, num_feats_seq_hidden)
144
+
145
+ def forward(self, x):
146
+ stn_output = self.stn(x)
147
+ visual_feats = self.visual_feature_extractor(stn_output)
148
+ pp_feats = self.pp_block(visual_feats)
149
+ log_probs = self.rnn_seq2seq_module(pp_feats)
150
+ return log_probs
151
+ """
iam_line_recognition/model_visual_features.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ from typing import List
5
+ from torch import Tensor
6
+ import torch.nn.functional as F
7
+ from torchvision.models.resnet import BasicBlock, model_urls, load_state_dict_from_url, conv1x1, conv3x3
8
+
9
+ device = torch.device("cuda")
10
+
11
+ class CustomResNet(nn.Module):
12
+ def __init__(
13
+ self,
14
+ layers: List[int],
15
+ block=BasicBlock,
16
+ zero_init_residual=False,
17
+ groups=1,
18
+ num_classes=1000,
19
+ width_per_group=64,
20
+ replace_stride_with_dilation=None,
21
+ norm_layer=None,
22
+ ):
23
+
24
+ super().__init__()
25
+
26
+ if norm_layer is None:
27
+ self._norm_layer = nn.BatchNorm2d
28
+
29
+ self.inplanes = 64
30
+ self.dilation = 1
31
+
32
+ if replace_stride_with_dilation is None:
33
+ # each element in the tuple indicates if we should replace
34
+ # the 2x2 stride with a dilated convolution instead
35
+ replace_stride_with_dilation = [False, False, False]
36
+
37
+ if len(replace_stride_with_dilation) != 3:
38
+ raise ValueError(
39
+ "replace_stride_with_dilation should be None "
40
+ f"or a 3-element tuple, got {replace_stride_with_dilation}"
41
+ )
42
+
43
+ self.groups = groups
44
+ self.base_width = width_per_group
45
+
46
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
47
+ self.bn1 = self._norm_layer(self.inplanes)
48
+ self.relu = nn.ReLU(inplace=True)
49
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=(2, 1), padding=1)
50
+ self.layer1 = self._make_layer(block, 64, layers[0])
51
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=(2, 1), dilate=replace_stride_with_dilation[0])
52
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=(2, 2), dilate=replace_stride_with_dilation[1])
53
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=(2, 1), dilate=replace_stride_with_dilation[2])
54
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
55
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
56
+
57
+ for m in self.modules():
58
+ if isinstance(m, nn.Conv2d):
59
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
60
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
61
+ nn.init.constant_(m.weight, 1)
62
+ nn.init.constant_(m.bias, 0)
63
+
64
+ # Zero-initialize the last BN in each residual branch,
65
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
66
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
67
+ if zero_init_residual:
68
+ for m in self.modules():
69
+ if isinstance(m, BasicBlock):
70
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
71
+
72
+ def _make_layer(
73
+ self,
74
+ block,
75
+ planes,
76
+ blocks,
77
+ stride=1,
78
+ dilate=False,
79
+ ) -> nn.Sequential:
80
+ norm_layer = self._norm_layer
81
+ downsample = None
82
+ previous_dilation = self.dilation
83
+ if dilate:
84
+ self.dilation *= stride
85
+ stride = 1
86
+ if stride != 1 or self.inplanes != planes * block.expansion:
87
+ downsample = nn.Sequential(
88
+ conv1x1(self.inplanes, planes * block.expansion, stride),
89
+ norm_layer(planes * block.expansion),
90
+ )
91
+
92
+ layers = []
93
+ layers.append(
94
+ block(
95
+ self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
96
+ )
97
+ )
98
+ self.inplanes = planes * block.expansion
99
+ for _ in range(1, blocks):
100
+ layers.append(
101
+ block(
102
+ self.inplanes,
103
+ planes,
104
+ groups=self.groups,
105
+ base_width=self.base_width,
106
+ dilation=self.dilation,
107
+ norm_layer=norm_layer,
108
+ )
109
+ )
110
+
111
+ return nn.Sequential(*layers)
112
+
113
+ def _forward_impl(self, x: Tensor) -> Tensor:
114
+ # See note [TorchScript super()]
115
+ x = self.conv1(x)
116
+ x = self.bn1(x)
117
+ x = self.relu(x)
118
+ x = self.maxpool(x)
119
+
120
+ x = self.layer1(x)
121
+ x = self.layer2(x)
122
+ x = self.layer3(x)
123
+ x = self.layer4(x)
124
+ return x
125
+
126
+ def forward(self, x: Tensor) -> Tensor:
127
+ return self._forward_impl(x)
128
+
129
+ def _resnet(layers: List[int], pretrained=True) -> CustomResNet:
130
+ model = CustomResNet(layers)
131
+
132
+ if pretrained:
133
+ model.load_state_dict(load_state_dict_from_url(model_urls["resnet34"]))
134
+
135
+ return model
136
+
137
+ def resnet34(*, pretrained=True) -> CustomResNet:
138
+ """ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
139
+ Args:
140
+ weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The
141
+ pretrained weights to use. See
142
+ :class:`~torchvision.models.ResNet34_Weights` below for
143
+ more details, and possible values. By default, no pre-trained
144
+ weights are used.
145
+ progress (bool, optional): If True, displays a progress bar of the
146
+ download to stderr. Default is True.
147
+ **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
148
+ base class. Please refer to the `source code
149
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
150
+ for more details about this class.
151
+ .. autoclass:: torchvision.models.ResNet34_Weights
152
+ :members:
153
+ """
154
+
155
+ return _resnet([3, 4, 6, 3], pretrained=pretrained)
156
+
157
+
158
+ class ResNetFeatureExtractor(nn.Module):
159
+ """
160
+ Defines Base ResNet-34 feature extractor
161
+ """
162
+ def __init__(self, pretrained=True):
163
+ """
164
+ ---------
165
+ Arguments
166
+ ---------
167
+ pretrained : bool (default=True)
168
+ boolean to indicate whether to use a pretrained resnet model or not
169
+ """
170
+ super().__init__()
171
+ self.output_channels = 512
172
+ self.resnet34 = resnet34(pretrained=pretrained)
173
+
174
+ def forward(self, x):
175
+ block1 = self.resnet34.conv1(x)
176
+ block1 = self.resnet34.bn1(block1)
177
+ block1 = self.resnet34.relu(block1) # [64, H/2, W/2]
178
+
179
+ block2 = self.resnet34.maxpool(block1)
180
+ block2 = self.resnet34.layer1(block2) # [64, H/4, W/4]
181
+ block3 = self.resnet34.layer2(block2) # [128, H/8, W/8]
182
+ block4 = self.resnet34.layer3(block3) # [256, H/16, W/16]
183
+ resnet_features = self.resnet34.layer4(block4) # [512, H/32, W/32]
184
+
185
+ # [B, 512, H/32, W/32]
186
+ return resnet_features
187
+
188
+
189
+ #########################################
190
+ ### STN - Spatial Transformer Network ###
191
+ #########################################
192
+ class TPS_SpatialTransformerNetwork(nn.Module):
193
+ """ Rectification Network of RARE, namely TPS based STN """
194
+
195
+ def __init__(self, num_fiducial_points, I_size, I_r_size, I_channel_num=1):
196
+ """ Based on RARE TPS
197
+ input:
198
+ batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
199
+ I_size : (height, width) of the input image I
200
+ I_r_size : (height, width) of the rectified image I_r
201
+ I_channel_num : the number of channels of the input image I
202
+ output:
203
+ batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width]
204
+ """
205
+ super(TPS_SpatialTransformerNetwork, self).__init__()
206
+ self.num_fiducial_points = num_fiducial_points
207
+ self.I_size = I_size
208
+ self.I_r_size = I_r_size # = (I_r_height, I_r_width)
209
+ self.I_channel_num = I_channel_num
210
+ self.LocalizationNetwork = LocalizationNetwork(self.num_fiducial_points, self.I_channel_num)
211
+ self.GridGenerator = GridGenerator(self.num_fiducial_points, self.I_r_size)
212
+
213
+ def forward(self, batch_I):
214
+ batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
215
+ build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2
216
+ build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2])
217
+
218
+ if torch.__version__ > "1.2.0":
219
+ batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True)
220
+ else:
221
+ batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border')
222
+
223
+ return batch_I_r
224
+
225
+
226
+ class LocalizationNetwork(nn.Module):
227
+ """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """
228
+
229
+ def __init__(self, num_fiducial_points, I_channel_num):
230
+ super(LocalizationNetwork, self).__init__()
231
+ self.num_fiducial_points = num_fiducial_points
232
+ self.I_channel_num = I_channel_num
233
+ self.conv = nn.Sequential(
234
+ nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1,
235
+ bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
236
+ nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2
237
+ nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True),
238
+ nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4
239
+ nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True),
240
+ nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8
241
+ nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True),
242
+ nn.AdaptiveAvgPool2d(1) # batch_size x 512
243
+ )
244
+
245
+ self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True))
246
+ self.localization_fc2 = nn.Linear(256, self.num_fiducial_points * 2)
247
+
248
+ # Init fc2 in LocalizationNetwork
249
+ self.localization_fc2.weight.data.fill_(0)
250
+ """ see RARE paper Fig. 6 (a) """
251
+ ctrl_pts_x = np.linspace(-1.0, 1.0, int(num_fiducial_points / 2))
252
+ ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(num_fiducial_points / 2))
253
+ ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(num_fiducial_points / 2))
254
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
255
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
256
+ initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
257
+ self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1)
258
+
259
+ def forward(self, batch_I):
260
+ """
261
+ input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width]
262
+ output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2]
263
+ """
264
+ batch_size = batch_I.size(0)
265
+ features = self.conv(batch_I).view(batch_size, -1)
266
+ batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.num_fiducial_points, 2)
267
+ return batch_C_prime
268
+
269
+
270
+ class GridGenerator(nn.Module):
271
+ """ Grid Generator of RARE, which produces P_prime by multipling T with P """
272
+
273
+ def __init__(self, num_fiducial_points, I_r_size):
274
+ """ Generate P_hat and inv_delta_C for later """
275
+ super(GridGenerator, self).__init__()
276
+ self.eps = 1e-6
277
+ self.I_r_height, self.I_r_width = I_r_size
278
+ self.num_fiducial_points = num_fiducial_points
279
+ self.C = self._build_C(self.num_fiducial_points) # F x 2
280
+ self.P = self._build_P(self.I_r_width, self.I_r_height)
281
+ ## for multi-gpu, you need register buffer
282
+ self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.num_fiducial_points, self.C)).float()) # F+3 x F+3
283
+ self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.num_fiducial_points, self.C, self.P)).float()) # n x F+3
284
+ ## for fine-tuning with different image width, you may use below instead of self.register_buffer
285
+ #self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.num_fiducial_points, self.C)).float().cuda() # F+3 x F+3
286
+ #self.P_hat = torch.tensor(self._build_P_hat(self.num_fiducial_points, self.C, self.P)).float().cuda() # n x F+3
287
+
288
+ def _build_C(self, F):
289
+ """ Return coordinates of fiducial points in I_r; C """
290
+ ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
291
+ ctrl_pts_y_top = -1 * np.ones(int(F / 2))
292
+ ctrl_pts_y_bottom = np.ones(int(F / 2))
293
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
294
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
295
+ C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
296
+ return C # F x 2
297
+
298
+ def _build_inv_delta_C(self, F, C):
299
+ """ Return inv_delta_C which is needed to calculate T """
300
+ hat_C = np.zeros((F, F), dtype=float) # F x F
301
+ for i in range(0, F):
302
+ for j in range(i, F):
303
+ r = np.linalg.norm(C[i] - C[j])
304
+ hat_C[i, j] = r
305
+ hat_C[j, i] = r
306
+ np.fill_diagonal(hat_C, 1)
307
+ hat_C = (hat_C ** 2) * np.log(hat_C)
308
+ # print(C.shape, hat_C.shape)
309
+ delta_C = np.concatenate( # F+3 x F+3
310
+ [
311
+ np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3
312
+ np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3
313
+ np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3
314
+ ],
315
+ axis=0
316
+ )
317
+ inv_delta_C = np.linalg.inv(delta_C)
318
+ return inv_delta_C # F+3 x F+3
319
+
320
+ def _build_P(self, I_r_width, I_r_height):
321
+ I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width
322
+ I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height
323
+ P = np.stack( # self.I_r_width x self.I_r_height x 2
324
+ np.meshgrid(I_r_grid_x, I_r_grid_y),
325
+ axis=2
326
+ )
327
+ return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2
328
+
329
+ def _build_P_hat(self, F, C, P):
330
+ n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
331
+ P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2
332
+ C_tile = np.expand_dims(C, axis=0) # 1 x F x 2
333
+ P_diff = P_tile - C_tile # n x F x 2
334
+ rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F
335
+ rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F
336
+ P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1)
337
+ return P_hat # n x F+3
338
+
339
+ def build_P_prime(self, batch_C_prime):
340
+ """ Generate Grid from batch_C_prime [batch_size x F x 2] """
341
+ batch_size = batch_C_prime.size(0)
342
+ batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
343
+ batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
344
+ batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros(
345
+ batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2
346
+ batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2
347
+ batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
348
+ return batch_P_prime # batch_size x n x 2
349
+
350
+
351
+ """
352
+ ########################################
353
+ ######## Pyramid Pooling Block #########
354
+ ########################################
355
+ class PyramidPool(nn.Module):
356
+ def __init__(self, pool_kernel_size, in_channels, out_channels):
357
+ super().__init__()
358
+ self.pool_kernel_size = pool_kernel_size
359
+ self.avg_pool_block = nn.Sequential(
360
+ nn.AvgPool2d((1, self.pool_kernel_size), stride=(1, self.pool_kernel_size)),
361
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding="same", bias=False),
362
+ nn.BatchNorm2d(out_channels),
363
+ nn.ELU(inplace=True),
364
+ )
365
+
366
+ for m in self.modules():
367
+ if isinstance(m, nn.Conv2d):
368
+ nn.init.xavier_normal_(m.weight)
369
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
370
+ nn.init.constant_(m.weight, 1)
371
+ nn.init.constant_(m.bias, 0)
372
+
373
+ def forward(self, x):
374
+ _, _, in_height, in_width = x.size()
375
+ x = self.avg_pool_block(x)
376
+ x = F.interpolate(x, size=(in_height, in_width), mode="bilinear")
377
+ return x
378
+
379
+
380
+ class PyramidPoolBlock(nn.Module):
381
+ def __init__(self, pyramid_pool_kernel_sizes=[4, 8, 16, 32], num_channels=512):
382
+ super().__init__()
383
+ pp_out_channels = 256
384
+ self.pyramid_pool_layers = nn.ModuleList([PyramidPool(pool_kernel_size=k, in_channels=num_channels, out_channels=pp_out_channels) for k in pyramid_pool_kernel_sizes])
385
+ self.final_layer = nn.Sequential(
386
+ nn.Conv2d((num_channels + (pp_out_channels * len(self.pyramid_pool_layers))), num_channels, (1, 5), stride=1, padding="same"),
387
+ nn.BatchNorm2d(num_channels),
388
+ nn.ELU(inplace=True),
389
+ nn.Dropout(p=0.1),
390
+ )
391
+
392
+ def forward(self, input):
393
+ pp_outputs = []
394
+ for pp_layer in self.pyramid_pool_layers:
395
+ pp_output = pp_layer(input)
396
+ pp_outputs.append(pp_output)
397
+ pp_outputs.append(input)
398
+ x = torch.cat(pp_outputs, dim=1)
399
+ x = self.final_layer(x)
400
+ return x
401
+ """
iam_line_recognition/test_internal.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import torch
5
+ import argparse
6
+ import torchvision
7
+ import numpy as np
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from logger_utils import CSVWriter
12
+ from model_main import CRNN, STN_CRNN
13
+ from utils import ctc_decode, compute_wer_and_cer_for_sample
14
+ from dataset import HWRecogIAMDataset, split_dataset, get_dataloader_for_testing
15
+
16
+
17
+ def test(hw_model, test_loader, device, list_test_files, which_ctc_decoder="beam_search", save_prediction_stats=False):
18
+ """
19
+ ---------
20
+ Arguments
21
+ ---------
22
+ hw_model : object
23
+ handwriting recognition model object
24
+ test_loader : object
25
+ dataset loader object
26
+ device : str
27
+ device to be used for running the evaluation
28
+ list_test_files : list
29
+ list of all the test files
30
+ which_ctc_decoder : str
31
+ string indicating which ctc decoder to use
32
+ save_prediction_stats : bool
33
+ whether to save prediction stats
34
+ """
35
+ hw_model.eval()
36
+ num_test_samples = len(test_loader.dataset)
37
+ num_test_batches = len(test_loader)
38
+
39
+ count = 0
40
+ list_test_cers, list_test_wers = [], []
41
+
42
+ if save_prediction_stats:
43
+ csv_writer = CSVWriter(
44
+ file_name="pred_stats.csv",
45
+ column_names=["file_name", "num_chars", "num_words", "cer", "wer"]
46
+ )
47
+
48
+ with torch.no_grad():
49
+ for images, labels, length_labels in test_loader:
50
+ count += 1
51
+ images = images.to(device, dtype=torch.float)
52
+ log_probs = hw_model(images)
53
+ pred_labels = ctc_decode(log_probs, which_ctc_decoder=which_ctc_decoder)
54
+ labels = labels.cpu().numpy().tolist()
55
+
56
+ str_label = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in labels]
57
+ str_label = "".join(str_label)
58
+ str_pred = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in pred_labels[0]]
59
+ str_pred = "".join(str_pred)
60
+
61
+ cer_sample, wer_sample = compute_wer_and_cer_for_sample(str_pred, str_label)
62
+ list_test_cers.append(cer_sample)
63
+ list_test_wers.append(wer_sample)
64
+
65
+ print(f"progress: {count}/{num_test_samples}, test file: {list_test_files[count-1]}")
66
+ print(f"{str_label} - label")
67
+ print(f"{str_pred} - prediction")
68
+ print(f"cer: {cer_sample:.3f}, wer: {wer_sample:.3f}\n")
69
+
70
+ if save_prediction_stats:
71
+ csv_writer.write_row([
72
+ list_test_files[count-1],
73
+ len(str_label),
74
+ len(str_label.split(" ")),
75
+ cer_sample,
76
+ wer_sample,
77
+ ])
78
+ list_test_cers = np.array(list_test_cers)
79
+ list_test_wers = np.array(list_test_wers)
80
+ mean_test_cer = np.mean(list_test_cers)
81
+ mean_test_wer = np.mean(list_test_wers)
82
+ print(f"test set - mean cer: {mean_test_cer:.3f}, mean wer: {mean_test_wer:.3f}\n")
83
+
84
+ if save_prediction_stats:
85
+ csv_writer.close()
86
+ return
87
+
88
+ def test_hw_recognizer(FLAGS):
89
+ file_txt_labels = os.path.join(FLAGS.dir_dataset, "iam_lines_gt.txt")
90
+ dir_images = os.path.join(FLAGS.dir_dataset, "img")
91
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
92
+
93
+ # choose a device for testing
94
+ if torch.cuda.is_available():
95
+ device = torch.device("cuda")
96
+ else:
97
+ device = torch.device("cpu")
98
+
99
+ # get the internal test set files
100
+ test_x, test_y = split_dataset(file_txt_labels, for_train=False)
101
+ num_test_samples = len(test_x)
102
+ # get the internal test set dataloader
103
+ test_loader = get_dataloader_for_testing(
104
+ test_x, test_y,
105
+ dir_images=dir_images, image_height=FLAGS.image_height, image_width=FLAGS.image_width,
106
+ )
107
+
108
+ num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1
109
+ print(f"task - handwriting recognition")
110
+ print(f"model: {FLAGS.which_hw_model}, ctc decoder: {FLAGS.which_ctc_decoder}")
111
+ print(f"image height: {FLAGS.image_height}, image width: {FLAGS.image_width}")
112
+ print(f"num test samples: {num_test_samples}")
113
+
114
+ # load the right model
115
+ if FLAGS.which_hw_model == "crnn":
116
+ hw_model = CRNN(num_classes, FLAGS.image_height)
117
+ elif FLAGS.which_hw_model == "stn_crnn":
118
+ hw_model = STN_CRNN(num_classes, FLAGS.image_height, FLAGS.image_width)
119
+ else:
120
+ print(f"unidentified option : {FLAGS.which_hw_model}")
121
+ sys.exit(0)
122
+ hw_model.to(device)
123
+ hw_model.load_state_dict(torch.load(FLAGS.file_model))
124
+
125
+ # start testing of the model on the internal set
126
+ print(f"testing of handwriting recognition model {FLAGS.which_hw_model} started\n")
127
+ test(hw_model, test_loader, device, test_x, FLAGS.which_ctc_decoder, bool(FLAGS.save_prediction_stats))
128
+ print(f"testing handwriting recognition model completed!!!!")
129
+ return
130
+
131
+ def main():
132
+ image_height = 32
133
+ image_width = 768
134
+ which_hw_model = "crnn"
135
+ dir_dataset = "/home/abhishek/Desktop/RUG/hw_recognition/IAM-data/"
136
+ file_model = "model_crnn/crnn_H_32_W_768_E_177.pth"
137
+ which_ctc_decoder = "beam_search"
138
+ save_prediction_stats = 0
139
+
140
+ parser = argparse.ArgumentParser(
141
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
142
+ )
143
+
144
+ parser.add_argument("--image_height", default=image_height,
145
+ type=int, help="image height to be used to predict with the model")
146
+ parser.add_argument("--image_width", default=image_width,
147
+ type=int, help="image width to be used to predict with the model")
148
+ parser.add_argument("--dir_dataset", default=dir_dataset,
149
+ type=str, help="full directory path to the dataset")
150
+ parser.add_argument("--which_hw_model", default=which_hw_model,
151
+ type=str, choices=["crnn", "stn_crnn"], help="which model to be used for prediction")
152
+ parser.add_argument("--which_ctc_decoder", default=which_ctc_decoder,
153
+ type=str, choices=["beam_search", "greedy"], help="which ctc decoder to use")
154
+ parser.add_argument("--file_model", default=file_model,
155
+ type=str, help="full path to trained model file (.pth)")
156
+ parser.add_argument("--save_prediction_stats", default=save_prediction_stats,
157
+ type=int, choices=[0, 1], help="save prediction stats (1 - yes, 0 - no)")
158
+
159
+ FLAGS, unparsed = parser.parse_known_args()
160
+ test_hw_recognizer(FLAGS)
161
+ return
162
+
163
+ if __name__ == "__main__":
164
+ main()
iam_line_recognition/train.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import torch
5
+ import argparse
6
+ import torchvision
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from model_main import CRNN, STN_CRNN
11
+ from logger_utils import CSVWriter, write_json_file
12
+ from utils import compute_wer_and_cer_for_sample, ctc_decode
13
+ from dataset import HWRecogIAMDataset, split_dataset, get_dataloaders_for_training
14
+
15
+
16
+ def train(hw_model, optimizer, criterion, train_loader, device):
17
+ """
18
+ ---------
19
+ Arguments
20
+ ---------
21
+ hw_model : object
22
+ handwriting recognition model object
23
+ optimizer : object
24
+ optimizer object to be used for optimization
25
+ criterion : object
26
+ criterion or loss object to be used as the objective function for optimization
27
+ train_loader : object
28
+ train set dataloader object
29
+ device : str
30
+ device to be used for running the evaluation
31
+
32
+ -------
33
+ Returns
34
+ -------
35
+ train_loss : float
36
+ mean training loss for an epoch
37
+ """
38
+ hw_model.train()
39
+ train_running_loss = 0.0
40
+ num_train_samples = len(train_loader.dataset)
41
+ num_train_batches = len(train_loader)
42
+
43
+ for images, labels, lengths_labels in train_loader:
44
+ images = images.to(device, dtype=torch.float)
45
+ labels = labels.to(device, dtype=torch.long)
46
+ lengths_labels = lengths_labels.to(device, torch.long)
47
+
48
+ batch_size = images.size(0)
49
+ optimizer.zero_grad()
50
+ log_probs = hw_model(images)
51
+
52
+ lengths_preds = torch.LongTensor([log_probs.size(0)] * batch_size)
53
+ lengths_labels = torch.flatten(lengths_labels)
54
+
55
+ loss = criterion(log_probs, labels, lengths_preds, lengths_labels)
56
+ train_running_loss += loss.item()
57
+ loss.backward()
58
+ torch.nn.utils.clip_grad_norm_(hw_model.parameters(), 5) # gradient clipping with 5
59
+ optimizer.step()
60
+
61
+ train_loss = train_running_loss / num_train_batches
62
+ return train_loss
63
+
64
+ def validate(hw_model, criterion, valid_loader, device):
65
+ """
66
+ ---------
67
+ Arguments
68
+ ---------
69
+ hw_model : object
70
+ handwriting recognition model object
71
+ criterion : object
72
+ criterion or loss object to be used as the objective function for optimization
73
+ valid_loader : object
74
+ validation set dataloader object
75
+ device : str
76
+ device to be used for running the evaluation
77
+
78
+ -------
79
+ Returns
80
+ -------
81
+ a 3 tuple of
82
+ valid_loss : float
83
+ mean validation loss for an epoch
84
+ valid_cer : float
85
+ mean character error rate (CER) for validation set
86
+ valid_wer : float
87
+ mean word error rate (WER) for validation set
88
+ """
89
+ hw_model.eval()
90
+ valid_running_loss = 0.0
91
+ valid_running_cer = 0.0
92
+ valid_running_wer = 0.0
93
+ num_valid_samples = len(valid_loader.dataset)
94
+ num_valid_batches = len(valid_loader)
95
+
96
+ count = 0
97
+ with torch.no_grad():
98
+ for images, labels, lengths_labels in valid_loader:
99
+ images = images.to(device, dtype=torch.float)
100
+ labels = labels.to(device, dtype=torch.long)
101
+ lengths_labels = lengths_labels.to(device, torch.long)
102
+
103
+ batch_size = images.size(0)
104
+ log_probs = hw_model(images)
105
+ lengths_preds = torch.LongTensor([log_probs.size(0)] * batch_size)
106
+
107
+ loss = criterion(log_probs, labels, lengths_preds, lengths_labels)
108
+ valid_running_loss += loss.item()
109
+
110
+ pred_labels = ctc_decode(log_probs)
111
+ labels_for_eval = labels.cpu().numpy().tolist()
112
+ lengths_labels_for_eval = lengths_labels.cpu().numpy().tolist()
113
+
114
+ final_labels_for_eval = []
115
+ length_label_counter = 0
116
+ for pred_label, length_label in zip(pred_labels, lengths_labels_for_eval):
117
+ label = labels_for_eval[length_label_counter:length_label_counter+length_label]
118
+ length_label_counter += length_label
119
+
120
+ final_labels_for_eval.append(label)
121
+
122
+ for i in range(len(final_labels_for_eval)):
123
+ if len(pred_labels[i]) != 0:
124
+ str_label = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in final_labels_for_eval[i]]
125
+ str_label = "".join(str_label)
126
+ str_pred = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in pred_labels[i]]
127
+ str_pred = "".join(str_pred)
128
+
129
+ cer_sample, wer_sample = compute_wer_and_cer_for_sample(str_pred, str_label)
130
+ else:
131
+ cer_sample, wer_sample = 100, 100
132
+
133
+ valid_running_cer += cer_sample
134
+ valid_running_wer += wer_sample
135
+
136
+ valid_loss = valid_running_loss / num_valid_batches
137
+ valid_cer = valid_running_cer / num_valid_samples
138
+ valid_wer = valid_running_wer / num_valid_samples
139
+ return valid_loss, valid_cer, valid_wer
140
+
141
+ def train_hw_recognizer(FLAGS):
142
+ file_txt_labels = os.path.join(FLAGS.dir_dataset, "iam_lines_gt.txt")
143
+ dir_images = os.path.join(FLAGS.dir_dataset, "img")
144
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
145
+
146
+ # train only on a CUDA device (GPU)
147
+ if torch.cuda.is_available():
148
+ device = torch.device("cuda")
149
+ else:
150
+ print("CUDA device not found, so exiting....")
151
+ sys.exit(0)
152
+
153
+ # split dataset into train and validation sets
154
+ train_x, valid_x, train_y, valid_y = split_dataset(file_txt_labels, for_train=True)
155
+ num_train_samples = len(train_x)
156
+ num_valid_samples = len(valid_x)
157
+ # get dataloaders for train and validation sets
158
+ train_loader, valid_loader = get_dataloaders_for_training(
159
+ train_x, train_y, valid_x, valid_y,
160
+ dir_images=dir_images, image_height=FLAGS.image_height, image_width=FLAGS.image_width,
161
+ batch_size=FLAGS.batch_size,
162
+ )
163
+
164
+ # create a directory for saving the model
165
+ dir_model = f"model_{FLAGS.which_hw_model}"
166
+ if not os.path.isdir(dir_model):
167
+ print(f"creating directory: {dir_model}")
168
+ os.makedirs(dir_model)
169
+
170
+ # save train and validation metrics in a csv file
171
+ file_logger_train = os.path.join(dir_model, "train_metrics.csv")
172
+ csv_writer = CSVWriter(
173
+ file_name=file_logger_train,
174
+ column_names=["epoch", "loss_train", "loss_valid", "cer_valid", "wer_valid"]
175
+ )
176
+
177
+ file_params = os.path.join(dir_model, "params.json")
178
+ write_json_file(file_params, vars(FLAGS))
179
+
180
+ num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1
181
+ print(f"task - handwriting recognition")
182
+ print(f"model: {FLAGS.which_hw_model}")
183
+ print(f"optimizer: {FLAGS.which_optimizer}, learning rate: {FLAGS.learning_rate:.6f}, weight decay: {FLAGS.weight_decay:.8f}")
184
+ print(f"batch size: {FLAGS.batch_size}, image height: {FLAGS.image_height}, image width: {FLAGS.image_width}")
185
+ print(f"num train samples: {num_train_samples}, num validation samples: {num_valid_samples}\n")
186
+
187
+ # load the right model
188
+ if FLAGS.which_hw_model == "crnn":
189
+ hw_model = CRNN(num_classes, FLAGS.image_height)
190
+ elif FLAGS.which_hw_model == "stn_crnn":
191
+ hw_model = STN_CRNN(num_classes, FLAGS.image_height, FLAGS.image_width)
192
+ else:
193
+ print(f"unidentified option: {FLAGS.which_hw_model}")
194
+ sys.exit(0)
195
+ hw_model.to(device)
196
+
197
+ # load the right optimizer based on user option
198
+ if FLAGS.which_optimizer == "adam":
199
+ optimizer = torch.optim.Adam(hw_model.parameters(), lr=FLAGS.learning_rate, weight_decay=FLAGS.weight_decay)
200
+ elif FLAGS.which_optimizer == "adadelta":
201
+ optimizer = torch.optim.Adadelta(hw_model.parameters(), lr=FLAGS.learning_rate, rho=0.95, eps=1e-8, weight_decay=FLAGS.weight_decay)
202
+ else:
203
+ print(f"unidentified option: {FLAGS.which_optimizer}")
204
+ sys.exit(0)
205
+ # use the CTC loss as the objective function for training
206
+ criterion = nn.CTCLoss(reduction="mean", zero_infinity=True)
207
+
208
+ # start training the model
209
+ print(f"training of handwriting recognition model {FLAGS.which_hw_model} started\n")
210
+ for epoch in range(1, FLAGS.num_epochs+1):
211
+ time_start = time.time()
212
+ train_loss = train(hw_model, optimizer, criterion, train_loader, device)
213
+ valid_loss, valid_cer, valid_wer = validate(hw_model, criterion, valid_loader, device)
214
+ time_end = time.time()
215
+ print(f"epoch: {epoch}/{FLAGS.num_epochs}, time: {time_end-time_start:.3f} sec.")
216
+ print(f"train loss: {train_loss:.6f}, validation loss: {valid_loss:.6f}, validation cer: {valid_cer:.4f}, validation wer: {valid_wer:.4f}\n")
217
+
218
+ csv_writer.write_row(
219
+ [
220
+ epoch,
221
+ round(train_loss, 6),
222
+ round(valid_loss, 6),
223
+ round(valid_cer, 4),
224
+ round(valid_wer, 4),
225
+ ]
226
+ )
227
+ torch.save(hw_model.state_dict(), os.path.join(dir_model, f"{FLAGS.which_hw_model}_H_{FLAGS.image_height}_W_{FLAGS.image_width}_E_{epoch}.pth"))
228
+ print(f"Training of handwriting recognition model {FLAGS.which_hw_model} complete!!!!")
229
+ # close the csv file
230
+ csv_writer.close()
231
+ return
232
+
233
+ def main():
234
+ learning_rate = 1
235
+ # 3e-4 for Adam, 1 for Adadelta
236
+ weight_decay = 0
237
+ # 3e-5 with Adam for both CRNN and STN-CRNN
238
+ # 0 with Adadelta for CRNN and STN-CRNN
239
+ batch_size = 64
240
+ num_epochs = 100
241
+ image_height = 32
242
+ image_width = 768
243
+ which_hw_model = "crnn"
244
+ which_optimizer = "adadelta"
245
+ dir_dataset = "/home/abhishek/Desktop/RUG/hw_recognition/IAM-data/"
246
+
247
+ parser = argparse.ArgumentParser(
248
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
249
+ )
250
+
251
+ parser.add_argument("--learning_rate", default=learning_rate,
252
+ type=float, help="learning rate to use for training")
253
+ parser.add_argument("--weight_decay", default=weight_decay,
254
+ type=float, help="weight decay to use for training")
255
+ parser.add_argument("--batch_size", default=batch_size,
256
+ type=int, help="batch size to use for training")
257
+ parser.add_argument("--num_epochs", default=num_epochs,
258
+ type=int, help="num epochs to train the model")
259
+ parser.add_argument("--image_height", default=image_height,
260
+ type=int, help="image height to be used to train the model")
261
+ parser.add_argument("--image_width", default=image_width,
262
+ type=int, help="image width to be used to train the model")
263
+ parser.add_argument("--dir_dataset", default=dir_dataset,
264
+ type=str, help="full directory path to the dataset")
265
+ parser.add_argument("--which_optimizer", default=which_optimizer,
266
+ type=str, choices=["adadelta", "adam"], help="which optimizer to use to train")
267
+ parser.add_argument("--which_hw_model", default=which_hw_model,
268
+ type=str, choices=["crnn", "stn_crnn", "stn_pp_crnn"], help="which model to train")
269
+
270
+ FLAGS, unparsed = parser.parse_known_args()
271
+ train_hw_recognizer(FLAGS)
272
+ return
273
+
274
+ if __name__ == "__main__":
275
+ main()
iam_line_recognition/utils.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import fastwer
3
+ import numpy as np
4
+ from scipy.special import logsumexp
5
+
6
+
7
+ """
8
+ -------------
9
+ CTC decoder
10
+ -------------
11
+ """
12
+
13
+ NINF = -1 * float("inf")
14
+ DEFAULT_EMISSION_THRESHOLD = 0.01
15
+
16
+ def _reconstruct(labels, blank=0):
17
+ new_labels = []
18
+ # merge same labels
19
+ previous = None
20
+ for l in labels:
21
+ if l != previous:
22
+ new_labels.append(l)
23
+ previous = l
24
+ # delete blank
25
+ new_labels = [l for l in new_labels if l != blank]
26
+ return new_labels
27
+
28
+ def beam_search_decode(emission_log_prob, blank=0, **kwargs):
29
+ beam_size = kwargs["beam_size"]
30
+ emission_threshold = kwargs.get("emission_threshold", np.log(DEFAULT_EMISSION_THRESHOLD))
31
+
32
+ length, class_count = emission_log_prob.shape
33
+
34
+ beams = [([], 0)] # (prefix, accumulated_log_prob)
35
+ for t in range(length):
36
+ new_beams = []
37
+ for prefix, accumulated_log_prob in beams:
38
+ for c in range(class_count):
39
+ log_prob = emission_log_prob[t, c]
40
+ if log_prob < emission_threshold:
41
+ continue
42
+ new_prefix = prefix + [c]
43
+ # log(p1 * p2) = log_p1 + log_p2
44
+ new_accu_log_prob = accumulated_log_prob + log_prob
45
+ new_beams.append((new_prefix, new_accu_log_prob))
46
+
47
+ # sorted by accumulated_log_prob
48
+ new_beams.sort(key=lambda x: x[1], reverse=True)
49
+ beams = new_beams[:beam_size]
50
+
51
+ # sum up beams to produce labels
52
+ total_accu_log_prob = {}
53
+ for prefix, accu_log_prob in beams:
54
+ labels = tuple(_reconstruct(prefix, blank))
55
+ # log(p1 + p2) = logsumexp([log_p1, log_p2])
56
+ total_accu_log_prob[labels] = \
57
+ logsumexp([accu_log_prob, total_accu_log_prob.get(labels, NINF)])
58
+
59
+ labels_beams = [(list(labels), accu_log_prob)
60
+ for labels, accu_log_prob in total_accu_log_prob.items()]
61
+ labels_beams.sort(key=lambda x: x[1], reverse=True)
62
+ labels = labels_beams[0][0]
63
+
64
+ return labels
65
+
66
+ def greedy_decode(emission_log_prob, blank=0):
67
+ labels = np.argmax(emission_log_prob, axis=-1)
68
+ labels = _reconstruct(labels, blank=blank)
69
+ return labels
70
+
71
+ def ctc_decode(log_probs, which_ctc_decoder="beam_search", label_2_char=None, blank=0, beam_size=25):
72
+ emission_log_probs = np.transpose(log_probs.cpu().numpy(), (1, 0, 2))
73
+ # size of emission_log_probs: (batch, length, class)
74
+
75
+ decoded_list = []
76
+ for emission_log_prob in emission_log_probs:
77
+ if which_ctc_decoder == "beam_search":
78
+ decoded = beam_search_decode(emission_log_prob, blank=blank, beam_size=beam_size)
79
+ elif which_ctc_decoder == "greedy":
80
+ decoded = greedy_decode(emission_log_prob, blank=blank)
81
+ else:
82
+ print(f"unidentified option for which_ctc_decoder : {which_ctc_decoder}")
83
+ sys.exit(0)
84
+
85
+ if label_2_char:
86
+ decoded = [label_2_char[l] for l in decoded]
87
+ decoded_list.append(decoded)
88
+ return decoded_list
89
+
90
+ """
91
+ --------------------
92
+ Evaluation Metrics
93
+ --------------------
94
+ """
95
+ def compute_wer_and_cer_for_batch(batch_preds, batch_gts):
96
+ cer_batch = fastwer.score(batch_preds, batch_gts, char_level=True)
97
+ wer_batch = fastwer.score(batch_preds, batch_gts)
98
+ return cer_batch, wer_batch
99
+
100
+ def compute_wer_and_cer_for_sample(str_pred, str_gt):
101
+ cer_sample = fastwer.score_sent(str_pred, str_gt, char_level=True)
102
+ wer_sample = fastwer.score_sent(str_pred, str_gt)
103
+ return cer_sample, wer_sample
iam_line_recognition/utils_unique_chars.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+
4
+ from dataset import read_IAM_label_txt_file
5
+
6
+ def list_unique_characters_in_IAM_dataset(FLAGS):
7
+ _, all_labels = read_IAM_label_txt_file(FLAGS.file_txt_labels)
8
+
9
+ num_labels = len(all_labels)
10
+ print(f"num labels : {num_labels}")
11
+ unique_chars = []
12
+
13
+ for label in all_labels:
14
+ unique_chars = unique_chars + list(np.unique(np.array(list(label))))
15
+
16
+ unique_chars = sorted(unique_chars)
17
+ unique_chars = np.array(unique_chars)
18
+ unique_chars = np.unique(unique_chars)
19
+ unique_chars = ''.join(unique_chars)
20
+
21
+ # prints all unique chars in the IAM dataset
22
+ print(unique_chars)
23
+
24
+ # prints the number of unique chars in the IAM dataset
25
+ print(f"Number of unique characters : {len(unique_chars)}")
26
+ return
27
+
28
+ def main():
29
+ file_txt_labels = "/home/abhishek/Desktop/RUG/hw_recognition/IAM-data/iam_lines_gt.txt"
30
+
31
+ parser = argparse.ArgumentParser(
32
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
33
+ )
34
+
35
+ parser.add_argument("--file_txt_labels", default=file_txt_labels,
36
+ type=str, help="full path to label text file")
37
+
38
+ FLAGS, unparsed = parser.parse_known_args()
39
+ list_unique_characters_in_IAM_dataset(FLAGS)
40
+ return
41
+
42
+ if __name__ == "__main__":
43
+ main()