Spaces:
Sleeping
Sleeping
abhishekrs4
commited on
Commit
•
44066b7
1
Parent(s):
bd1dc81
code formatting
Browse files- app.py +5 -7
- iam_line_recognition/dataset.py +94 -30
- iam_line_recognition/final_iam_line_recognizer.py +81 -28
- iam_line_recognition/logger_utils.py +3 -0
- iam_line_recognition/model_main.py +54 -9
- iam_line_recognition/model_visual_features.py +117 -43
- iam_line_recognition/test_internal.py +82 -27
- iam_line_recognition/train.py +125 -37
- iam_line_recognition/utils.py +25 -8
- iam_line_recognition/utils_unique_chars.py +14 -5
app.py
CHANGED
@@ -26,14 +26,10 @@ hw_recog_model = CRNN(num_classes, image_height)
|
|
26 |
|
27 |
try:
|
28 |
logging.info(f"loading model from {file_model_local}")
|
29 |
-
hw_recog_model.load_state_dict(
|
30 |
-
torch.load(file_model_local, map_location=device)
|
31 |
-
)
|
32 |
except:
|
33 |
logging.info(f"loading model from {file_model_cont}")
|
34 |
-
hw_recog_model.load_state_dict(
|
35 |
-
torch.load(file_model_cont, map_location=device)
|
36 |
-
)
|
37 |
hw_recog_model.to(device)
|
38 |
hw_recog_model.eval()
|
39 |
|
@@ -51,6 +47,7 @@ def predict_hw(img_test: np.ndarray) -> str:
|
|
51 |
str_pred = "".join(str_pred)
|
52 |
return str_pred
|
53 |
|
|
|
54 |
@app.route("/predict", methods=["POST"])
|
55 |
def predict() -> Response:
|
56 |
logging.info("IAM Handwriting recognition app")
|
@@ -62,7 +59,7 @@ def predict() -> Response:
|
|
62 |
img_dec = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
|
63 |
img_dec = cv2.cvtColor(img_dec, cv2.COLOR_BGR2RGB)
|
64 |
|
65 |
-
img_dec = cv2.resize(img_dec, (768, 32), interpolation
|
66 |
|
67 |
str_pred = predict_hw(img_dec)
|
68 |
|
@@ -77,5 +74,6 @@ def predict() -> Response:
|
|
77 |
json_pred = jsonify({"error": str(e)})
|
78 |
return json_pred
|
79 |
|
|
|
80 |
if __name__ == "__main__":
|
81 |
app.run(host="0.0.0.0", debug=True, port=7860)
|
|
|
26 |
|
27 |
try:
|
28 |
logging.info(f"loading model from {file_model_local}")
|
29 |
+
hw_recog_model.load_state_dict(torch.load(file_model_local, map_location=device))
|
|
|
|
|
30 |
except:
|
31 |
logging.info(f"loading model from {file_model_cont}")
|
32 |
+
hw_recog_model.load_state_dict(torch.load(file_model_cont, map_location=device))
|
|
|
|
|
33 |
hw_recog_model.to(device)
|
34 |
hw_recog_model.eval()
|
35 |
|
|
|
47 |
str_pred = "".join(str_pred)
|
48 |
return str_pred
|
49 |
|
50 |
+
|
51 |
@app.route("/predict", methods=["POST"])
|
52 |
def predict() -> Response:
|
53 |
logging.info("IAM Handwriting recognition app")
|
|
|
59 |
img_dec = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
|
60 |
img_dec = cv2.cvtColor(img_dec, cv2.COLOR_BGR2RGB)
|
61 |
|
62 |
+
img_dec = cv2.resize(img_dec, (768, 32), interpolation=cv2.INTER_LINEAR)
|
63 |
|
64 |
str_pred = predict_hw(img_dec)
|
65 |
|
|
|
74 |
json_pred = jsonify({"error": str(e)})
|
75 |
return json_pred
|
76 |
|
77 |
+
|
78 |
if __name__ == "__main__":
|
79 |
app.run(host="0.0.0.0", debug=True, port=7860)
|
iam_line_recognition/dataset.py
CHANGED
@@ -8,6 +8,7 @@ import torchvision.transforms as transforms
|
|
8 |
from torch.utils.data import Dataset, DataLoader
|
9 |
from sklearn.model_selection import train_test_split
|
10 |
|
|
|
11 |
def read_IAM_label_txt_file(file_txt_labels):
|
12 |
"""
|
13 |
---------
|
@@ -42,15 +43,25 @@ def read_IAM_label_txt_file(file_txt_labels):
|
|
42 |
|
43 |
return all_image_files, all_labels
|
44 |
|
|
|
45 |
class HWRecogIAMDataset(Dataset):
|
46 |
"""
|
47 |
Main dataset class to be used only for training, validation and internal testing
|
48 |
"""
|
49 |
-
|
|
|
50 |
CHAR_2_LABEL = {char: i + 1 for i, char in enumerate(CHAR_SET)}
|
51 |
LABEL_2_CHAR = {label: char for char, label in CHAR_2_LABEL.items()}
|
52 |
|
53 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
"""
|
55 |
---------
|
56 |
Arguments
|
@@ -77,28 +88,41 @@ class HWRecogIAMDataset(Dataset):
|
|
77 |
|
78 |
if self.which_set == "train":
|
79 |
# apply data augmentation only for train set
|
80 |
-
self.transform = transforms.Compose(
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
else:
|
93 |
-
self.transform = transforms.Compose(
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
102 |
|
103 |
def __len__(self):
|
104 |
return len(self.list_image_files)
|
@@ -118,6 +142,7 @@ class HWRecogIAMDataset(Dataset):
|
|
118 |
|
119 |
return image_3_channel, label_encoded, label_length
|
120 |
|
|
|
121 |
def IAM_collate_fn(batch):
|
122 |
"""
|
123 |
collate function
|
@@ -145,6 +170,7 @@ def IAM_collate_fn(batch):
|
|
145 |
label_lengths = torch.cat(label_lengths, 0)
|
146 |
return images, labels, label_lengths
|
147 |
|
|
|
148 |
def split_dataset(file_txt_labels, for_train=True):
|
149 |
"""
|
150 |
---------
|
@@ -161,14 +187,28 @@ def split_dataset(file_txt_labels, for_train=True):
|
|
161 |
a tuple of files depending for train or internal testing
|
162 |
"""
|
163 |
all_image_files, all_labels = read_IAM_label_txt_file(file_txt_labels)
|
164 |
-
train_image_files, test_image_files, train_labels, test_labels = train_test_split(
|
165 |
-
|
|
|
|
|
|
|
|
|
166 |
if for_train:
|
167 |
return train_image_files, valid_image_files, train_labels, valid_labels
|
168 |
else:
|
169 |
return test_image_files, test_labels
|
170 |
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
"""
|
173 |
---------
|
174 |
Arguments
|
@@ -199,8 +239,22 @@ def get_dataloaders_for_training(train_x, train_y, valid_x, valid_y, dir_images,
|
|
199 |
valid_loader : object
|
200 |
object of validation set dataloader
|
201 |
"""
|
202 |
-
train_dataset = HWRecogIAMDataset(
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
train_loader = DataLoader(
|
206 |
train_dataset,
|
@@ -218,7 +272,10 @@ def get_dataloaders_for_training(train_x, train_y, valid_x, valid_y, dir_images,
|
|
218 |
)
|
219 |
return train_loader, valid_loader
|
220 |
|
221 |
-
|
|
|
|
|
|
|
222 |
"""
|
223 |
---------
|
224 |
Arguments
|
@@ -242,7 +299,14 @@ def get_dataloader_for_testing(test_x, test_y, dir_images, image_height=32, imag
|
|
242 |
test_loader : object
|
243 |
object of test set dataloader
|
244 |
"""
|
245 |
-
test_dataset = HWRecogIAMDataset(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
test_loader = DataLoader(
|
247 |
test_dataset,
|
248 |
batch_size=batch_size,
|
|
|
8 |
from torch.utils.data import Dataset, DataLoader
|
9 |
from sklearn.model_selection import train_test_split
|
10 |
|
11 |
+
|
12 |
def read_IAM_label_txt_file(file_txt_labels):
|
13 |
"""
|
14 |
---------
|
|
|
43 |
|
44 |
return all_image_files, all_labels
|
45 |
|
46 |
+
|
47 |
class HWRecogIAMDataset(Dataset):
|
48 |
"""
|
49 |
Main dataset class to be used only for training, validation and internal testing
|
50 |
"""
|
51 |
+
|
52 |
+
CHAR_SET = " !\"#&'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
53 |
CHAR_2_LABEL = {char: i + 1 for i, char in enumerate(CHAR_SET)}
|
54 |
LABEL_2_CHAR = {label: char for char, label in CHAR_2_LABEL.items()}
|
55 |
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
list_image_files,
|
59 |
+
list_labels,
|
60 |
+
dir_images,
|
61 |
+
image_height=32,
|
62 |
+
image_width=768,
|
63 |
+
which_set="train",
|
64 |
+
):
|
65 |
"""
|
66 |
---------
|
67 |
Arguments
|
|
|
88 |
|
89 |
if self.which_set == "train":
|
90 |
# apply data augmentation only for train set
|
91 |
+
self.transform = transforms.Compose(
|
92 |
+
[
|
93 |
+
transforms.ToPILImage(),
|
94 |
+
transforms.Resize(
|
95 |
+
(self.image_height, self.image_width), Image.BILINEAR
|
96 |
+
),
|
97 |
+
transforms.RandomAffine(
|
98 |
+
degrees=[-0.75, 0.75],
|
99 |
+
translate=[0, 0.05],
|
100 |
+
scale=[0.75, 1],
|
101 |
+
shear=[-35, 35],
|
102 |
+
interpolation=transforms.InterpolationMode.BILINEAR,
|
103 |
+
fill=255,
|
104 |
+
),
|
105 |
+
transforms.ToTensor(),
|
106 |
+
transforms.Normalize(
|
107 |
+
mean=[0.485, 0.456, 0.406],
|
108 |
+
std=[0.229, 0.224, 0.225],
|
109 |
+
),
|
110 |
+
]
|
111 |
+
)
|
112 |
else:
|
113 |
+
self.transform = transforms.Compose(
|
114 |
+
[
|
115 |
+
transforms.ToPILImage(),
|
116 |
+
transforms.Resize(
|
117 |
+
(self.image_height, self.image_width), Image.BILINEAR
|
118 |
+
),
|
119 |
+
transforms.ToTensor(),
|
120 |
+
transforms.Normalize(
|
121 |
+
mean=[0.485, 0.456, 0.406],
|
122 |
+
std=[0.229, 0.224, 0.225],
|
123 |
+
),
|
124 |
+
]
|
125 |
+
)
|
126 |
|
127 |
def __len__(self):
|
128 |
return len(self.list_image_files)
|
|
|
142 |
|
143 |
return image_3_channel, label_encoded, label_length
|
144 |
|
145 |
+
|
146 |
def IAM_collate_fn(batch):
|
147 |
"""
|
148 |
collate function
|
|
|
170 |
label_lengths = torch.cat(label_lengths, 0)
|
171 |
return images, labels, label_lengths
|
172 |
|
173 |
+
|
174 |
def split_dataset(file_txt_labels, for_train=True):
|
175 |
"""
|
176 |
---------
|
|
|
187 |
a tuple of files depending for train or internal testing
|
188 |
"""
|
189 |
all_image_files, all_labels = read_IAM_label_txt_file(file_txt_labels)
|
190 |
+
train_image_files, test_image_files, train_labels, test_labels = train_test_split(
|
191 |
+
all_image_files, all_labels, test_size=0.1, random_state=4
|
192 |
+
)
|
193 |
+
train_image_files, valid_image_files, train_labels, valid_labels = train_test_split(
|
194 |
+
train_image_files, train_labels, test_size=0.1, random_state=4
|
195 |
+
)
|
196 |
if for_train:
|
197 |
return train_image_files, valid_image_files, train_labels, valid_labels
|
198 |
else:
|
199 |
return test_image_files, test_labels
|
200 |
|
201 |
+
|
202 |
+
def get_dataloaders_for_training(
|
203 |
+
train_x,
|
204 |
+
train_y,
|
205 |
+
valid_x,
|
206 |
+
valid_y,
|
207 |
+
dir_images,
|
208 |
+
image_height=32,
|
209 |
+
image_width=768,
|
210 |
+
batch_size=8,
|
211 |
+
):
|
212 |
"""
|
213 |
---------
|
214 |
Arguments
|
|
|
239 |
valid_loader : object
|
240 |
object of validation set dataloader
|
241 |
"""
|
242 |
+
train_dataset = HWRecogIAMDataset(
|
243 |
+
train_x,
|
244 |
+
train_y,
|
245 |
+
dir_images,
|
246 |
+
image_height=image_height,
|
247 |
+
image_width=image_width,
|
248 |
+
which_set="train",
|
249 |
+
)
|
250 |
+
valid_dataset = HWRecogIAMDataset(
|
251 |
+
valid_x,
|
252 |
+
valid_y,
|
253 |
+
dir_images,
|
254 |
+
image_height=image_height,
|
255 |
+
image_width=image_width,
|
256 |
+
which_set="valid",
|
257 |
+
)
|
258 |
|
259 |
train_loader = DataLoader(
|
260 |
train_dataset,
|
|
|
272 |
)
|
273 |
return train_loader, valid_loader
|
274 |
|
275 |
+
|
276 |
+
def get_dataloader_for_testing(
|
277 |
+
test_x, test_y, dir_images, image_height=32, image_width=768, batch_size=1
|
278 |
+
):
|
279 |
"""
|
280 |
---------
|
281 |
Arguments
|
|
|
299 |
test_loader : object
|
300 |
object of test set dataloader
|
301 |
"""
|
302 |
+
test_dataset = HWRecogIAMDataset(
|
303 |
+
test_x,
|
304 |
+
test_y,
|
305 |
+
dir_images=dir_images,
|
306 |
+
image_height=image_height,
|
307 |
+
image_width=image_width,
|
308 |
+
which_set="test",
|
309 |
+
)
|
310 |
test_loader = DataLoader(
|
311 |
test_dataset,
|
312 |
batch_size=batch_size,
|
iam_line_recognition/final_iam_line_recognizer.py
CHANGED
@@ -21,6 +21,7 @@ class DatasetFinalEval(HWRecogIAMDataset):
|
|
21 |
"""
|
22 |
Dataset class for final evaluation - inherits main dataset class
|
23 |
"""
|
|
|
24 |
def __init__(self, dir_images, image_height=32, image_width=768):
|
25 |
"""
|
26 |
---------
|
@@ -34,18 +35,24 @@ class DatasetFinalEval(HWRecogIAMDataset):
|
|
34 |
image width (default: 768)
|
35 |
"""
|
36 |
self.dir_images = dir_images
|
37 |
-
self.image_files = [
|
|
|
|
|
38 |
self.image_width = image_width
|
39 |
self.image_height = image_height
|
40 |
-
self.transform = transforms.Compose(
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
49 |
|
50 |
def __len__(self):
|
51 |
return len(self.image_files)
|
@@ -57,7 +64,10 @@ class DatasetFinalEval(HWRecogIAMDataset):
|
|
57 |
image_3_channel = self.transform(image_3_channel)
|
58 |
return image_3_channel
|
59 |
|
60 |
-
|
|
|
|
|
|
|
61 |
"""
|
62 |
---------
|
63 |
Arguments
|
@@ -77,7 +87,9 @@ def get_dataloader_for_evaluation(dir_images, image_height=32, image_width=768,
|
|
77 |
test_loader : object
|
78 |
dataset loader object for final evaluation
|
79 |
"""
|
80 |
-
test_dataset = DatasetFinalEval(
|
|
|
|
|
81 |
test_loader = DataLoader(
|
82 |
test_dataset,
|
83 |
batch_size=batch_size,
|
@@ -86,6 +98,7 @@ def get_dataloader_for_evaluation(dir_images, image_height=32, image_width=768,
|
|
86 |
)
|
87 |
return test_loader
|
88 |
|
|
|
89 |
def final_eval(hw_model, device, test_loader, dir_images, dir_results):
|
90 |
"""
|
91 |
---------
|
@@ -126,14 +139,22 @@ def final_eval(hw_model, device, test_loader, dir_images, dir_results):
|
|
126 |
str_pred = [DatasetFinalEval.LABEL_2_CHAR[i] for i in pred_labels[0]]
|
127 |
str_pred = "".join(str_pred)
|
128 |
|
129 |
-
with open(
|
|
|
|
|
|
|
|
|
|
|
130 |
fh_pred.write(str_pred)
|
131 |
|
132 |
-
print(
|
|
|
|
|
133 |
print(f"{str_pred}\n")
|
134 |
print(f"predictions saved in directory: ./{dir_results}\n")
|
135 |
return
|
136 |
|
|
|
137 |
def test_hw_recognizer(FLAGS):
|
138 |
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
139 |
|
@@ -163,15 +184,20 @@ def test_hw_recognizer(FLAGS):
|
|
163 |
|
164 |
# get test set dataloader
|
165 |
test_loader = get_dataloader_for_evaluation(
|
166 |
-
dir_images=FLAGS.dir_images,
|
|
|
|
|
167 |
)
|
168 |
|
169 |
# start the evaluation on the final test set
|
170 |
-
print(
|
|
|
|
|
171 |
final_eval(hw_model, device, test_loader, FLAGS.dir_images, dir_results)
|
172 |
print(f"final evaluation of handwriting recognition model completed!!!!")
|
173 |
return
|
174 |
|
|
|
175 |
def main():
|
176 |
image_height = 32
|
177 |
image_width = 768
|
@@ -184,22 +210,49 @@ def main():
|
|
184 |
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
185 |
)
|
186 |
|
187 |
-
parser.add_argument(
|
188 |
-
|
189 |
-
|
190 |
-
type=int,
|
191 |
-
|
192 |
-
|
193 |
-
parser.add_argument(
|
194 |
-
|
195 |
-
|
196 |
-
type=
|
197 |
-
|
198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
|
200 |
FLAGS, unparsed = parser.parse_known_args()
|
201 |
test_hw_recognizer(FLAGS)
|
202 |
return
|
203 |
|
|
|
204 |
if __name__ == "__main__":
|
205 |
main()
|
|
|
21 |
"""
|
22 |
Dataset class for final evaluation - inherits main dataset class
|
23 |
"""
|
24 |
+
|
25 |
def __init__(self, dir_images, image_height=32, image_width=768):
|
26 |
"""
|
27 |
---------
|
|
|
35 |
image width (default: 768)
|
36 |
"""
|
37 |
self.dir_images = dir_images
|
38 |
+
self.image_files = [
|
39 |
+
f for f in os.listdir(self.dir_images) if f.endswith(".png")
|
40 |
+
]
|
41 |
self.image_width = image_width
|
42 |
self.image_height = image_height
|
43 |
+
self.transform = transforms.Compose(
|
44 |
+
[
|
45 |
+
transforms.ToPILImage(),
|
46 |
+
transforms.Resize(
|
47 |
+
(self.image_height, self.image_width), Image.BILINEAR
|
48 |
+
),
|
49 |
+
transforms.ToTensor(),
|
50 |
+
transforms.Normalize(
|
51 |
+
mean=[0.485, 0.456, 0.406],
|
52 |
+
std=[0.229, 0.224, 0.225],
|
53 |
+
),
|
54 |
+
]
|
55 |
+
)
|
56 |
|
57 |
def __len__(self):
|
58 |
return len(self.image_files)
|
|
|
64 |
image_3_channel = self.transform(image_3_channel)
|
65 |
return image_3_channel
|
66 |
|
67 |
+
|
68 |
+
def get_dataloader_for_evaluation(
|
69 |
+
dir_images, image_height=32, image_width=768, batch_size=1
|
70 |
+
):
|
71 |
"""
|
72 |
---------
|
73 |
Arguments
|
|
|
87 |
test_loader : object
|
88 |
dataset loader object for final evaluation
|
89 |
"""
|
90 |
+
test_dataset = DatasetFinalEval(
|
91 |
+
dir_images=dir_images, image_height=image_height, image_width=image_width
|
92 |
+
)
|
93 |
test_loader = DataLoader(
|
94 |
test_dataset,
|
95 |
batch_size=batch_size,
|
|
|
98 |
)
|
99 |
return test_loader
|
100 |
|
101 |
+
|
102 |
def final_eval(hw_model, device, test_loader, dir_images, dir_results):
|
103 |
"""
|
104 |
---------
|
|
|
139 |
str_pred = [DatasetFinalEval.LABEL_2_CHAR[i] for i in pred_labels[0]]
|
140 |
str_pred = "".join(str_pred)
|
141 |
|
142 |
+
with open(
|
143 |
+
os.path.join(dir_results, file_test + ".txt"),
|
144 |
+
"w",
|
145 |
+
encoding="utf-8",
|
146 |
+
newline="\n",
|
147 |
+
) as fh_pred:
|
148 |
fh_pred.write(str_pred)
|
149 |
|
150 |
+
print(
|
151 |
+
f"progress: {count}/{num_test_samples}, test file: {list_test_files[count-1]}"
|
152 |
+
)
|
153 |
print(f"{str_pred}\n")
|
154 |
print(f"predictions saved in directory: ./{dir_results}\n")
|
155 |
return
|
156 |
|
157 |
+
|
158 |
def test_hw_recognizer(FLAGS):
|
159 |
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
160 |
|
|
|
184 |
|
185 |
# get test set dataloader
|
186 |
test_loader = get_dataloader_for_evaluation(
|
187 |
+
dir_images=FLAGS.dir_images,
|
188 |
+
image_height=FLAGS.image_height,
|
189 |
+
image_width=FLAGS.image_width,
|
190 |
)
|
191 |
|
192 |
# start the evaluation on the final test set
|
193 |
+
print(
|
194 |
+
f"final evaluation of handwriting recognition model {FLAGS.which_hw_model} started\n"
|
195 |
+
)
|
196 |
final_eval(hw_model, device, test_loader, FLAGS.dir_images, dir_results)
|
197 |
print(f"final evaluation of handwriting recognition model completed!!!!")
|
198 |
return
|
199 |
|
200 |
+
|
201 |
def main():
|
202 |
image_height = 32
|
203 |
image_width = 768
|
|
|
210 |
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
211 |
)
|
212 |
|
213 |
+
parser.add_argument(
|
214 |
+
"--image_height",
|
215 |
+
default=image_height,
|
216 |
+
type=int,
|
217 |
+
help="image height to be used to predict with the model",
|
218 |
+
)
|
219 |
+
parser.add_argument(
|
220 |
+
"--image_width",
|
221 |
+
default=image_width,
|
222 |
+
type=int,
|
223 |
+
help="image width to be used to predict with the model",
|
224 |
+
)
|
225 |
+
parser.add_argument(
|
226 |
+
"--dir_images",
|
227 |
+
default=dir_images,
|
228 |
+
type=str,
|
229 |
+
help="full directory path to directory containing images",
|
230 |
+
)
|
231 |
+
parser.add_argument(
|
232 |
+
"--which_hw_model",
|
233 |
+
default=which_hw_model,
|
234 |
+
type=str,
|
235 |
+
choices=["crnn", "stn_crnn"],
|
236 |
+
help="which model to be used for prediction",
|
237 |
+
)
|
238 |
+
parser.add_argument(
|
239 |
+
"--file_model",
|
240 |
+
default=file_model,
|
241 |
+
type=str,
|
242 |
+
help="full path to trained model file (.pth)",
|
243 |
+
)
|
244 |
+
parser.add_argument(
|
245 |
+
"--save_predictions",
|
246 |
+
default=save_predictions,
|
247 |
+
type=int,
|
248 |
+
choices=[0, 1],
|
249 |
+
help="save or do not save the predictions (1 - save, 0 - do not save)",
|
250 |
+
)
|
251 |
|
252 |
FLAGS, unparsed = parser.parse_known_args()
|
253 |
test_hw_recognizer(FLAGS)
|
254 |
return
|
255 |
|
256 |
+
|
257 |
if __name__ == "__main__":
|
258 |
main()
|
iam_line_recognition/logger_utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import csv
|
2 |
import json
|
3 |
|
|
|
4 |
def write_json_file(file_json, dict_data):
|
5 |
"""
|
6 |
---------
|
@@ -15,10 +16,12 @@ def write_json_file(file_json, dict_data):
|
|
15 |
fh.write(json.dumps(dict_data, indent=4))
|
16 |
return
|
17 |
|
|
|
18 |
class CSVWriter:
|
19 |
"""
|
20 |
for writing tabular data to a csv file
|
21 |
"""
|
|
|
22 |
def __init__(self, file_name, column_names):
|
23 |
"""
|
24 |
---------
|
|
|
1 |
import csv
|
2 |
import json
|
3 |
|
4 |
+
|
5 |
def write_json_file(file_json, dict_data):
|
6 |
"""
|
7 |
---------
|
|
|
16 |
fh.write(json.dumps(dict_data, indent=4))
|
17 |
return
|
18 |
|
19 |
+
|
20 |
class CSVWriter:
|
21 |
"""
|
22 |
for writing tabular data to a csv file
|
23 |
"""
|
24 |
+
|
25 |
def __init__(self, file_name, column_names):
|
26 |
"""
|
27 |
---------
|
iam_line_recognition/model_main.py
CHANGED
@@ -4,11 +4,20 @@ import torch.nn.functional as F
|
|
4 |
|
5 |
from model_visual_features import ResNetFeatureExtractor, TPS_SpatialTransformerNetwork
|
6 |
|
|
|
7 |
class HW_RNN_Seq2Seq(nn.Module):
|
8 |
"""
|
9 |
Visual Seq2Seq model using BiLSTM
|
10 |
"""
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
"""
|
13 |
---------
|
14 |
Arguments
|
@@ -28,10 +37,16 @@ class HW_RNN_Seq2Seq(nn.Module):
|
|
28 |
self.output_height = image_height // 32
|
29 |
|
30 |
self.dropout = nn.Dropout(p=0.25)
|
31 |
-
self.map_visual_to_seq = nn.Linear(
|
|
|
|
|
32 |
|
33 |
-
self.b_lstm_1 = nn.LSTM(
|
34 |
-
|
|
|
|
|
|
|
|
|
35 |
|
36 |
self.final_dense = nn.Linear(2 * num_feats_seq_hidden, num_classes)
|
37 |
|
@@ -40,7 +55,9 @@ class HW_RNN_Seq2Seq(nn.Module):
|
|
40 |
# WBCH
|
41 |
# the sequence is along the width of the image as a sentence
|
42 |
|
43 |
-
visual_feats = visual_feats.contiguous().view(
|
|
|
|
|
44 |
# WBC
|
45 |
|
46 |
seq = self.map_visual_to_seq(visual_feats)
|
@@ -63,7 +80,14 @@ class CRNN(nn.Module):
|
|
63 |
CNN - Modified ResNet34 for visual features
|
64 |
RNN - BiLSTM for seq2seq modeling
|
65 |
"""
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
"""
|
68 |
---------
|
69 |
Arguments
|
@@ -79,7 +103,13 @@ class CRNN(nn.Module):
|
|
79 |
"""
|
80 |
super().__init__()
|
81 |
self.visual_feature_extractor = ResNetFeatureExtractor()
|
82 |
-
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
def forward(self, x):
|
85 |
visual_feats = self.visual_feature_extractor(x)
|
@@ -96,7 +126,15 @@ class STN_CRNN(nn.Module):
|
|
96 |
CNN - Modified ResNet34 for visual features
|
97 |
RNN - BiLSTM for seq2seq modeling
|
98 |
"""
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
"""
|
101 |
---------
|
102 |
Arguments
|
@@ -120,7 +158,13 @@ class STN_CRNN(nn.Module):
|
|
120 |
I_channel_num=3,
|
121 |
)
|
122 |
self.visual_feature_extractor = ResNetFeatureExtractor()
|
123 |
-
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
def forward(self, x):
|
126 |
stn_output = self.stn(x)
|
@@ -128,6 +172,7 @@ class STN_CRNN(nn.Module):
|
|
128 |
log_probs = self.rnn_seq2seq_module(visual_feats)
|
129 |
return log_probs
|
130 |
|
|
|
131 |
"""
|
132 |
class STN_PP_CRNN(nn.Module):
|
133 |
def __init__(self, num_classes, image_height, image_width, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
|
|
|
4 |
|
5 |
from model_visual_features import ResNetFeatureExtractor, TPS_SpatialTransformerNetwork
|
6 |
|
7 |
+
|
8 |
class HW_RNN_Seq2Seq(nn.Module):
|
9 |
"""
|
10 |
Visual Seq2Seq model using BiLSTM
|
11 |
"""
|
12 |
+
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
num_classes,
|
16 |
+
image_height,
|
17 |
+
cnn_output_channels=512,
|
18 |
+
num_feats_mapped_seq_hidden=128,
|
19 |
+
num_feats_seq_hidden=256,
|
20 |
+
):
|
21 |
"""
|
22 |
---------
|
23 |
Arguments
|
|
|
37 |
self.output_height = image_height // 32
|
38 |
|
39 |
self.dropout = nn.Dropout(p=0.25)
|
40 |
+
self.map_visual_to_seq = nn.Linear(
|
41 |
+
cnn_output_channels * self.output_height, num_feats_mapped_seq_hidden
|
42 |
+
)
|
43 |
|
44 |
+
self.b_lstm_1 = nn.LSTM(
|
45 |
+
num_feats_mapped_seq_hidden, num_feats_seq_hidden, bidirectional=True
|
46 |
+
)
|
47 |
+
self.b_lstm_2 = nn.LSTM(
|
48 |
+
2 * num_feats_seq_hidden, num_feats_seq_hidden, bidirectional=True
|
49 |
+
)
|
50 |
|
51 |
self.final_dense = nn.Linear(2 * num_feats_seq_hidden, num_classes)
|
52 |
|
|
|
55 |
# WBCH
|
56 |
# the sequence is along the width of the image as a sentence
|
57 |
|
58 |
+
visual_feats = visual_feats.contiguous().view(
|
59 |
+
visual_feats.shape[0], visual_feats.shape[1], -1
|
60 |
+
)
|
61 |
# WBC
|
62 |
|
63 |
seq = self.map_visual_to_seq(visual_feats)
|
|
|
80 |
CNN - Modified ResNet34 for visual features
|
81 |
RNN - BiLSTM for seq2seq modeling
|
82 |
"""
|
83 |
+
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
num_classes,
|
87 |
+
image_height,
|
88 |
+
num_feats_mapped_seq_hidden=128,
|
89 |
+
num_feats_seq_hidden=256,
|
90 |
+
):
|
91 |
"""
|
92 |
---------
|
93 |
Arguments
|
|
|
103 |
"""
|
104 |
super().__init__()
|
105 |
self.visual_feature_extractor = ResNetFeatureExtractor()
|
106 |
+
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(
|
107 |
+
num_classes,
|
108 |
+
image_height,
|
109 |
+
self.visual_feature_extractor.output_channels,
|
110 |
+
num_feats_mapped_seq_hidden,
|
111 |
+
num_feats_seq_hidden,
|
112 |
+
)
|
113 |
|
114 |
def forward(self, x):
|
115 |
visual_feats = self.visual_feature_extractor(x)
|
|
|
126 |
CNN - Modified ResNet34 for visual features
|
127 |
RNN - BiLSTM for seq2seq modeling
|
128 |
"""
|
129 |
+
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
num_classes,
|
133 |
+
image_height,
|
134 |
+
image_width,
|
135 |
+
num_feats_mapped_seq_hidden=128,
|
136 |
+
num_feats_seq_hidden=256,
|
137 |
+
):
|
138 |
"""
|
139 |
---------
|
140 |
Arguments
|
|
|
158 |
I_channel_num=3,
|
159 |
)
|
160 |
self.visual_feature_extractor = ResNetFeatureExtractor()
|
161 |
+
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(
|
162 |
+
num_classes,
|
163 |
+
image_height,
|
164 |
+
self.visual_feature_extractor.output_channels,
|
165 |
+
num_feats_mapped_seq_hidden,
|
166 |
+
num_feats_seq_hidden,
|
167 |
+
)
|
168 |
|
169 |
def forward(self, x):
|
170 |
stn_output = self.stn(x)
|
|
|
172 |
log_probs = self.rnn_seq2seq_module(visual_feats)
|
173 |
return log_probs
|
174 |
|
175 |
+
|
176 |
"""
|
177 |
class STN_PP_CRNN(nn.Module):
|
178 |
def __init__(self, num_classes, image_height, image_width, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
|
iam_line_recognition/model_visual_features.py
CHANGED
@@ -4,10 +4,17 @@ import torch.nn as nn
|
|
4 |
from typing import List
|
5 |
from torch import Tensor
|
6 |
import torch.nn.functional as F
|
7 |
-
from torchvision.models.resnet import
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
device = torch.device("cuda")
|
10 |
|
|
|
11 |
class CustomResNet(nn.Module):
|
12 |
def __init__(
|
13 |
self,
|
@@ -43,14 +50,22 @@ class CustomResNet(nn.Module):
|
|
43 |
self.groups = groups
|
44 |
self.base_width = width_per_group
|
45 |
|
46 |
-
self.conv1 = nn.Conv2d(
|
|
|
|
|
47 |
self.bn1 = self._norm_layer(self.inplanes)
|
48 |
self.relu = nn.ReLU(inplace=True)
|
49 |
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=(2, 1), padding=1)
|
50 |
self.layer1 = self._make_layer(block, 64, layers[0])
|
51 |
-
self.layer2 = self._make_layer(
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
55 |
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
56 |
|
@@ -92,7 +107,14 @@ class CustomResNet(nn.Module):
|
|
92 |
layers = []
|
93 |
layers.append(
|
94 |
block(
|
95 |
-
self.inplanes,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
)
|
97 |
)
|
98 |
self.inplanes = planes * block.expansion
|
@@ -126,6 +148,7 @@ class CustomResNet(nn.Module):
|
|
126 |
def forward(self, x: Tensor) -> Tensor:
|
127 |
return self._forward_impl(x)
|
128 |
|
|
|
129 |
def _resnet(layers: List[int], pretrained=True) -> CustomResNet:
|
130 |
model = CustomResNet(layers)
|
131 |
|
@@ -134,6 +157,7 @@ def _resnet(layers: List[int], pretrained=True) -> CustomResNet:
|
|
134 |
|
135 |
return model
|
136 |
|
|
|
137 |
def resnet34(*, pretrained=True) -> CustomResNet:
|
138 |
"""ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
|
139 |
Args:
|
@@ -159,6 +183,7 @@ class ResNetFeatureExtractor(nn.Module):
|
|
159 |
"""
|
160 |
Defines Base ResNet-34 feature extractor
|
161 |
"""
|
|
|
162 |
def __init__(self, pretrained=True):
|
163 |
"""
|
164 |
---------
|
@@ -174,7 +199,7 @@ class ResNetFeatureExtractor(nn.Module):
|
|
174 |
def forward(self, x):
|
175 |
block1 = self.resnet34.conv1(x)
|
176 |
block1 = self.resnet34.bn1(block1)
|
177 |
-
block1 = self.resnet34.relu(block1)
|
178 |
|
179 |
block2 = self.resnet34.maxpool(block1)
|
180 |
block2 = self.resnet34.layer1(block2) # [64, H/4, W/4]
|
@@ -190,10 +215,10 @@ class ResNetFeatureExtractor(nn.Module):
|
|
190 |
### STN - Spatial Transformer Network ###
|
191 |
#########################################
|
192 |
class TPS_SpatialTransformerNetwork(nn.Module):
|
193 |
-
"""
|
194 |
|
195 |
def __init__(self, num_fiducial_points, I_size, I_r_size, I_channel_num=1):
|
196 |
-
"""
|
197 |
input:
|
198 |
batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
|
199 |
I_size : (height, width) of the input image I
|
@@ -207,39 +232,66 @@ class TPS_SpatialTransformerNetwork(nn.Module):
|
|
207 |
self.I_size = I_size
|
208 |
self.I_r_size = I_r_size # = (I_r_height, I_r_width)
|
209 |
self.I_channel_num = I_channel_num
|
210 |
-
self.LocalizationNetwork = LocalizationNetwork(
|
|
|
|
|
211 |
self.GridGenerator = GridGenerator(self.num_fiducial_points, self.I_r_size)
|
212 |
|
213 |
def forward(self, batch_I):
|
214 |
batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
|
215 |
-
build_P_prime = self.GridGenerator.build_P_prime(
|
216 |
-
|
|
|
|
|
|
|
|
|
217 |
|
218 |
if torch.__version__ > "1.2.0":
|
219 |
-
batch_I_r = F.grid_sample(
|
|
|
|
|
|
|
|
|
|
|
220 |
else:
|
221 |
-
batch_I_r = F.grid_sample(
|
|
|
|
|
222 |
|
223 |
return batch_I_r
|
224 |
|
225 |
|
226 |
class LocalizationNetwork(nn.Module):
|
227 |
-
"""
|
228 |
|
229 |
def __init__(self, num_fiducial_points, I_channel_num):
|
230 |
super(LocalizationNetwork, self).__init__()
|
231 |
self.num_fiducial_points = num_fiducial_points
|
232 |
self.I_channel_num = I_channel_num
|
233 |
self.conv = nn.Sequential(
|
234 |
-
nn.Conv2d(
|
235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2
|
237 |
-
nn.Conv2d(64, 128, 3, 1, 1, bias=False),
|
|
|
|
|
238 |
nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4
|
239 |
-
nn.Conv2d(128, 256, 3, 1, 1, bias=False),
|
|
|
|
|
240 |
nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8
|
241 |
-
nn.Conv2d(256, 512, 3, 1, 1, bias=False),
|
242 |
-
nn.
|
|
|
|
|
243 |
)
|
244 |
|
245 |
self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True))
|
@@ -254,7 +306,9 @@ class LocalizationNetwork(nn.Module):
|
|
254 |
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
255 |
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
256 |
initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
|
257 |
-
self.localization_fc2.bias.data =
|
|
|
|
|
258 |
|
259 |
def forward(self, batch_I):
|
260 |
"""
|
@@ -263,15 +317,17 @@ class LocalizationNetwork(nn.Module):
|
|
263 |
"""
|
264 |
batch_size = batch_I.size(0)
|
265 |
features = self.conv(batch_I).view(batch_size, -1)
|
266 |
-
batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(
|
|
|
|
|
267 |
return batch_C_prime
|
268 |
|
269 |
|
270 |
class GridGenerator(nn.Module):
|
271 |
-
"""
|
272 |
|
273 |
def __init__(self, num_fiducial_points, I_r_size):
|
274 |
-
"""
|
275 |
super(GridGenerator, self).__init__()
|
276 |
self.eps = 1e-6
|
277 |
self.I_r_height, self.I_r_width = I_r_size
|
@@ -279,14 +335,24 @@ class GridGenerator(nn.Module):
|
|
279 |
self.C = self._build_C(self.num_fiducial_points) # F x 2
|
280 |
self.P = self._build_P(self.I_r_width, self.I_r_height)
|
281 |
## for multi-gpu, you need register buffer
|
282 |
-
self.register_buffer(
|
283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
## for fine-tuning with different image width, you may use below instead of self.register_buffer
|
285 |
-
#self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.num_fiducial_points, self.C)).float().cuda() # F+3 x F+3
|
286 |
-
#self.P_hat = torch.tensor(self._build_P_hat(self.num_fiducial_points, self.C, self.P)).float().cuda() # n x F+3
|
287 |
|
288 |
def _build_C(self, F):
|
289 |
-
"""
|
290 |
ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
|
291 |
ctrl_pts_y_top = -1 * np.ones(int(F / 2))
|
292 |
ctrl_pts_y_bottom = np.ones(int(F / 2))
|
@@ -296,7 +362,7 @@ class GridGenerator(nn.Module):
|
|
296 |
return C # F x 2
|
297 |
|
298 |
def _build_inv_delta_C(self, F, C):
|
299 |
-
"""
|
300 |
hat_C = np.zeros((F, F), dtype=float) # F x F
|
301 |
for i in range(0, F):
|
302 |
for j in range(i, F):
|
@@ -304,31 +370,36 @@ class GridGenerator(nn.Module):
|
|
304 |
hat_C[i, j] = r
|
305 |
hat_C[j, i] = r
|
306 |
np.fill_diagonal(hat_C, 1)
|
307 |
-
hat_C = (hat_C
|
308 |
# print(C.shape, hat_C.shape)
|
309 |
delta_C = np.concatenate( # F+3 x F+3
|
310 |
[
|
311 |
np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3
|
312 |
np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3
|
313 |
-
np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3
|
314 |
],
|
315 |
-
axis=0
|
316 |
)
|
317 |
inv_delta_C = np.linalg.inv(delta_C)
|
318 |
return inv_delta_C # F+3 x F+3
|
319 |
|
320 |
def _build_P(self, I_r_width, I_r_height):
|
321 |
-
I_r_grid_x = (
|
322 |
-
|
|
|
|
|
|
|
|
|
323 |
P = np.stack( # self.I_r_width x self.I_r_height x 2
|
324 |
-
np.meshgrid(I_r_grid_x, I_r_grid_y),
|
325 |
-
axis=2
|
326 |
)
|
327 |
return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2
|
328 |
|
329 |
def _build_P_hat(self, F, C, P):
|
330 |
n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
|
331 |
-
P_tile = np.tile(
|
|
|
|
|
332 |
C_tile = np.expand_dims(C, axis=0) # 1 x F x 2
|
333 |
P_diff = P_tile - C_tile # n x F x 2
|
334 |
rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F
|
@@ -337,13 +408,16 @@ class GridGenerator(nn.Module):
|
|
337 |
return P_hat # n x F+3
|
338 |
|
339 |
def build_P_prime(self, batch_C_prime):
|
340 |
-
"""
|
341 |
batch_size = batch_C_prime.size(0)
|
342 |
batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
|
343 |
batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
|
344 |
-
batch_C_prime_with_zeros = torch.cat(
|
345 |
-
batch_size, 3, 2).float().to(device)), dim=1
|
346 |
-
|
|
|
|
|
|
|
347 |
batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
|
348 |
return batch_P_prime # batch_size x n x 2
|
349 |
|
|
|
4 |
from typing import List
|
5 |
from torch import Tensor
|
6 |
import torch.nn.functional as F
|
7 |
+
from torchvision.models.resnet import (
|
8 |
+
BasicBlock,
|
9 |
+
model_urls,
|
10 |
+
load_state_dict_from_url,
|
11 |
+
conv1x1,
|
12 |
+
conv3x3,
|
13 |
+
)
|
14 |
|
15 |
device = torch.device("cuda")
|
16 |
|
17 |
+
|
18 |
class CustomResNet(nn.Module):
|
19 |
def __init__(
|
20 |
self,
|
|
|
50 |
self.groups = groups
|
51 |
self.base_width = width_per_group
|
52 |
|
53 |
+
self.conv1 = nn.Conv2d(
|
54 |
+
3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
|
55 |
+
)
|
56 |
self.bn1 = self._norm_layer(self.inplanes)
|
57 |
self.relu = nn.ReLU(inplace=True)
|
58 |
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=(2, 1), padding=1)
|
59 |
self.layer1 = self._make_layer(block, 64, layers[0])
|
60 |
+
self.layer2 = self._make_layer(
|
61 |
+
block, 128, layers[1], stride=(2, 1), dilate=replace_stride_with_dilation[0]
|
62 |
+
)
|
63 |
+
self.layer3 = self._make_layer(
|
64 |
+
block, 256, layers[2], stride=(2, 2), dilate=replace_stride_with_dilation[1]
|
65 |
+
)
|
66 |
+
self.layer4 = self._make_layer(
|
67 |
+
block, 512, layers[3], stride=(2, 1), dilate=replace_stride_with_dilation[2]
|
68 |
+
)
|
69 |
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
70 |
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
71 |
|
|
|
107 |
layers = []
|
108 |
layers.append(
|
109 |
block(
|
110 |
+
self.inplanes,
|
111 |
+
planes,
|
112 |
+
stride,
|
113 |
+
downsample,
|
114 |
+
self.groups,
|
115 |
+
self.base_width,
|
116 |
+
previous_dilation,
|
117 |
+
norm_layer,
|
118 |
)
|
119 |
)
|
120 |
self.inplanes = planes * block.expansion
|
|
|
148 |
def forward(self, x: Tensor) -> Tensor:
|
149 |
return self._forward_impl(x)
|
150 |
|
151 |
+
|
152 |
def _resnet(layers: List[int], pretrained=True) -> CustomResNet:
|
153 |
model = CustomResNet(layers)
|
154 |
|
|
|
157 |
|
158 |
return model
|
159 |
|
160 |
+
|
161 |
def resnet34(*, pretrained=True) -> CustomResNet:
|
162 |
"""ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
|
163 |
Args:
|
|
|
183 |
"""
|
184 |
Defines Base ResNet-34 feature extractor
|
185 |
"""
|
186 |
+
|
187 |
def __init__(self, pretrained=True):
|
188 |
"""
|
189 |
---------
|
|
|
199 |
def forward(self, x):
|
200 |
block1 = self.resnet34.conv1(x)
|
201 |
block1 = self.resnet34.bn1(block1)
|
202 |
+
block1 = self.resnet34.relu(block1) # [64, H/2, W/2]
|
203 |
|
204 |
block2 = self.resnet34.maxpool(block1)
|
205 |
block2 = self.resnet34.layer1(block2) # [64, H/4, W/4]
|
|
|
215 |
### STN - Spatial Transformer Network ###
|
216 |
#########################################
|
217 |
class TPS_SpatialTransformerNetwork(nn.Module):
|
218 |
+
"""Rectification Network of RARE, namely TPS based STN"""
|
219 |
|
220 |
def __init__(self, num_fiducial_points, I_size, I_r_size, I_channel_num=1):
|
221 |
+
"""Based on RARE TPS
|
222 |
input:
|
223 |
batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
|
224 |
I_size : (height, width) of the input image I
|
|
|
232 |
self.I_size = I_size
|
233 |
self.I_r_size = I_r_size # = (I_r_height, I_r_width)
|
234 |
self.I_channel_num = I_channel_num
|
235 |
+
self.LocalizationNetwork = LocalizationNetwork(
|
236 |
+
self.num_fiducial_points, self.I_channel_num
|
237 |
+
)
|
238 |
self.GridGenerator = GridGenerator(self.num_fiducial_points, self.I_r_size)
|
239 |
|
240 |
def forward(self, batch_I):
|
241 |
batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
|
242 |
+
build_P_prime = self.GridGenerator.build_P_prime(
|
243 |
+
batch_C_prime
|
244 |
+
) # batch_size x n (= I_r_width x I_r_height) x 2
|
245 |
+
build_P_prime_reshape = build_P_prime.reshape(
|
246 |
+
[build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]
|
247 |
+
)
|
248 |
|
249 |
if torch.__version__ > "1.2.0":
|
250 |
+
batch_I_r = F.grid_sample(
|
251 |
+
batch_I,
|
252 |
+
build_P_prime_reshape,
|
253 |
+
padding_mode="border",
|
254 |
+
align_corners=True,
|
255 |
+
)
|
256 |
else:
|
257 |
+
batch_I_r = F.grid_sample(
|
258 |
+
batch_I, build_P_prime_reshape, padding_mode="border"
|
259 |
+
)
|
260 |
|
261 |
return batch_I_r
|
262 |
|
263 |
|
264 |
class LocalizationNetwork(nn.Module):
|
265 |
+
"""Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height)"""
|
266 |
|
267 |
def __init__(self, num_fiducial_points, I_channel_num):
|
268 |
super(LocalizationNetwork, self).__init__()
|
269 |
self.num_fiducial_points = num_fiducial_points
|
270 |
self.I_channel_num = I_channel_num
|
271 |
self.conv = nn.Sequential(
|
272 |
+
nn.Conv2d(
|
273 |
+
in_channels=self.I_channel_num,
|
274 |
+
out_channels=64,
|
275 |
+
kernel_size=3,
|
276 |
+
stride=1,
|
277 |
+
padding=1,
|
278 |
+
bias=False,
|
279 |
+
),
|
280 |
+
nn.BatchNorm2d(64),
|
281 |
+
nn.ReLU(True),
|
282 |
nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2
|
283 |
+
nn.Conv2d(64, 128, 3, 1, 1, bias=False),
|
284 |
+
nn.BatchNorm2d(128),
|
285 |
+
nn.ReLU(True),
|
286 |
nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4
|
287 |
+
nn.Conv2d(128, 256, 3, 1, 1, bias=False),
|
288 |
+
nn.BatchNorm2d(256),
|
289 |
+
nn.ReLU(True),
|
290 |
nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8
|
291 |
+
nn.Conv2d(256, 512, 3, 1, 1, bias=False),
|
292 |
+
nn.BatchNorm2d(512),
|
293 |
+
nn.ReLU(True),
|
294 |
+
nn.AdaptiveAvgPool2d(1), # batch_size x 512
|
295 |
)
|
296 |
|
297 |
self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True))
|
|
|
306 |
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
307 |
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
308 |
initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
|
309 |
+
self.localization_fc2.bias.data = (
|
310 |
+
torch.from_numpy(initial_bias).float().view(-1)
|
311 |
+
)
|
312 |
|
313 |
def forward(self, batch_I):
|
314 |
"""
|
|
|
317 |
"""
|
318 |
batch_size = batch_I.size(0)
|
319 |
features = self.conv(batch_I).view(batch_size, -1)
|
320 |
+
batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(
|
321 |
+
batch_size, self.num_fiducial_points, 2
|
322 |
+
)
|
323 |
return batch_C_prime
|
324 |
|
325 |
|
326 |
class GridGenerator(nn.Module):
|
327 |
+
"""Grid Generator of RARE, which produces P_prime by multipling T with P"""
|
328 |
|
329 |
def __init__(self, num_fiducial_points, I_r_size):
|
330 |
+
"""Generate P_hat and inv_delta_C for later"""
|
331 |
super(GridGenerator, self).__init__()
|
332 |
self.eps = 1e-6
|
333 |
self.I_r_height, self.I_r_width = I_r_size
|
|
|
335 |
self.C = self._build_C(self.num_fiducial_points) # F x 2
|
336 |
self.P = self._build_P(self.I_r_width, self.I_r_height)
|
337 |
## for multi-gpu, you need register buffer
|
338 |
+
self.register_buffer(
|
339 |
+
"inv_delta_C",
|
340 |
+
torch.tensor(
|
341 |
+
self._build_inv_delta_C(self.num_fiducial_points, self.C)
|
342 |
+
).float(),
|
343 |
+
) # F+3 x F+3
|
344 |
+
self.register_buffer(
|
345 |
+
"P_hat",
|
346 |
+
torch.tensor(
|
347 |
+
self._build_P_hat(self.num_fiducial_points, self.C, self.P)
|
348 |
+
).float(),
|
349 |
+
) # n x F+3
|
350 |
## for fine-tuning with different image width, you may use below instead of self.register_buffer
|
351 |
+
# self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.num_fiducial_points, self.C)).float().cuda() # F+3 x F+3
|
352 |
+
# self.P_hat = torch.tensor(self._build_P_hat(self.num_fiducial_points, self.C, self.P)).float().cuda() # n x F+3
|
353 |
|
354 |
def _build_C(self, F):
|
355 |
+
"""Return coordinates of fiducial points in I_r; C"""
|
356 |
ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
|
357 |
ctrl_pts_y_top = -1 * np.ones(int(F / 2))
|
358 |
ctrl_pts_y_bottom = np.ones(int(F / 2))
|
|
|
362 |
return C # F x 2
|
363 |
|
364 |
def _build_inv_delta_C(self, F, C):
|
365 |
+
"""Return inv_delta_C which is needed to calculate T"""
|
366 |
hat_C = np.zeros((F, F), dtype=float) # F x F
|
367 |
for i in range(0, F):
|
368 |
for j in range(i, F):
|
|
|
370 |
hat_C[i, j] = r
|
371 |
hat_C[j, i] = r
|
372 |
np.fill_diagonal(hat_C, 1)
|
373 |
+
hat_C = (hat_C**2) * np.log(hat_C)
|
374 |
# print(C.shape, hat_C.shape)
|
375 |
delta_C = np.concatenate( # F+3 x F+3
|
376 |
[
|
377 |
np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3
|
378 |
np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3
|
379 |
+
np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1), # 1 x F+3
|
380 |
],
|
381 |
+
axis=0,
|
382 |
)
|
383 |
inv_delta_C = np.linalg.inv(delta_C)
|
384 |
return inv_delta_C # F+3 x F+3
|
385 |
|
386 |
def _build_P(self, I_r_width, I_r_height):
|
387 |
+
I_r_grid_x = (
|
388 |
+
np.arange(-I_r_width, I_r_width, 2) + 1.0
|
389 |
+
) / I_r_width # self.I_r_width
|
390 |
+
I_r_grid_y = (
|
391 |
+
np.arange(-I_r_height, I_r_height, 2) + 1.0
|
392 |
+
) / I_r_height # self.I_r_height
|
393 |
P = np.stack( # self.I_r_width x self.I_r_height x 2
|
394 |
+
np.meshgrid(I_r_grid_x, I_r_grid_y), axis=2
|
|
|
395 |
)
|
396 |
return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2
|
397 |
|
398 |
def _build_P_hat(self, F, C, P):
|
399 |
n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
|
400 |
+
P_tile = np.tile(
|
401 |
+
np.expand_dims(P, axis=1), (1, F, 1)
|
402 |
+
) # n x 2 -> n x 1 x 2 -> n x F x 2
|
403 |
C_tile = np.expand_dims(C, axis=0) # 1 x F x 2
|
404 |
P_diff = P_tile - C_tile # n x F x 2
|
405 |
rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F
|
|
|
408 |
return P_hat # n x F+3
|
409 |
|
410 |
def build_P_prime(self, batch_C_prime):
|
411 |
+
"""Generate Grid from batch_C_prime [batch_size x F x 2]"""
|
412 |
batch_size = batch_C_prime.size(0)
|
413 |
batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
|
414 |
batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
|
415 |
+
batch_C_prime_with_zeros = torch.cat(
|
416 |
+
(batch_C_prime, torch.zeros(batch_size, 3, 2).float().to(device)), dim=1
|
417 |
+
) # batch_size x F+3 x 2
|
418 |
+
batch_T = torch.bmm(
|
419 |
+
batch_inv_delta_C, batch_C_prime_with_zeros
|
420 |
+
) # batch_size x F+3 x 2
|
421 |
batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
|
422 |
return batch_P_prime # batch_size x n x 2
|
423 |
|
iam_line_recognition/test_internal.py
CHANGED
@@ -14,7 +14,14 @@ from utils import ctc_decode, compute_wer_and_cer_for_sample
|
|
14 |
from dataset import HWRecogIAMDataset, split_dataset, get_dataloader_for_testing
|
15 |
|
16 |
|
17 |
-
def test(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
"""
|
19 |
---------
|
20 |
Arguments
|
@@ -42,7 +49,7 @@ def test(hw_model, test_loader, device, list_test_files, which_ctc_decoder="beam
|
|
42 |
if save_prediction_stats:
|
43 |
csv_writer = CSVWriter(
|
44 |
file_name="pred_stats.csv",
|
45 |
-
column_names=["file_name", "num_chars", "num_words", "cer", "wer"]
|
46 |
)
|
47 |
|
48 |
with torch.no_grad():
|
@@ -62,19 +69,23 @@ def test(hw_model, test_loader, device, list_test_files, which_ctc_decoder="beam
|
|
62 |
list_test_cers.append(cer_sample)
|
63 |
list_test_wers.append(wer_sample)
|
64 |
|
65 |
-
print(
|
|
|
|
|
66 |
print(f"{str_label} - label")
|
67 |
print(f"{str_pred} - prediction")
|
68 |
print(f"cer: {cer_sample:.3f}, wer: {wer_sample:.3f}\n")
|
69 |
|
70 |
if save_prediction_stats:
|
71 |
-
csv_writer.write_row(
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
78 |
list_test_cers = np.array(list_test_cers)
|
79 |
list_test_wers = np.array(list_test_wers)
|
80 |
mean_test_cer = np.mean(list_test_cers)
|
@@ -85,6 +96,7 @@ def test(hw_model, test_loader, device, list_test_files, which_ctc_decoder="beam
|
|
85 |
csv_writer.close()
|
86 |
return
|
87 |
|
|
|
88 |
def test_hw_recognizer(FLAGS):
|
89 |
file_txt_labels = os.path.join(FLAGS.dir_dataset, "iam_lines_gt.txt")
|
90 |
dir_images = os.path.join(FLAGS.dir_dataset, "img")
|
@@ -101,8 +113,11 @@ def test_hw_recognizer(FLAGS):
|
|
101 |
num_test_samples = len(test_x)
|
102 |
# get the internal test set dataloader
|
103 |
test_loader = get_dataloader_for_testing(
|
104 |
-
test_x,
|
105 |
-
|
|
|
|
|
|
|
106 |
)
|
107 |
|
108 |
num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1
|
@@ -124,10 +139,18 @@ def test_hw_recognizer(FLAGS):
|
|
124 |
|
125 |
# start testing of the model on the internal set
|
126 |
print(f"testing of handwriting recognition model {FLAGS.which_hw_model} started\n")
|
127 |
-
test(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
print(f"testing handwriting recognition model completed!!!!")
|
129 |
return
|
130 |
|
|
|
131 |
def main():
|
132 |
image_height = 32
|
133 |
image_width = 768
|
@@ -141,24 +164,56 @@ def main():
|
|
141 |
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
142 |
)
|
143 |
|
144 |
-
parser.add_argument(
|
145 |
-
|
146 |
-
|
147 |
-
type=int,
|
148 |
-
|
149 |
-
|
150 |
-
parser.add_argument(
|
151 |
-
|
152 |
-
|
153 |
-
type=
|
154 |
-
|
155 |
-
|
156 |
-
parser.add_argument(
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
FLAGS, unparsed = parser.parse_known_args()
|
160 |
test_hw_recognizer(FLAGS)
|
161 |
return
|
162 |
|
|
|
163 |
if __name__ == "__main__":
|
164 |
main()
|
|
|
14 |
from dataset import HWRecogIAMDataset, split_dataset, get_dataloader_for_testing
|
15 |
|
16 |
|
17 |
+
def test(
|
18 |
+
hw_model,
|
19 |
+
test_loader,
|
20 |
+
device,
|
21 |
+
list_test_files,
|
22 |
+
which_ctc_decoder="beam_search",
|
23 |
+
save_prediction_stats=False,
|
24 |
+
):
|
25 |
"""
|
26 |
---------
|
27 |
Arguments
|
|
|
49 |
if save_prediction_stats:
|
50 |
csv_writer = CSVWriter(
|
51 |
file_name="pred_stats.csv",
|
52 |
+
column_names=["file_name", "num_chars", "num_words", "cer", "wer"],
|
53 |
)
|
54 |
|
55 |
with torch.no_grad():
|
|
|
69 |
list_test_cers.append(cer_sample)
|
70 |
list_test_wers.append(wer_sample)
|
71 |
|
72 |
+
print(
|
73 |
+
f"progress: {count}/{num_test_samples}, test file: {list_test_files[count-1]}"
|
74 |
+
)
|
75 |
print(f"{str_label} - label")
|
76 |
print(f"{str_pred} - prediction")
|
77 |
print(f"cer: {cer_sample:.3f}, wer: {wer_sample:.3f}\n")
|
78 |
|
79 |
if save_prediction_stats:
|
80 |
+
csv_writer.write_row(
|
81 |
+
[
|
82 |
+
list_test_files[count - 1],
|
83 |
+
len(str_label),
|
84 |
+
len(str_label.split(" ")),
|
85 |
+
cer_sample,
|
86 |
+
wer_sample,
|
87 |
+
]
|
88 |
+
)
|
89 |
list_test_cers = np.array(list_test_cers)
|
90 |
list_test_wers = np.array(list_test_wers)
|
91 |
mean_test_cer = np.mean(list_test_cers)
|
|
|
96 |
csv_writer.close()
|
97 |
return
|
98 |
|
99 |
+
|
100 |
def test_hw_recognizer(FLAGS):
|
101 |
file_txt_labels = os.path.join(FLAGS.dir_dataset, "iam_lines_gt.txt")
|
102 |
dir_images = os.path.join(FLAGS.dir_dataset, "img")
|
|
|
113 |
num_test_samples = len(test_x)
|
114 |
# get the internal test set dataloader
|
115 |
test_loader = get_dataloader_for_testing(
|
116 |
+
test_x,
|
117 |
+
test_y,
|
118 |
+
dir_images=dir_images,
|
119 |
+
image_height=FLAGS.image_height,
|
120 |
+
image_width=FLAGS.image_width,
|
121 |
)
|
122 |
|
123 |
num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1
|
|
|
139 |
|
140 |
# start testing of the model on the internal set
|
141 |
print(f"testing of handwriting recognition model {FLAGS.which_hw_model} started\n")
|
142 |
+
test(
|
143 |
+
hw_model,
|
144 |
+
test_loader,
|
145 |
+
device,
|
146 |
+
test_x,
|
147 |
+
FLAGS.which_ctc_decoder,
|
148 |
+
bool(FLAGS.save_prediction_stats),
|
149 |
+
)
|
150 |
print(f"testing handwriting recognition model completed!!!!")
|
151 |
return
|
152 |
|
153 |
+
|
154 |
def main():
|
155 |
image_height = 32
|
156 |
image_width = 768
|
|
|
164 |
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
165 |
)
|
166 |
|
167 |
+
parser.add_argument(
|
168 |
+
"--image_height",
|
169 |
+
default=image_height,
|
170 |
+
type=int,
|
171 |
+
help="image height to be used to predict with the model",
|
172 |
+
)
|
173 |
+
parser.add_argument(
|
174 |
+
"--image_width",
|
175 |
+
default=image_width,
|
176 |
+
type=int,
|
177 |
+
help="image width to be used to predict with the model",
|
178 |
+
)
|
179 |
+
parser.add_argument(
|
180 |
+
"--dir_dataset",
|
181 |
+
default=dir_dataset,
|
182 |
+
type=str,
|
183 |
+
help="full directory path to the dataset",
|
184 |
+
)
|
185 |
+
parser.add_argument(
|
186 |
+
"--which_hw_model",
|
187 |
+
default=which_hw_model,
|
188 |
+
type=str,
|
189 |
+
choices=["crnn", "stn_crnn"],
|
190 |
+
help="which model to be used for prediction",
|
191 |
+
)
|
192 |
+
parser.add_argument(
|
193 |
+
"--which_ctc_decoder",
|
194 |
+
default=which_ctc_decoder,
|
195 |
+
type=str,
|
196 |
+
choices=["beam_search", "greedy"],
|
197 |
+
help="which ctc decoder to use",
|
198 |
+
)
|
199 |
+
parser.add_argument(
|
200 |
+
"--file_model",
|
201 |
+
default=file_model,
|
202 |
+
type=str,
|
203 |
+
help="full path to trained model file (.pth)",
|
204 |
+
)
|
205 |
+
parser.add_argument(
|
206 |
+
"--save_prediction_stats",
|
207 |
+
default=save_prediction_stats,
|
208 |
+
type=int,
|
209 |
+
choices=[0, 1],
|
210 |
+
help="save prediction stats (1 - yes, 0 - no)",
|
211 |
+
)
|
212 |
|
213 |
FLAGS, unparsed = parser.parse_known_args()
|
214 |
test_hw_recognizer(FLAGS)
|
215 |
return
|
216 |
|
217 |
+
|
218 |
if __name__ == "__main__":
|
219 |
main()
|
iam_line_recognition/train.py
CHANGED
@@ -55,12 +55,15 @@ def train(hw_model, optimizer, criterion, train_loader, device):
|
|
55 |
loss = criterion(log_probs, labels, lengths_preds, lengths_labels)
|
56 |
train_running_loss += loss.item()
|
57 |
loss.backward()
|
58 |
-
torch.nn.utils.clip_grad_norm_(
|
|
|
|
|
59 |
optimizer.step()
|
60 |
|
61 |
train_loss = train_running_loss / num_train_batches
|
62 |
return train_loss
|
63 |
|
|
|
64 |
def validate(hw_model, criterion, valid_loader, device):
|
65 |
"""
|
66 |
---------
|
@@ -114,19 +117,28 @@ def validate(hw_model, criterion, valid_loader, device):
|
|
114 |
final_labels_for_eval = []
|
115 |
length_label_counter = 0
|
116 |
for pred_label, length_label in zip(pred_labels, lengths_labels_for_eval):
|
117 |
-
label = labels_for_eval[
|
|
|
|
|
118 |
length_label_counter += length_label
|
119 |
|
120 |
final_labels_for_eval.append(label)
|
121 |
|
122 |
for i in range(len(final_labels_for_eval)):
|
123 |
if len(pred_labels[i]) != 0:
|
124 |
-
str_label = [
|
|
|
|
|
|
|
125 |
str_label = "".join(str_label)
|
126 |
-
str_pred = [
|
|
|
|
|
127 |
str_pred = "".join(str_pred)
|
128 |
|
129 |
-
cer_sample, wer_sample = compute_wer_and_cer_for_sample(
|
|
|
|
|
130 |
else:
|
131 |
cer_sample, wer_sample = 100, 100
|
132 |
|
@@ -138,6 +150,7 @@ def validate(hw_model, criterion, valid_loader, device):
|
|
138 |
valid_wer = valid_running_wer / num_valid_samples
|
139 |
return valid_loss, valid_cer, valid_wer
|
140 |
|
|
|
141 |
def train_hw_recognizer(FLAGS):
|
142 |
file_txt_labels = os.path.join(FLAGS.dir_dataset, "iam_lines_gt.txt")
|
143 |
dir_images = os.path.join(FLAGS.dir_dataset, "img")
|
@@ -156,8 +169,13 @@ def train_hw_recognizer(FLAGS):
|
|
156 |
num_valid_samples = len(valid_x)
|
157 |
# get dataloaders for train and validation sets
|
158 |
train_loader, valid_loader = get_dataloaders_for_training(
|
159 |
-
train_x,
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
161 |
batch_size=FLAGS.batch_size,
|
162 |
)
|
163 |
|
@@ -171,7 +189,7 @@ def train_hw_recognizer(FLAGS):
|
|
171 |
file_logger_train = os.path.join(dir_model, "train_metrics.csv")
|
172 |
csv_writer = CSVWriter(
|
173 |
file_name=file_logger_train,
|
174 |
-
column_names=["epoch", "loss_train", "loss_valid", "cer_valid", "wer_valid"]
|
175 |
)
|
176 |
|
177 |
file_params = os.path.join(dir_model, "params.json")
|
@@ -180,9 +198,15 @@ def train_hw_recognizer(FLAGS):
|
|
180 |
num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1
|
181 |
print(f"task - handwriting recognition")
|
182 |
print(f"model: {FLAGS.which_hw_model}")
|
183 |
-
print(
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
# load the right model
|
188 |
if FLAGS.which_hw_model == "crnn":
|
@@ -196,9 +220,19 @@ def train_hw_recognizer(FLAGS):
|
|
196 |
|
197 |
# load the right optimizer based on user option
|
198 |
if FLAGS.which_optimizer == "adam":
|
199 |
-
optimizer = torch.optim.Adam(
|
|
|
|
|
|
|
|
|
200 |
elif FLAGS.which_optimizer == "adadelta":
|
201 |
-
optimizer = torch.optim.Adadelta(
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
else:
|
203 |
print(f"unidentified option: {FLAGS.which_optimizer}")
|
204 |
sys.exit(0)
|
@@ -207,13 +241,19 @@ def train_hw_recognizer(FLAGS):
|
|
207 |
|
208 |
# start training the model
|
209 |
print(f"training of handwriting recognition model {FLAGS.which_hw_model} started\n")
|
210 |
-
for epoch in range(1, FLAGS.num_epochs+1):
|
211 |
time_start = time.time()
|
212 |
train_loss = train(hw_model, optimizer, criterion, train_loader, device)
|
213 |
-
valid_loss, valid_cer, valid_wer = validate(
|
|
|
|
|
214 |
time_end = time.time()
|
215 |
-
print(
|
216 |
-
|
|
|
|
|
|
|
|
|
217 |
|
218 |
csv_writer.write_row(
|
219 |
[
|
@@ -224,12 +264,21 @@ def train_hw_recognizer(FLAGS):
|
|
224 |
round(valid_wer, 4),
|
225 |
]
|
226 |
)
|
227 |
-
torch.save(
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
# close the csv file
|
230 |
csv_writer.close()
|
231 |
return
|
232 |
|
|
|
233 |
def main():
|
234 |
learning_rate = 1
|
235 |
# 3e-4 for Adam, 1 for Adadelta
|
@@ -248,28 +297,67 @@ def main():
|
|
248 |
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
249 |
)
|
250 |
|
251 |
-
parser.add_argument(
|
252 |
-
|
253 |
-
|
254 |
-
type=float,
|
255 |
-
|
256 |
-
|
257 |
-
parser.add_argument(
|
258 |
-
|
259 |
-
|
260 |
-
type=
|
261 |
-
|
262 |
-
|
263 |
-
parser.add_argument(
|
264 |
-
|
265 |
-
|
266 |
-
type=
|
267 |
-
|
268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
|
270 |
FLAGS, unparsed = parser.parse_known_args()
|
271 |
train_hw_recognizer(FLAGS)
|
272 |
return
|
273 |
|
|
|
274 |
if __name__ == "__main__":
|
275 |
main()
|
|
|
55 |
loss = criterion(log_probs, labels, lengths_preds, lengths_labels)
|
56 |
train_running_loss += loss.item()
|
57 |
loss.backward()
|
58 |
+
torch.nn.utils.clip_grad_norm_(
|
59 |
+
hw_model.parameters(), 5
|
60 |
+
) # gradient clipping with 5
|
61 |
optimizer.step()
|
62 |
|
63 |
train_loss = train_running_loss / num_train_batches
|
64 |
return train_loss
|
65 |
|
66 |
+
|
67 |
def validate(hw_model, criterion, valid_loader, device):
|
68 |
"""
|
69 |
---------
|
|
|
117 |
final_labels_for_eval = []
|
118 |
length_label_counter = 0
|
119 |
for pred_label, length_label in zip(pred_labels, lengths_labels_for_eval):
|
120 |
+
label = labels_for_eval[
|
121 |
+
length_label_counter : length_label_counter + length_label
|
122 |
+
]
|
123 |
length_label_counter += length_label
|
124 |
|
125 |
final_labels_for_eval.append(label)
|
126 |
|
127 |
for i in range(len(final_labels_for_eval)):
|
128 |
if len(pred_labels[i]) != 0:
|
129 |
+
str_label = [
|
130 |
+
HWRecogIAMDataset.LABEL_2_CHAR[i]
|
131 |
+
for i in final_labels_for_eval[i]
|
132 |
+
]
|
133 |
str_label = "".join(str_label)
|
134 |
+
str_pred = [
|
135 |
+
HWRecogIAMDataset.LABEL_2_CHAR[i] for i in pred_labels[i]
|
136 |
+
]
|
137 |
str_pred = "".join(str_pred)
|
138 |
|
139 |
+
cer_sample, wer_sample = compute_wer_and_cer_for_sample(
|
140 |
+
str_pred, str_label
|
141 |
+
)
|
142 |
else:
|
143 |
cer_sample, wer_sample = 100, 100
|
144 |
|
|
|
150 |
valid_wer = valid_running_wer / num_valid_samples
|
151 |
return valid_loss, valid_cer, valid_wer
|
152 |
|
153 |
+
|
154 |
def train_hw_recognizer(FLAGS):
|
155 |
file_txt_labels = os.path.join(FLAGS.dir_dataset, "iam_lines_gt.txt")
|
156 |
dir_images = os.path.join(FLAGS.dir_dataset, "img")
|
|
|
169 |
num_valid_samples = len(valid_x)
|
170 |
# get dataloaders for train and validation sets
|
171 |
train_loader, valid_loader = get_dataloaders_for_training(
|
172 |
+
train_x,
|
173 |
+
train_y,
|
174 |
+
valid_x,
|
175 |
+
valid_y,
|
176 |
+
dir_images=dir_images,
|
177 |
+
image_height=FLAGS.image_height,
|
178 |
+
image_width=FLAGS.image_width,
|
179 |
batch_size=FLAGS.batch_size,
|
180 |
)
|
181 |
|
|
|
189 |
file_logger_train = os.path.join(dir_model, "train_metrics.csv")
|
190 |
csv_writer = CSVWriter(
|
191 |
file_name=file_logger_train,
|
192 |
+
column_names=["epoch", "loss_train", "loss_valid", "cer_valid", "wer_valid"],
|
193 |
)
|
194 |
|
195 |
file_params = os.path.join(dir_model, "params.json")
|
|
|
198 |
num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1
|
199 |
print(f"task - handwriting recognition")
|
200 |
print(f"model: {FLAGS.which_hw_model}")
|
201 |
+
print(
|
202 |
+
f"optimizer: {FLAGS.which_optimizer}, learning rate: {FLAGS.learning_rate:.6f}, weight decay: {FLAGS.weight_decay:.8f}"
|
203 |
+
)
|
204 |
+
print(
|
205 |
+
f"batch size: {FLAGS.batch_size}, image height: {FLAGS.image_height}, image width: {FLAGS.image_width}"
|
206 |
+
)
|
207 |
+
print(
|
208 |
+
f"num train samples: {num_train_samples}, num validation samples: {num_valid_samples}\n"
|
209 |
+
)
|
210 |
|
211 |
# load the right model
|
212 |
if FLAGS.which_hw_model == "crnn":
|
|
|
220 |
|
221 |
# load the right optimizer based on user option
|
222 |
if FLAGS.which_optimizer == "adam":
|
223 |
+
optimizer = torch.optim.Adam(
|
224 |
+
hw_model.parameters(),
|
225 |
+
lr=FLAGS.learning_rate,
|
226 |
+
weight_decay=FLAGS.weight_decay,
|
227 |
+
)
|
228 |
elif FLAGS.which_optimizer == "adadelta":
|
229 |
+
optimizer = torch.optim.Adadelta(
|
230 |
+
hw_model.parameters(),
|
231 |
+
lr=FLAGS.learning_rate,
|
232 |
+
rho=0.95,
|
233 |
+
eps=1e-8,
|
234 |
+
weight_decay=FLAGS.weight_decay,
|
235 |
+
)
|
236 |
else:
|
237 |
print(f"unidentified option: {FLAGS.which_optimizer}")
|
238 |
sys.exit(0)
|
|
|
241 |
|
242 |
# start training the model
|
243 |
print(f"training of handwriting recognition model {FLAGS.which_hw_model} started\n")
|
244 |
+
for epoch in range(1, FLAGS.num_epochs + 1):
|
245 |
time_start = time.time()
|
246 |
train_loss = train(hw_model, optimizer, criterion, train_loader, device)
|
247 |
+
valid_loss, valid_cer, valid_wer = validate(
|
248 |
+
hw_model, criterion, valid_loader, device
|
249 |
+
)
|
250 |
time_end = time.time()
|
251 |
+
print(
|
252 |
+
f"epoch: {epoch}/{FLAGS.num_epochs}, time: {time_end-time_start:.3f} sec."
|
253 |
+
)
|
254 |
+
print(
|
255 |
+
f"train loss: {train_loss:.6f}, validation loss: {valid_loss:.6f}, validation cer: {valid_cer:.4f}, validation wer: {valid_wer:.4f}\n"
|
256 |
+
)
|
257 |
|
258 |
csv_writer.write_row(
|
259 |
[
|
|
|
264 |
round(valid_wer, 4),
|
265 |
]
|
266 |
)
|
267 |
+
torch.save(
|
268 |
+
hw_model.state_dict(),
|
269 |
+
os.path.join(
|
270 |
+
dir_model,
|
271 |
+
f"{FLAGS.which_hw_model}_H_{FLAGS.image_height}_W_{FLAGS.image_width}_E_{epoch}.pth",
|
272 |
+
),
|
273 |
+
)
|
274 |
+
print(
|
275 |
+
f"Training of handwriting recognition model {FLAGS.which_hw_model} complete!!!!"
|
276 |
+
)
|
277 |
# close the csv file
|
278 |
csv_writer.close()
|
279 |
return
|
280 |
|
281 |
+
|
282 |
def main():
|
283 |
learning_rate = 1
|
284 |
# 3e-4 for Adam, 1 for Adadelta
|
|
|
297 |
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
298 |
)
|
299 |
|
300 |
+
parser.add_argument(
|
301 |
+
"--learning_rate",
|
302 |
+
default=learning_rate,
|
303 |
+
type=float,
|
304 |
+
help="learning rate to use for training",
|
305 |
+
)
|
306 |
+
parser.add_argument(
|
307 |
+
"--weight_decay",
|
308 |
+
default=weight_decay,
|
309 |
+
type=float,
|
310 |
+
help="weight decay to use for training",
|
311 |
+
)
|
312 |
+
parser.add_argument(
|
313 |
+
"--batch_size",
|
314 |
+
default=batch_size,
|
315 |
+
type=int,
|
316 |
+
help="batch size to use for training",
|
317 |
+
)
|
318 |
+
parser.add_argument(
|
319 |
+
"--num_epochs",
|
320 |
+
default=num_epochs,
|
321 |
+
type=int,
|
322 |
+
help="num epochs to train the model",
|
323 |
+
)
|
324 |
+
parser.add_argument(
|
325 |
+
"--image_height",
|
326 |
+
default=image_height,
|
327 |
+
type=int,
|
328 |
+
help="image height to be used to train the model",
|
329 |
+
)
|
330 |
+
parser.add_argument(
|
331 |
+
"--image_width",
|
332 |
+
default=image_width,
|
333 |
+
type=int,
|
334 |
+
help="image width to be used to train the model",
|
335 |
+
)
|
336 |
+
parser.add_argument(
|
337 |
+
"--dir_dataset",
|
338 |
+
default=dir_dataset,
|
339 |
+
type=str,
|
340 |
+
help="full directory path to the dataset",
|
341 |
+
)
|
342 |
+
parser.add_argument(
|
343 |
+
"--which_optimizer",
|
344 |
+
default=which_optimizer,
|
345 |
+
type=str,
|
346 |
+
choices=["adadelta", "adam"],
|
347 |
+
help="which optimizer to use to train",
|
348 |
+
)
|
349 |
+
parser.add_argument(
|
350 |
+
"--which_hw_model",
|
351 |
+
default=which_hw_model,
|
352 |
+
type=str,
|
353 |
+
choices=["crnn", "stn_crnn", "stn_pp_crnn"],
|
354 |
+
help="which model to train",
|
355 |
+
)
|
356 |
|
357 |
FLAGS, unparsed = parser.parse_known_args()
|
358 |
train_hw_recognizer(FLAGS)
|
359 |
return
|
360 |
|
361 |
+
|
362 |
if __name__ == "__main__":
|
363 |
main()
|
iam_line_recognition/utils.py
CHANGED
@@ -13,6 +13,7 @@ from scipy.special import logsumexp
|
|
13 |
NINF = -1 * float("inf")
|
14 |
DEFAULT_EMISSION_THRESHOLD = 0.01
|
15 |
|
|
|
16 |
def _reconstruct(labels, blank=0):
|
17 |
new_labels = []
|
18 |
# merge same labels
|
@@ -25,9 +26,12 @@ def _reconstruct(labels, blank=0):
|
|
25 |
new_labels = [l for l in new_labels if l != blank]
|
26 |
return new_labels
|
27 |
|
|
|
28 |
def beam_search_decode(emission_log_prob, blank=0, **kwargs):
|
29 |
beam_size = kwargs["beam_size"]
|
30 |
-
emission_threshold = kwargs.get(
|
|
|
|
|
31 |
|
32 |
length, class_count = emission_log_prob.shape
|
33 |
|
@@ -53,29 +57,38 @@ def beam_search_decode(emission_log_prob, blank=0, **kwargs):
|
|
53 |
for prefix, accu_log_prob in beams:
|
54 |
labels = tuple(_reconstruct(prefix, blank))
|
55 |
# log(p1 + p2) = logsumexp([log_p1, log_p2])
|
56 |
-
total_accu_log_prob[labels] =
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
61 |
labels_beams.sort(key=lambda x: x[1], reverse=True)
|
62 |
labels = labels_beams[0][0]
|
63 |
|
64 |
return labels
|
65 |
|
|
|
66 |
def greedy_decode(emission_log_prob, blank=0):
|
67 |
labels = np.argmax(emission_log_prob, axis=-1)
|
68 |
labels = _reconstruct(labels, blank=blank)
|
69 |
return labels
|
70 |
|
71 |
-
|
|
|
|
|
|
|
72 |
emission_log_probs = np.transpose(log_probs.cpu().numpy(), (1, 0, 2))
|
73 |
# size of emission_log_probs: (batch, length, class)
|
74 |
|
75 |
decoded_list = []
|
76 |
for emission_log_prob in emission_log_probs:
|
77 |
if which_ctc_decoder == "beam_search":
|
78 |
-
decoded = beam_search_decode(
|
|
|
|
|
79 |
elif which_ctc_decoder == "greedy":
|
80 |
decoded = greedy_decode(emission_log_prob, blank=blank)
|
81 |
else:
|
@@ -87,16 +100,20 @@ def ctc_decode(log_probs, which_ctc_decoder="beam_search", label_2_char=None, bl
|
|
87 |
decoded_list.append(decoded)
|
88 |
return decoded_list
|
89 |
|
|
|
90 |
"""
|
91 |
--------------------
|
92 |
Evaluation Metrics
|
93 |
--------------------
|
94 |
"""
|
|
|
|
|
95 |
def compute_wer_and_cer_for_batch(batch_preds, batch_gts):
|
96 |
cer_batch = fastwer.score(batch_preds, batch_gts, char_level=True)
|
97 |
wer_batch = fastwer.score(batch_preds, batch_gts)
|
98 |
return cer_batch, wer_batch
|
99 |
|
|
|
100 |
def compute_wer_and_cer_for_sample(str_pred, str_gt):
|
101 |
cer_sample = fastwer.score_sent(str_pred, str_gt, char_level=True)
|
102 |
wer_sample = fastwer.score_sent(str_pred, str_gt)
|
|
|
13 |
NINF = -1 * float("inf")
|
14 |
DEFAULT_EMISSION_THRESHOLD = 0.01
|
15 |
|
16 |
+
|
17 |
def _reconstruct(labels, blank=0):
|
18 |
new_labels = []
|
19 |
# merge same labels
|
|
|
26 |
new_labels = [l for l in new_labels if l != blank]
|
27 |
return new_labels
|
28 |
|
29 |
+
|
30 |
def beam_search_decode(emission_log_prob, blank=0, **kwargs):
|
31 |
beam_size = kwargs["beam_size"]
|
32 |
+
emission_threshold = kwargs.get(
|
33 |
+
"emission_threshold", np.log(DEFAULT_EMISSION_THRESHOLD)
|
34 |
+
)
|
35 |
|
36 |
length, class_count = emission_log_prob.shape
|
37 |
|
|
|
57 |
for prefix, accu_log_prob in beams:
|
58 |
labels = tuple(_reconstruct(prefix, blank))
|
59 |
# log(p1 + p2) = logsumexp([log_p1, log_p2])
|
60 |
+
total_accu_log_prob[labels] = logsumexp(
|
61 |
+
[accu_log_prob, total_accu_log_prob.get(labels, NINF)]
|
62 |
+
)
|
63 |
+
|
64 |
+
labels_beams = [
|
65 |
+
(list(labels), accu_log_prob)
|
66 |
+
for labels, accu_log_prob in total_accu_log_prob.items()
|
67 |
+
]
|
68 |
labels_beams.sort(key=lambda x: x[1], reverse=True)
|
69 |
labels = labels_beams[0][0]
|
70 |
|
71 |
return labels
|
72 |
|
73 |
+
|
74 |
def greedy_decode(emission_log_prob, blank=0):
|
75 |
labels = np.argmax(emission_log_prob, axis=-1)
|
76 |
labels = _reconstruct(labels, blank=blank)
|
77 |
return labels
|
78 |
|
79 |
+
|
80 |
+
def ctc_decode(
|
81 |
+
log_probs, which_ctc_decoder="beam_search", label_2_char=None, blank=0, beam_size=25
|
82 |
+
):
|
83 |
emission_log_probs = np.transpose(log_probs.cpu().numpy(), (1, 0, 2))
|
84 |
# size of emission_log_probs: (batch, length, class)
|
85 |
|
86 |
decoded_list = []
|
87 |
for emission_log_prob in emission_log_probs:
|
88 |
if which_ctc_decoder == "beam_search":
|
89 |
+
decoded = beam_search_decode(
|
90 |
+
emission_log_prob, blank=blank, beam_size=beam_size
|
91 |
+
)
|
92 |
elif which_ctc_decoder == "greedy":
|
93 |
decoded = greedy_decode(emission_log_prob, blank=blank)
|
94 |
else:
|
|
|
100 |
decoded_list.append(decoded)
|
101 |
return decoded_list
|
102 |
|
103 |
+
|
104 |
"""
|
105 |
--------------------
|
106 |
Evaluation Metrics
|
107 |
--------------------
|
108 |
"""
|
109 |
+
|
110 |
+
|
111 |
def compute_wer_and_cer_for_batch(batch_preds, batch_gts):
|
112 |
cer_batch = fastwer.score(batch_preds, batch_gts, char_level=True)
|
113 |
wer_batch = fastwer.score(batch_preds, batch_gts)
|
114 |
return cer_batch, wer_batch
|
115 |
|
116 |
+
|
117 |
def compute_wer_and_cer_for_sample(str_pred, str_gt):
|
118 |
cer_sample = fastwer.score_sent(str_pred, str_gt, char_level=True)
|
119 |
wer_sample = fastwer.score_sent(str_pred, str_gt)
|
iam_line_recognition/utils_unique_chars.py
CHANGED
@@ -3,6 +3,7 @@ import numpy as np
|
|
3 |
|
4 |
from dataset import read_IAM_label_txt_file
|
5 |
|
|
|
6 |
def list_unique_characters_in_IAM_dataset(FLAGS):
|
7 |
_, all_labels = read_IAM_label_txt_file(FLAGS.file_txt_labels)
|
8 |
|
@@ -16,8 +17,8 @@ def list_unique_characters_in_IAM_dataset(FLAGS):
|
|
16 |
unique_chars = sorted(unique_chars)
|
17 |
unique_chars = np.array(unique_chars)
|
18 |
unique_chars = np.unique(unique_chars)
|
19 |
-
unique_chars =
|
20 |
-
|
21 |
# prints all unique chars in the IAM dataset
|
22 |
print(unique_chars)
|
23 |
|
@@ -25,19 +26,27 @@ def list_unique_characters_in_IAM_dataset(FLAGS):
|
|
25 |
print(f"Number of unique characters : {len(unique_chars)}")
|
26 |
return
|
27 |
|
|
|
28 |
def main():
|
29 |
-
file_txt_labels =
|
|
|
|
|
30 |
|
31 |
parser = argparse.ArgumentParser(
|
32 |
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
33 |
)
|
34 |
|
35 |
-
parser.add_argument(
|
36 |
-
|
|
|
|
|
|
|
|
|
37 |
|
38 |
FLAGS, unparsed = parser.parse_known_args()
|
39 |
list_unique_characters_in_IAM_dataset(FLAGS)
|
40 |
return
|
41 |
|
|
|
42 |
if __name__ == "__main__":
|
43 |
main()
|
|
|
3 |
|
4 |
from dataset import read_IAM_label_txt_file
|
5 |
|
6 |
+
|
7 |
def list_unique_characters_in_IAM_dataset(FLAGS):
|
8 |
_, all_labels = read_IAM_label_txt_file(FLAGS.file_txt_labels)
|
9 |
|
|
|
17 |
unique_chars = sorted(unique_chars)
|
18 |
unique_chars = np.array(unique_chars)
|
19 |
unique_chars = np.unique(unique_chars)
|
20 |
+
unique_chars = "".join(unique_chars)
|
21 |
+
|
22 |
# prints all unique chars in the IAM dataset
|
23 |
print(unique_chars)
|
24 |
|
|
|
26 |
print(f"Number of unique characters : {len(unique_chars)}")
|
27 |
return
|
28 |
|
29 |
+
|
30 |
def main():
|
31 |
+
file_txt_labels = (
|
32 |
+
"/home/abhishek/Desktop/RUG/hw_recognition/IAM-data/iam_lines_gt.txt"
|
33 |
+
)
|
34 |
|
35 |
parser = argparse.ArgumentParser(
|
36 |
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
37 |
)
|
38 |
|
39 |
+
parser.add_argument(
|
40 |
+
"--file_txt_labels",
|
41 |
+
default=file_txt_labels,
|
42 |
+
type=str,
|
43 |
+
help="full path to label text file",
|
44 |
+
)
|
45 |
|
46 |
FLAGS, unparsed = parser.parse_known_args()
|
47 |
list_unique_characters_in_IAM_dataset(FLAGS)
|
48 |
return
|
49 |
|
50 |
+
|
51 |
if __name__ == "__main__":
|
52 |
main()
|