Spaces:
Sleeping
Sleeping
abhishekrs4
commited on
Commit
•
bd421ea
1
Parent(s):
2cabce6
added iam_line_recognition module
Browse files- iam_line_recognition/__init__.py +3 -0
- iam_line_recognition/dataset.py +253 -0
- iam_line_recognition/final_iam_line_recognizer.py +205 -0
- iam_line_recognition/logger_utils.py +60 -0
- iam_line_recognition/model_main.py +151 -0
- iam_line_recognition/model_visual_features.py +401 -0
- iam_line_recognition/test_internal.py +164 -0
- iam_line_recognition/train.py +275 -0
- iam_line_recognition/utils.py +103 -0
- iam_line_recognition/utils_unique_chars.py +43 -0
iam_line_recognition/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
iam_line_recognition/dataset.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from skimage.io import imread
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
from torch.utils.data import Dataset, DataLoader
|
9 |
+
from sklearn.model_selection import train_test_split
|
10 |
+
|
11 |
+
def read_IAM_label_txt_file(file_txt_labels):
|
12 |
+
"""
|
13 |
+
---------
|
14 |
+
Arguments
|
15 |
+
---------
|
16 |
+
file_txt_labels : str
|
17 |
+
full path to the text file containing labels
|
18 |
+
|
19 |
+
-------
|
20 |
+
Returns
|
21 |
+
-------
|
22 |
+
a tuple of
|
23 |
+
all_image_files : list
|
24 |
+
a list of all image file names
|
25 |
+
all_labels : list
|
26 |
+
a list of all labels
|
27 |
+
"""
|
28 |
+
label_file_handler = open(file_txt_labels, mode="r")
|
29 |
+
all_lines = label_file_handler.readlines()
|
30 |
+
num_lines = len(all_lines)
|
31 |
+
|
32 |
+
all_image_files = []
|
33 |
+
all_labels = []
|
34 |
+
|
35 |
+
for cur_line_num in range(num_lines):
|
36 |
+
if cur_line_num % 3 == 0:
|
37 |
+
all_image_files.append(all_lines[cur_line_num].strip())
|
38 |
+
elif cur_line_num % 3 == 1:
|
39 |
+
all_labels.append(all_lines[cur_line_num].strip())
|
40 |
+
else:
|
41 |
+
continue
|
42 |
+
|
43 |
+
return all_image_files, all_labels
|
44 |
+
|
45 |
+
class HWRecogIAMDataset(Dataset):
|
46 |
+
"""
|
47 |
+
Main dataset class to be used only for training, validation and internal testing
|
48 |
+
"""
|
49 |
+
CHAR_SET = ' !"#&\'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
50 |
+
CHAR_2_LABEL = {char: i + 1 for i, char in enumerate(CHAR_SET)}
|
51 |
+
LABEL_2_CHAR = {label: char for char, label in CHAR_2_LABEL.items()}
|
52 |
+
|
53 |
+
def __init__(self, list_image_files, list_labels, dir_images, image_height=32, image_width=768, which_set="train"):
|
54 |
+
"""
|
55 |
+
---------
|
56 |
+
Arguments
|
57 |
+
---------
|
58 |
+
list_image_files : list
|
59 |
+
list of image files
|
60 |
+
list_labels : list
|
61 |
+
list of labels
|
62 |
+
dir_images : str
|
63 |
+
full path to directory containing images
|
64 |
+
image_height : int
|
65 |
+
image height (default: 32)
|
66 |
+
image_width : int
|
67 |
+
image width (default: 768)
|
68 |
+
which_set : str
|
69 |
+
a string indicating which set is being used (default: train)
|
70 |
+
"""
|
71 |
+
self.list_labels = list_labels
|
72 |
+
self.dir_images = dir_images
|
73 |
+
self.list_image_files = list_image_files
|
74 |
+
self.image_width = image_width
|
75 |
+
self.image_height = image_height
|
76 |
+
self.which_set = which_set
|
77 |
+
|
78 |
+
if self.which_set == "train":
|
79 |
+
# apply data augmentation only for train set
|
80 |
+
self.transform = transforms.Compose([
|
81 |
+
transforms.ToPILImage(),
|
82 |
+
transforms.Resize((self.image_height, self.image_width), Image.BILINEAR),
|
83 |
+
transforms.RandomAffine(degrees=[-0.75, 0.75], translate=[0, 0.05], scale=[0.75, 1],
|
84 |
+
shear=[-35, 35], interpolation=transforms.InterpolationMode.BILINEAR, fill=255,
|
85 |
+
),
|
86 |
+
transforms.ToTensor(),
|
87 |
+
transforms.Normalize(
|
88 |
+
mean=[0.485, 0.456, 0.406],
|
89 |
+
std=[0.229, 0.224, 0.225],
|
90 |
+
),
|
91 |
+
])
|
92 |
+
else:
|
93 |
+
self.transform = transforms.Compose([
|
94 |
+
transforms.ToPILImage(),
|
95 |
+
transforms.Resize((self.image_height, self.image_width), Image.BILINEAR),
|
96 |
+
transforms.ToTensor(),
|
97 |
+
transforms.Normalize(
|
98 |
+
mean=[0.485, 0.456, 0.406],
|
99 |
+
std=[0.229, 0.224, 0.225],
|
100 |
+
),
|
101 |
+
])
|
102 |
+
|
103 |
+
def __len__(self):
|
104 |
+
return len(self.list_image_files)
|
105 |
+
|
106 |
+
def __getitem__(self, idx):
|
107 |
+
image_file_name = self.list_image_files[idx]
|
108 |
+
image_gray = imread(os.path.join(self.dir_images, image_file_name))
|
109 |
+
image_3_channel = np.repeat(np.expand_dims(image_gray, -1), 3, -1)
|
110 |
+
image_3_channel = self.transform(image_3_channel)
|
111 |
+
|
112 |
+
label_string = self.list_labels[idx]
|
113 |
+
label_encoded = [self.CHAR_2_LABEL[c] for c in label_string]
|
114 |
+
label_length = [len(label_encoded)]
|
115 |
+
|
116 |
+
label_encoded = torch.LongTensor(label_encoded)
|
117 |
+
label_length = torch.LongTensor(label_length)
|
118 |
+
|
119 |
+
return image_3_channel, label_encoded, label_length
|
120 |
+
|
121 |
+
def IAM_collate_fn(batch):
|
122 |
+
"""
|
123 |
+
collate function
|
124 |
+
|
125 |
+
---------
|
126 |
+
Arguments
|
127 |
+
---------
|
128 |
+
batch : tuple
|
129 |
+
a batch of input data as a tuple
|
130 |
+
|
131 |
+
-------
|
132 |
+
Returns
|
133 |
+
-------
|
134 |
+
a collated tuple of
|
135 |
+
images : tensor
|
136 |
+
tensor of batch images
|
137 |
+
labels : tensor
|
138 |
+
tensor of batch labels
|
139 |
+
label_lengths : tensor
|
140 |
+
tensor of batch label lengths
|
141 |
+
"""
|
142 |
+
images, labels, label_lengths = zip(*batch)
|
143 |
+
images = torch.stack(images, 0)
|
144 |
+
labels = torch.cat(labels, 0)
|
145 |
+
label_lengths = torch.cat(label_lengths, 0)
|
146 |
+
return images, labels, label_lengths
|
147 |
+
|
148 |
+
def split_dataset(file_txt_labels, for_train=True):
|
149 |
+
"""
|
150 |
+
---------
|
151 |
+
Arguments
|
152 |
+
---------
|
153 |
+
file_txt_labels : str
|
154 |
+
full path to the text file containing labels
|
155 |
+
for_train : bool
|
156 |
+
indicating whether split is for training or internal testing
|
157 |
+
|
158 |
+
-------
|
159 |
+
Returns
|
160 |
+
-------
|
161 |
+
a tuple of files depending for train or internal testing
|
162 |
+
"""
|
163 |
+
all_image_files, all_labels = read_IAM_label_txt_file(file_txt_labels)
|
164 |
+
train_image_files, test_image_files, train_labels, test_labels = train_test_split(all_image_files, all_labels, test_size=0.1, random_state=4)
|
165 |
+
train_image_files, valid_image_files, train_labels, valid_labels = train_test_split(train_image_files, train_labels, test_size=0.1, random_state=4)
|
166 |
+
if for_train:
|
167 |
+
return train_image_files, valid_image_files, train_labels, valid_labels
|
168 |
+
else:
|
169 |
+
return test_image_files, test_labels
|
170 |
+
|
171 |
+
def get_dataloaders_for_training(train_x, train_y, valid_x, valid_y, dir_images, image_height=32, image_width=768, batch_size=8):
|
172 |
+
"""
|
173 |
+
---------
|
174 |
+
Arguments
|
175 |
+
---------
|
176 |
+
train_x : list
|
177 |
+
list of train file names
|
178 |
+
train_y : list
|
179 |
+
list of train labels
|
180 |
+
valid_x : list
|
181 |
+
list of validation file names
|
182 |
+
valid_y : list
|
183 |
+
list of validation labels
|
184 |
+
dir_images : str
|
185 |
+
full directory path containing the images
|
186 |
+
image_height : int
|
187 |
+
image height (default: 32)
|
188 |
+
image_width : int
|
189 |
+
image width (default: 768)
|
190 |
+
batch_size : int
|
191 |
+
batch size (default: 8)
|
192 |
+
|
193 |
+
-------
|
194 |
+
Returns
|
195 |
+
-------
|
196 |
+
a tuple of dataloaders objects
|
197 |
+
train_loader : object
|
198 |
+
object of train set dataloader
|
199 |
+
valid_loader : object
|
200 |
+
object of validation set dataloader
|
201 |
+
"""
|
202 |
+
train_dataset = HWRecogIAMDataset(train_x, train_y, dir_images, image_height=image_height, image_width=image_width, which_set="train")
|
203 |
+
valid_dataset = HWRecogIAMDataset(valid_x, valid_y, dir_images, image_height=image_height, image_width=image_width, which_set="valid")
|
204 |
+
|
205 |
+
train_loader = DataLoader(
|
206 |
+
train_dataset,
|
207 |
+
batch_size=batch_size,
|
208 |
+
shuffle=True,
|
209 |
+
num_workers=4,
|
210 |
+
collate_fn=IAM_collate_fn,
|
211 |
+
)
|
212 |
+
valid_loader = DataLoader(
|
213 |
+
valid_dataset,
|
214 |
+
batch_size=batch_size,
|
215 |
+
shuffle=False,
|
216 |
+
num_workers=4,
|
217 |
+
collate_fn=IAM_collate_fn,
|
218 |
+
)
|
219 |
+
return train_loader, valid_loader
|
220 |
+
|
221 |
+
def get_dataloader_for_testing(test_x, test_y, dir_images, image_height=32, image_width=768, batch_size=1):
|
222 |
+
"""
|
223 |
+
---------
|
224 |
+
Arguments
|
225 |
+
---------
|
226 |
+
test_x : list
|
227 |
+
list of test file names
|
228 |
+
test_y : list
|
229 |
+
list of test labels
|
230 |
+
dir_images : str
|
231 |
+
full directory path containing the images
|
232 |
+
image_height : int
|
233 |
+
image height (default: 32)
|
234 |
+
image_width : int
|
235 |
+
image width (default: 768)
|
236 |
+
batch_size : int
|
237 |
+
batch size (default: 1)
|
238 |
+
|
239 |
+
-------
|
240 |
+
Returns
|
241 |
+
-------
|
242 |
+
test_loader : object
|
243 |
+
object of test set dataloader
|
244 |
+
"""
|
245 |
+
test_dataset = HWRecogIAMDataset(test_x, test_y, dir_images=dir_images, image_height=image_height, image_width=image_width, which_set="test")
|
246 |
+
test_loader = DataLoader(
|
247 |
+
test_dataset,
|
248 |
+
batch_size=batch_size,
|
249 |
+
shuffle=False,
|
250 |
+
num_workers=4,
|
251 |
+
collate_fn=IAM_collate_fn,
|
252 |
+
)
|
253 |
+
return test_loader
|
iam_line_recognition/final_iam_line_recognizer.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
import argparse
|
6 |
+
import torchvision
|
7 |
+
import numpy as np
|
8 |
+
import torch.nn as nn
|
9 |
+
from PIL import Image
|
10 |
+
from skimage.io import imread
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
import torchvision.transforms as transforms
|
14 |
+
|
15 |
+
from dataset import HWRecogIAMDataset
|
16 |
+
from model_main import CRNN, STN_CRNN
|
17 |
+
from utils import ctc_decode, compute_wer_and_cer_for_sample
|
18 |
+
|
19 |
+
|
20 |
+
class DatasetFinalEval(HWRecogIAMDataset):
|
21 |
+
"""
|
22 |
+
Dataset class for final evaluation - inherits main dataset class
|
23 |
+
"""
|
24 |
+
def __init__(self, dir_images, image_height=32, image_width=768):
|
25 |
+
"""
|
26 |
+
---------
|
27 |
+
Arguments
|
28 |
+
---------
|
29 |
+
dir_images : str
|
30 |
+
full path to directory containing images
|
31 |
+
image_height : int
|
32 |
+
image height (default: 32)
|
33 |
+
image_width : int
|
34 |
+
image width (default: 768)
|
35 |
+
"""
|
36 |
+
self.dir_images = dir_images
|
37 |
+
self.image_files = [f for f in os.listdir(self.dir_images) if f.endswith(".png")]
|
38 |
+
self.image_width = image_width
|
39 |
+
self.image_height = image_height
|
40 |
+
self.transform = transforms.Compose([
|
41 |
+
transforms.ToPILImage(),
|
42 |
+
transforms.Resize((self.image_height, self.image_width), Image.BILINEAR),
|
43 |
+
transforms.ToTensor(),
|
44 |
+
transforms.Normalize(
|
45 |
+
mean=[0.485, 0.456, 0.406],
|
46 |
+
std=[0.229, 0.224, 0.225],
|
47 |
+
),
|
48 |
+
])
|
49 |
+
|
50 |
+
def __len__(self):
|
51 |
+
return len(self.image_files)
|
52 |
+
|
53 |
+
def __getitem__(self, idx):
|
54 |
+
image_file_name = self.image_files[idx]
|
55 |
+
image_gray = imread(os.path.join(self.dir_images, image_file_name))
|
56 |
+
image_3_channel = np.repeat(np.expand_dims(image_gray, -1), 3, -1)
|
57 |
+
image_3_channel = self.transform(image_3_channel)
|
58 |
+
return image_3_channel
|
59 |
+
|
60 |
+
def get_dataloader_for_evaluation(dir_images, image_height=32, image_width=768, batch_size=1):
|
61 |
+
"""
|
62 |
+
---------
|
63 |
+
Arguments
|
64 |
+
---------
|
65 |
+
dir_images : str
|
66 |
+
full path to directory containing images
|
67 |
+
image_height : int
|
68 |
+
image height (default: 32)
|
69 |
+
image_width : int
|
70 |
+
image width (default: 768)
|
71 |
+
batch_size : int
|
72 |
+
batch size to use for final evaluation (default: 1)
|
73 |
+
|
74 |
+
-------
|
75 |
+
Returns
|
76 |
+
-------
|
77 |
+
test_loader : object
|
78 |
+
dataset loader object for final evaluation
|
79 |
+
"""
|
80 |
+
test_dataset = DatasetFinalEval(dir_images=dir_images, image_height=image_height, image_width=image_width)
|
81 |
+
test_loader = DataLoader(
|
82 |
+
test_dataset,
|
83 |
+
batch_size=batch_size,
|
84 |
+
shuffle=False,
|
85 |
+
num_workers=4,
|
86 |
+
)
|
87 |
+
return test_loader
|
88 |
+
|
89 |
+
def final_eval(hw_model, device, test_loader, dir_images, dir_results):
|
90 |
+
"""
|
91 |
+
---------
|
92 |
+
Arguments
|
93 |
+
---------
|
94 |
+
hw_model : object
|
95 |
+
handwriting recognition model object
|
96 |
+
device : str
|
97 |
+
device to be used for running the evaluation
|
98 |
+
test_loader : object
|
99 |
+
dataset loader object
|
100 |
+
dir_images : str
|
101 |
+
full path to directory containing test images
|
102 |
+
dir_results : str
|
103 |
+
relative path to directory to save the predictions as txt files
|
104 |
+
"""
|
105 |
+
hw_model.eval()
|
106 |
+
count = 0
|
107 |
+
num_test_samples = len(test_loader.dataset)
|
108 |
+
list_test_files = os.listdir(dir_images)
|
109 |
+
|
110 |
+
if not os.path.isdir(dir_results):
|
111 |
+
print(f"creating directory: {dir_results}")
|
112 |
+
os.makedirs(dir_results)
|
113 |
+
|
114 |
+
with torch.no_grad():
|
115 |
+
for image_test in test_loader:
|
116 |
+
file_test = list_test_files[count]
|
117 |
+
count += 1
|
118 |
+
"""
|
119 |
+
if count == 11:
|
120 |
+
break
|
121 |
+
"""
|
122 |
+
image_test = image_test.to(device, dtype=torch.float)
|
123 |
+
|
124 |
+
log_probs = hw_model(image_test)
|
125 |
+
pred_labels = ctc_decode(log_probs)
|
126 |
+
str_pred = [DatasetFinalEval.LABEL_2_CHAR[i] for i in pred_labels[0]]
|
127 |
+
str_pred = "".join(str_pred)
|
128 |
+
|
129 |
+
with open(os.path.join(dir_results, file_test+".txt"), "w", encoding="utf-8", newline="\n") as fh_pred:
|
130 |
+
fh_pred.write(str_pred)
|
131 |
+
|
132 |
+
print(f"progress: {count}/{num_test_samples}, test file: {list_test_files[count-1]}")
|
133 |
+
print(f"{str_pred}\n")
|
134 |
+
print(f"predictions saved in directory: ./{dir_results}\n")
|
135 |
+
return
|
136 |
+
|
137 |
+
def test_hw_recognizer(FLAGS):
|
138 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
139 |
+
|
140 |
+
num_classes = len(DatasetFinalEval.LABEL_2_CHAR) + 1
|
141 |
+
print(f"task - handwriting recognition")
|
142 |
+
print(f"model: {FLAGS.which_hw_model}")
|
143 |
+
print(f"image height: {FLAGS.image_height}, image width: {FLAGS.image_width}")
|
144 |
+
|
145 |
+
# load the right model
|
146 |
+
if FLAGS.which_hw_model == "crnn":
|
147 |
+
hw_model = CRNN(num_classes, FLAGS.image_height)
|
148 |
+
elif FLAGS.which_hw_model == "stn_crnn":
|
149 |
+
hw_model = STN_CRNN(num_classes, FLAGS.image_height, FLAGS.image_width)
|
150 |
+
else:
|
151 |
+
print(f"unidentified option : {FLAGS.which_hw_model}")
|
152 |
+
sys.exit(0)
|
153 |
+
dir_results = f"results_{FLAGS.which_hw_model}"
|
154 |
+
|
155 |
+
# choose a device for evaluation
|
156 |
+
if torch.cuda.is_available():
|
157 |
+
device = torch.device("cuda")
|
158 |
+
else:
|
159 |
+
device = torch.device("cpu")
|
160 |
+
|
161 |
+
hw_model.to(device)
|
162 |
+
hw_model.load_state_dict(torch.load(FLAGS.file_model))
|
163 |
+
|
164 |
+
# get test set dataloader
|
165 |
+
test_loader = get_dataloader_for_evaluation(
|
166 |
+
dir_images=FLAGS.dir_images, image_height=FLAGS.image_height, image_width=FLAGS.image_width,
|
167 |
+
)
|
168 |
+
|
169 |
+
# start the evaluation on the final test set
|
170 |
+
print(f"final evaluation of handwriting recognition model {FLAGS.which_hw_model} started\n")
|
171 |
+
final_eval(hw_model, device, test_loader, FLAGS.dir_images, dir_results)
|
172 |
+
print(f"final evaluation of handwriting recognition model completed!!!!")
|
173 |
+
return
|
174 |
+
|
175 |
+
def main():
|
176 |
+
image_height = 32
|
177 |
+
image_width = 768
|
178 |
+
which_hw_model = "crnn"
|
179 |
+
dir_images = "/home/abhishek/Desktop/RUG/hw_recognition/IAM-data/img/"
|
180 |
+
file_model = "model_crnn/crnn_H_32_W_768_E_177.pth"
|
181 |
+
save_predictions = 1
|
182 |
+
|
183 |
+
parser = argparse.ArgumentParser(
|
184 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
185 |
+
)
|
186 |
+
|
187 |
+
parser.add_argument("--image_height", default=image_height,
|
188 |
+
type=int, help="image height to be used to predict with the model")
|
189 |
+
parser.add_argument("--image_width", default=image_width,
|
190 |
+
type=int, help="image width to be used to predict with the model")
|
191 |
+
parser.add_argument("--dir_images", default=dir_images,
|
192 |
+
type=str, help="full directory path to directory containing images")
|
193 |
+
parser.add_argument("--which_hw_model", default=which_hw_model,
|
194 |
+
type=str, choices=["crnn", "stn_crnn"], help="which model to be used for prediction")
|
195 |
+
parser.add_argument("--file_model", default=file_model,
|
196 |
+
type=str, help="full path to trained model file (.pth)")
|
197 |
+
parser.add_argument("--save_predictions", default=save_predictions,
|
198 |
+
type=int, choices=[0, 1], help="save or do not save the predictions (1 - save, 0 - do not save)")
|
199 |
+
|
200 |
+
FLAGS, unparsed = parser.parse_known_args()
|
201 |
+
test_hw_recognizer(FLAGS)
|
202 |
+
return
|
203 |
+
|
204 |
+
if __name__ == "__main__":
|
205 |
+
main()
|
iam_line_recognition/logger_utils.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import json
|
3 |
+
|
4 |
+
def write_json_file(file_json, dict_data):
|
5 |
+
"""
|
6 |
+
---------
|
7 |
+
Arguments
|
8 |
+
---------
|
9 |
+
file_json : str
|
10 |
+
full path of json file to be saved
|
11 |
+
dict_data : dict
|
12 |
+
dictionary of params to be saved in the json file
|
13 |
+
"""
|
14 |
+
with open(file_json, "w", encoding="utf-8") as fh:
|
15 |
+
fh.write(json.dumps(dict_data, indent=4))
|
16 |
+
return
|
17 |
+
|
18 |
+
class CSVWriter:
|
19 |
+
"""
|
20 |
+
for writing tabular data to a csv file
|
21 |
+
"""
|
22 |
+
def __init__(self, file_name, column_names):
|
23 |
+
"""
|
24 |
+
---------
|
25 |
+
Arguments
|
26 |
+
---------
|
27 |
+
file_name : str
|
28 |
+
full path of csv file
|
29 |
+
column_names : list
|
30 |
+
a list of columns names to be used to create the csv file
|
31 |
+
"""
|
32 |
+
self.file_name = file_name
|
33 |
+
self.column_names = column_names
|
34 |
+
|
35 |
+
self.file_handle = open(self.file_name, "w", encoding="utf-8", newline="\n")
|
36 |
+
self.writer = csv.writer(self.file_handle)
|
37 |
+
|
38 |
+
self.write_header()
|
39 |
+
print(f"{self.file_name} created successfully with header row")
|
40 |
+
|
41 |
+
def write_header(self):
|
42 |
+
"""
|
43 |
+
writes header into csv file
|
44 |
+
"""
|
45 |
+
self.write_row(self.column_names)
|
46 |
+
return
|
47 |
+
|
48 |
+
def write_row(self, row):
|
49 |
+
"""
|
50 |
+
writes a row into csv file
|
51 |
+
"""
|
52 |
+
self.writer.writerow(row)
|
53 |
+
return
|
54 |
+
|
55 |
+
def close(self):
|
56 |
+
"""
|
57 |
+
close the file
|
58 |
+
"""
|
59 |
+
self.file_handle.close()
|
60 |
+
return
|
iam_line_recognition/model_main.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchvision
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from model_visual_features import ResNetFeatureExtractor, TPS_SpatialTransformerNetwork
|
6 |
+
|
7 |
+
class HW_RNN_Seq2Seq(nn.Module):
|
8 |
+
"""
|
9 |
+
Visual Seq2Seq model using BiLSTM
|
10 |
+
"""
|
11 |
+
def __init__(self, num_classes, image_height, cnn_output_channels=512, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
|
12 |
+
"""
|
13 |
+
---------
|
14 |
+
Arguments
|
15 |
+
---------
|
16 |
+
num_classes : int
|
17 |
+
num of distinct characters (classes) in the dataset
|
18 |
+
image_height : int
|
19 |
+
image height
|
20 |
+
cnn_output_channels : int
|
21 |
+
number of channels output from the CNN visual feature extractor (default: 512)
|
22 |
+
num_feats_mapped_seq_hidden : int
|
23 |
+
number of features to be used in the mapped visual features as sequences (default: 128)
|
24 |
+
num_feats_seq_hidden : int
|
25 |
+
number of features to be used in the LSTM for sequence modeling (default: 256)
|
26 |
+
"""
|
27 |
+
super().__init__()
|
28 |
+
self.output_height = image_height // 32
|
29 |
+
|
30 |
+
self.dropout = nn.Dropout(p=0.25)
|
31 |
+
self.map_visual_to_seq = nn.Linear(cnn_output_channels * self.output_height, num_feats_mapped_seq_hidden)
|
32 |
+
|
33 |
+
self.b_lstm_1 = nn.LSTM(num_feats_mapped_seq_hidden, num_feats_seq_hidden, bidirectional=True)
|
34 |
+
self.b_lstm_2 = nn.LSTM(2 * num_feats_seq_hidden, num_feats_seq_hidden, bidirectional=True)
|
35 |
+
|
36 |
+
self.final_dense = nn.Linear(2 * num_feats_seq_hidden, num_classes)
|
37 |
+
|
38 |
+
def forward(self, visual_feats):
|
39 |
+
visual_feats = visual_feats.permute(3, 0, 1, 2)
|
40 |
+
# WBCH
|
41 |
+
# the sequence is along the width of the image as a sentence
|
42 |
+
|
43 |
+
visual_feats = visual_feats.contiguous().view(visual_feats.shape[0], visual_feats.shape[1], -1)
|
44 |
+
# WBC
|
45 |
+
|
46 |
+
seq = self.map_visual_to_seq(visual_feats)
|
47 |
+
seq = self.dropout(seq)
|
48 |
+
lstm_1, _ = self.b_lstm_1(seq)
|
49 |
+
lstm_2, _ = self.b_lstm_2(lstm_1)
|
50 |
+
lstm_2 = self.dropout(lstm_2)
|
51 |
+
|
52 |
+
dense_output = self.final_dense(lstm_2)
|
53 |
+
# [seq_len, B, num_classes]
|
54 |
+
|
55 |
+
log_probs = F.log_softmax(dense_output, dim=2)
|
56 |
+
|
57 |
+
return log_probs
|
58 |
+
|
59 |
+
|
60 |
+
class CRNN(nn.Module):
|
61 |
+
"""
|
62 |
+
Hybrid CNN - RNN model
|
63 |
+
CNN - Modified ResNet34 for visual features
|
64 |
+
RNN - BiLSTM for seq2seq modeling
|
65 |
+
"""
|
66 |
+
def __init__(self, num_classes, image_height, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
|
67 |
+
"""
|
68 |
+
---------
|
69 |
+
Arguments
|
70 |
+
---------
|
71 |
+
num_classes : int
|
72 |
+
num of distinct characters (classes) in the dataset
|
73 |
+
image_height : int
|
74 |
+
image height
|
75 |
+
num_feats_mapped_seq_hidden : int
|
76 |
+
number of features to be used in the mapped visual features as sequences (default: 128)
|
77 |
+
num_feats_seq_hidden : int
|
78 |
+
number of features to be used in the LSTM for sequence modeling (default: 256)
|
79 |
+
"""
|
80 |
+
super().__init__()
|
81 |
+
self.visual_feature_extractor = ResNetFeatureExtractor()
|
82 |
+
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(num_classes, image_height, self.visual_feature_extractor.output_channels, num_feats_mapped_seq_hidden, num_feats_seq_hidden)
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
visual_feats = self.visual_feature_extractor(x)
|
86 |
+
# [B, 512, H/32, W/32]
|
87 |
+
|
88 |
+
log_probs = self.rnn_seq2seq_module(visual_feats)
|
89 |
+
return log_probs
|
90 |
+
|
91 |
+
|
92 |
+
class STN_CRNN(nn.Module):
|
93 |
+
"""
|
94 |
+
STN + CNN + RNN model
|
95 |
+
STN - Spatial Transformer Network for learning variable handwriting
|
96 |
+
CNN - Modified ResNet34 for visual features
|
97 |
+
RNN - BiLSTM for seq2seq modeling
|
98 |
+
"""
|
99 |
+
def __init__(self, num_classes, image_height, image_width, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
|
100 |
+
"""
|
101 |
+
---------
|
102 |
+
Arguments
|
103 |
+
---------
|
104 |
+
num_classes : int
|
105 |
+
num of distinct characters (classes) in the dataset
|
106 |
+
image_height : int
|
107 |
+
image height
|
108 |
+
image_width : int
|
109 |
+
image width
|
110 |
+
num_feats_mapped_seq_hidden : int
|
111 |
+
number of features to be used in the mapped visual features as sequences (default: 128)
|
112 |
+
num_feats_seq_hidden : int
|
113 |
+
number of features to be used in the LSTM for sequence modeling (default: 256)
|
114 |
+
"""
|
115 |
+
super().__init__()
|
116 |
+
self.stn = TPS_SpatialTransformerNetwork(
|
117 |
+
80,
|
118 |
+
(image_height, image_width),
|
119 |
+
(image_height, image_width),
|
120 |
+
I_channel_num=3,
|
121 |
+
)
|
122 |
+
self.visual_feature_extractor = ResNetFeatureExtractor()
|
123 |
+
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(num_classes, image_height, self.visual_feature_extractor.output_channels, num_feats_mapped_seq_hidden, num_feats_seq_hidden)
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
stn_output = self.stn(x)
|
127 |
+
visual_feats = self.visual_feature_extractor(stn_output)
|
128 |
+
log_probs = self.rnn_seq2seq_module(visual_feats)
|
129 |
+
return log_probs
|
130 |
+
|
131 |
+
"""
|
132 |
+
class STN_PP_CRNN(nn.Module):
|
133 |
+
def __init__(self, num_classes, image_height, image_width, num_feats_mapped_seq_hidden=128, num_feats_seq_hidden=256):
|
134 |
+
super().__init__()
|
135 |
+
self.stn = TPS_SpatialTransformerNetwork(
|
136 |
+
20,
|
137 |
+
(image_height, image_width),
|
138 |
+
(image_height, image_width),
|
139 |
+
I_channel_num=3,
|
140 |
+
)
|
141 |
+
self.visual_feature_extractor = ResNetFeatureExtractor()
|
142 |
+
self.pp_block = PyramidPoolBlock(num_channels=self.visual_feature_extractor.output_channels)
|
143 |
+
self.rnn_seq2seq_module = HW_RNN_Seq2Seq(num_classes, image_height, self.visual_feature_extractor.output_channels, num_feats_mapped_seq_hidden, num_feats_seq_hidden)
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
stn_output = self.stn(x)
|
147 |
+
visual_feats = self.visual_feature_extractor(stn_output)
|
148 |
+
pp_feats = self.pp_block(visual_feats)
|
149 |
+
log_probs = self.rnn_seq2seq_module(pp_feats)
|
150 |
+
return log_probs
|
151 |
+
"""
|
iam_line_recognition/model_visual_features.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn as nn
|
4 |
+
from typing import List
|
5 |
+
from torch import Tensor
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torchvision.models.resnet import BasicBlock, model_urls, load_state_dict_from_url, conv1x1, conv3x3
|
8 |
+
|
9 |
+
device = torch.device("cuda")
|
10 |
+
|
11 |
+
class CustomResNet(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
layers: List[int],
|
15 |
+
block=BasicBlock,
|
16 |
+
zero_init_residual=False,
|
17 |
+
groups=1,
|
18 |
+
num_classes=1000,
|
19 |
+
width_per_group=64,
|
20 |
+
replace_stride_with_dilation=None,
|
21 |
+
norm_layer=None,
|
22 |
+
):
|
23 |
+
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
if norm_layer is None:
|
27 |
+
self._norm_layer = nn.BatchNorm2d
|
28 |
+
|
29 |
+
self.inplanes = 64
|
30 |
+
self.dilation = 1
|
31 |
+
|
32 |
+
if replace_stride_with_dilation is None:
|
33 |
+
# each element in the tuple indicates if we should replace
|
34 |
+
# the 2x2 stride with a dilated convolution instead
|
35 |
+
replace_stride_with_dilation = [False, False, False]
|
36 |
+
|
37 |
+
if len(replace_stride_with_dilation) != 3:
|
38 |
+
raise ValueError(
|
39 |
+
"replace_stride_with_dilation should be None "
|
40 |
+
f"or a 3-element tuple, got {replace_stride_with_dilation}"
|
41 |
+
)
|
42 |
+
|
43 |
+
self.groups = groups
|
44 |
+
self.base_width = width_per_group
|
45 |
+
|
46 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
47 |
+
self.bn1 = self._norm_layer(self.inplanes)
|
48 |
+
self.relu = nn.ReLU(inplace=True)
|
49 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=(2, 1), padding=1)
|
50 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
51 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=(2, 1), dilate=replace_stride_with_dilation[0])
|
52 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=(2, 2), dilate=replace_stride_with_dilation[1])
|
53 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=(2, 1), dilate=replace_stride_with_dilation[2])
|
54 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
55 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
56 |
+
|
57 |
+
for m in self.modules():
|
58 |
+
if isinstance(m, nn.Conv2d):
|
59 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
60 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
61 |
+
nn.init.constant_(m.weight, 1)
|
62 |
+
nn.init.constant_(m.bias, 0)
|
63 |
+
|
64 |
+
# Zero-initialize the last BN in each residual branch,
|
65 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
66 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
67 |
+
if zero_init_residual:
|
68 |
+
for m in self.modules():
|
69 |
+
if isinstance(m, BasicBlock):
|
70 |
+
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
|
71 |
+
|
72 |
+
def _make_layer(
|
73 |
+
self,
|
74 |
+
block,
|
75 |
+
planes,
|
76 |
+
blocks,
|
77 |
+
stride=1,
|
78 |
+
dilate=False,
|
79 |
+
) -> nn.Sequential:
|
80 |
+
norm_layer = self._norm_layer
|
81 |
+
downsample = None
|
82 |
+
previous_dilation = self.dilation
|
83 |
+
if dilate:
|
84 |
+
self.dilation *= stride
|
85 |
+
stride = 1
|
86 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
87 |
+
downsample = nn.Sequential(
|
88 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
89 |
+
norm_layer(planes * block.expansion),
|
90 |
+
)
|
91 |
+
|
92 |
+
layers = []
|
93 |
+
layers.append(
|
94 |
+
block(
|
95 |
+
self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
|
96 |
+
)
|
97 |
+
)
|
98 |
+
self.inplanes = planes * block.expansion
|
99 |
+
for _ in range(1, blocks):
|
100 |
+
layers.append(
|
101 |
+
block(
|
102 |
+
self.inplanes,
|
103 |
+
planes,
|
104 |
+
groups=self.groups,
|
105 |
+
base_width=self.base_width,
|
106 |
+
dilation=self.dilation,
|
107 |
+
norm_layer=norm_layer,
|
108 |
+
)
|
109 |
+
)
|
110 |
+
|
111 |
+
return nn.Sequential(*layers)
|
112 |
+
|
113 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
114 |
+
# See note [TorchScript super()]
|
115 |
+
x = self.conv1(x)
|
116 |
+
x = self.bn1(x)
|
117 |
+
x = self.relu(x)
|
118 |
+
x = self.maxpool(x)
|
119 |
+
|
120 |
+
x = self.layer1(x)
|
121 |
+
x = self.layer2(x)
|
122 |
+
x = self.layer3(x)
|
123 |
+
x = self.layer4(x)
|
124 |
+
return x
|
125 |
+
|
126 |
+
def forward(self, x: Tensor) -> Tensor:
|
127 |
+
return self._forward_impl(x)
|
128 |
+
|
129 |
+
def _resnet(layers: List[int], pretrained=True) -> CustomResNet:
|
130 |
+
model = CustomResNet(layers)
|
131 |
+
|
132 |
+
if pretrained:
|
133 |
+
model.load_state_dict(load_state_dict_from_url(model_urls["resnet34"]))
|
134 |
+
|
135 |
+
return model
|
136 |
+
|
137 |
+
def resnet34(*, pretrained=True) -> CustomResNet:
|
138 |
+
"""ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
|
139 |
+
Args:
|
140 |
+
weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The
|
141 |
+
pretrained weights to use. See
|
142 |
+
:class:`~torchvision.models.ResNet34_Weights` below for
|
143 |
+
more details, and possible values. By default, no pre-trained
|
144 |
+
weights are used.
|
145 |
+
progress (bool, optional): If True, displays a progress bar of the
|
146 |
+
download to stderr. Default is True.
|
147 |
+
**kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
|
148 |
+
base class. Please refer to the `source code
|
149 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
|
150 |
+
for more details about this class.
|
151 |
+
.. autoclass:: torchvision.models.ResNet34_Weights
|
152 |
+
:members:
|
153 |
+
"""
|
154 |
+
|
155 |
+
return _resnet([3, 4, 6, 3], pretrained=pretrained)
|
156 |
+
|
157 |
+
|
158 |
+
class ResNetFeatureExtractor(nn.Module):
|
159 |
+
"""
|
160 |
+
Defines Base ResNet-34 feature extractor
|
161 |
+
"""
|
162 |
+
def __init__(self, pretrained=True):
|
163 |
+
"""
|
164 |
+
---------
|
165 |
+
Arguments
|
166 |
+
---------
|
167 |
+
pretrained : bool (default=True)
|
168 |
+
boolean to indicate whether to use a pretrained resnet model or not
|
169 |
+
"""
|
170 |
+
super().__init__()
|
171 |
+
self.output_channels = 512
|
172 |
+
self.resnet34 = resnet34(pretrained=pretrained)
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
block1 = self.resnet34.conv1(x)
|
176 |
+
block1 = self.resnet34.bn1(block1)
|
177 |
+
block1 = self.resnet34.relu(block1) # [64, H/2, W/2]
|
178 |
+
|
179 |
+
block2 = self.resnet34.maxpool(block1)
|
180 |
+
block2 = self.resnet34.layer1(block2) # [64, H/4, W/4]
|
181 |
+
block3 = self.resnet34.layer2(block2) # [128, H/8, W/8]
|
182 |
+
block4 = self.resnet34.layer3(block3) # [256, H/16, W/16]
|
183 |
+
resnet_features = self.resnet34.layer4(block4) # [512, H/32, W/32]
|
184 |
+
|
185 |
+
# [B, 512, H/32, W/32]
|
186 |
+
return resnet_features
|
187 |
+
|
188 |
+
|
189 |
+
#########################################
|
190 |
+
### STN - Spatial Transformer Network ###
|
191 |
+
#########################################
|
192 |
+
class TPS_SpatialTransformerNetwork(nn.Module):
|
193 |
+
""" Rectification Network of RARE, namely TPS based STN """
|
194 |
+
|
195 |
+
def __init__(self, num_fiducial_points, I_size, I_r_size, I_channel_num=1):
|
196 |
+
""" Based on RARE TPS
|
197 |
+
input:
|
198 |
+
batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
|
199 |
+
I_size : (height, width) of the input image I
|
200 |
+
I_r_size : (height, width) of the rectified image I_r
|
201 |
+
I_channel_num : the number of channels of the input image I
|
202 |
+
output:
|
203 |
+
batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width]
|
204 |
+
"""
|
205 |
+
super(TPS_SpatialTransformerNetwork, self).__init__()
|
206 |
+
self.num_fiducial_points = num_fiducial_points
|
207 |
+
self.I_size = I_size
|
208 |
+
self.I_r_size = I_r_size # = (I_r_height, I_r_width)
|
209 |
+
self.I_channel_num = I_channel_num
|
210 |
+
self.LocalizationNetwork = LocalizationNetwork(self.num_fiducial_points, self.I_channel_num)
|
211 |
+
self.GridGenerator = GridGenerator(self.num_fiducial_points, self.I_r_size)
|
212 |
+
|
213 |
+
def forward(self, batch_I):
|
214 |
+
batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
|
215 |
+
build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2
|
216 |
+
build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2])
|
217 |
+
|
218 |
+
if torch.__version__ > "1.2.0":
|
219 |
+
batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True)
|
220 |
+
else:
|
221 |
+
batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border')
|
222 |
+
|
223 |
+
return batch_I_r
|
224 |
+
|
225 |
+
|
226 |
+
class LocalizationNetwork(nn.Module):
|
227 |
+
""" Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """
|
228 |
+
|
229 |
+
def __init__(self, num_fiducial_points, I_channel_num):
|
230 |
+
super(LocalizationNetwork, self).__init__()
|
231 |
+
self.num_fiducial_points = num_fiducial_points
|
232 |
+
self.I_channel_num = I_channel_num
|
233 |
+
self.conv = nn.Sequential(
|
234 |
+
nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1,
|
235 |
+
bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
|
236 |
+
nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2
|
237 |
+
nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True),
|
238 |
+
nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4
|
239 |
+
nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True),
|
240 |
+
nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8
|
241 |
+
nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True),
|
242 |
+
nn.AdaptiveAvgPool2d(1) # batch_size x 512
|
243 |
+
)
|
244 |
+
|
245 |
+
self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True))
|
246 |
+
self.localization_fc2 = nn.Linear(256, self.num_fiducial_points * 2)
|
247 |
+
|
248 |
+
# Init fc2 in LocalizationNetwork
|
249 |
+
self.localization_fc2.weight.data.fill_(0)
|
250 |
+
""" see RARE paper Fig. 6 (a) """
|
251 |
+
ctrl_pts_x = np.linspace(-1.0, 1.0, int(num_fiducial_points / 2))
|
252 |
+
ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(num_fiducial_points / 2))
|
253 |
+
ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(num_fiducial_points / 2))
|
254 |
+
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
255 |
+
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
256 |
+
initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
|
257 |
+
self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1)
|
258 |
+
|
259 |
+
def forward(self, batch_I):
|
260 |
+
"""
|
261 |
+
input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width]
|
262 |
+
output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2]
|
263 |
+
"""
|
264 |
+
batch_size = batch_I.size(0)
|
265 |
+
features = self.conv(batch_I).view(batch_size, -1)
|
266 |
+
batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.num_fiducial_points, 2)
|
267 |
+
return batch_C_prime
|
268 |
+
|
269 |
+
|
270 |
+
class GridGenerator(nn.Module):
|
271 |
+
""" Grid Generator of RARE, which produces P_prime by multipling T with P """
|
272 |
+
|
273 |
+
def __init__(self, num_fiducial_points, I_r_size):
|
274 |
+
""" Generate P_hat and inv_delta_C for later """
|
275 |
+
super(GridGenerator, self).__init__()
|
276 |
+
self.eps = 1e-6
|
277 |
+
self.I_r_height, self.I_r_width = I_r_size
|
278 |
+
self.num_fiducial_points = num_fiducial_points
|
279 |
+
self.C = self._build_C(self.num_fiducial_points) # F x 2
|
280 |
+
self.P = self._build_P(self.I_r_width, self.I_r_height)
|
281 |
+
## for multi-gpu, you need register buffer
|
282 |
+
self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.num_fiducial_points, self.C)).float()) # F+3 x F+3
|
283 |
+
self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.num_fiducial_points, self.C, self.P)).float()) # n x F+3
|
284 |
+
## for fine-tuning with different image width, you may use below instead of self.register_buffer
|
285 |
+
#self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.num_fiducial_points, self.C)).float().cuda() # F+3 x F+3
|
286 |
+
#self.P_hat = torch.tensor(self._build_P_hat(self.num_fiducial_points, self.C, self.P)).float().cuda() # n x F+3
|
287 |
+
|
288 |
+
def _build_C(self, F):
|
289 |
+
""" Return coordinates of fiducial points in I_r; C """
|
290 |
+
ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
|
291 |
+
ctrl_pts_y_top = -1 * np.ones(int(F / 2))
|
292 |
+
ctrl_pts_y_bottom = np.ones(int(F / 2))
|
293 |
+
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
294 |
+
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
295 |
+
C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
|
296 |
+
return C # F x 2
|
297 |
+
|
298 |
+
def _build_inv_delta_C(self, F, C):
|
299 |
+
""" Return inv_delta_C which is needed to calculate T """
|
300 |
+
hat_C = np.zeros((F, F), dtype=float) # F x F
|
301 |
+
for i in range(0, F):
|
302 |
+
for j in range(i, F):
|
303 |
+
r = np.linalg.norm(C[i] - C[j])
|
304 |
+
hat_C[i, j] = r
|
305 |
+
hat_C[j, i] = r
|
306 |
+
np.fill_diagonal(hat_C, 1)
|
307 |
+
hat_C = (hat_C ** 2) * np.log(hat_C)
|
308 |
+
# print(C.shape, hat_C.shape)
|
309 |
+
delta_C = np.concatenate( # F+3 x F+3
|
310 |
+
[
|
311 |
+
np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3
|
312 |
+
np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3
|
313 |
+
np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3
|
314 |
+
],
|
315 |
+
axis=0
|
316 |
+
)
|
317 |
+
inv_delta_C = np.linalg.inv(delta_C)
|
318 |
+
return inv_delta_C # F+3 x F+3
|
319 |
+
|
320 |
+
def _build_P(self, I_r_width, I_r_height):
|
321 |
+
I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width
|
322 |
+
I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height
|
323 |
+
P = np.stack( # self.I_r_width x self.I_r_height x 2
|
324 |
+
np.meshgrid(I_r_grid_x, I_r_grid_y),
|
325 |
+
axis=2
|
326 |
+
)
|
327 |
+
return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2
|
328 |
+
|
329 |
+
def _build_P_hat(self, F, C, P):
|
330 |
+
n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
|
331 |
+
P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2
|
332 |
+
C_tile = np.expand_dims(C, axis=0) # 1 x F x 2
|
333 |
+
P_diff = P_tile - C_tile # n x F x 2
|
334 |
+
rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F
|
335 |
+
rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F
|
336 |
+
P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1)
|
337 |
+
return P_hat # n x F+3
|
338 |
+
|
339 |
+
def build_P_prime(self, batch_C_prime):
|
340 |
+
""" Generate Grid from batch_C_prime [batch_size x F x 2] """
|
341 |
+
batch_size = batch_C_prime.size(0)
|
342 |
+
batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
|
343 |
+
batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
|
344 |
+
batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros(
|
345 |
+
batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2
|
346 |
+
batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2
|
347 |
+
batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
|
348 |
+
return batch_P_prime # batch_size x n x 2
|
349 |
+
|
350 |
+
|
351 |
+
"""
|
352 |
+
########################################
|
353 |
+
######## Pyramid Pooling Block #########
|
354 |
+
########################################
|
355 |
+
class PyramidPool(nn.Module):
|
356 |
+
def __init__(self, pool_kernel_size, in_channels, out_channels):
|
357 |
+
super().__init__()
|
358 |
+
self.pool_kernel_size = pool_kernel_size
|
359 |
+
self.avg_pool_block = nn.Sequential(
|
360 |
+
nn.AvgPool2d((1, self.pool_kernel_size), stride=(1, self.pool_kernel_size)),
|
361 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding="same", bias=False),
|
362 |
+
nn.BatchNorm2d(out_channels),
|
363 |
+
nn.ELU(inplace=True),
|
364 |
+
)
|
365 |
+
|
366 |
+
for m in self.modules():
|
367 |
+
if isinstance(m, nn.Conv2d):
|
368 |
+
nn.init.xavier_normal_(m.weight)
|
369 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
370 |
+
nn.init.constant_(m.weight, 1)
|
371 |
+
nn.init.constant_(m.bias, 0)
|
372 |
+
|
373 |
+
def forward(self, x):
|
374 |
+
_, _, in_height, in_width = x.size()
|
375 |
+
x = self.avg_pool_block(x)
|
376 |
+
x = F.interpolate(x, size=(in_height, in_width), mode="bilinear")
|
377 |
+
return x
|
378 |
+
|
379 |
+
|
380 |
+
class PyramidPoolBlock(nn.Module):
|
381 |
+
def __init__(self, pyramid_pool_kernel_sizes=[4, 8, 16, 32], num_channels=512):
|
382 |
+
super().__init__()
|
383 |
+
pp_out_channels = 256
|
384 |
+
self.pyramid_pool_layers = nn.ModuleList([PyramidPool(pool_kernel_size=k, in_channels=num_channels, out_channels=pp_out_channels) for k in pyramid_pool_kernel_sizes])
|
385 |
+
self.final_layer = nn.Sequential(
|
386 |
+
nn.Conv2d((num_channels + (pp_out_channels * len(self.pyramid_pool_layers))), num_channels, (1, 5), stride=1, padding="same"),
|
387 |
+
nn.BatchNorm2d(num_channels),
|
388 |
+
nn.ELU(inplace=True),
|
389 |
+
nn.Dropout(p=0.1),
|
390 |
+
)
|
391 |
+
|
392 |
+
def forward(self, input):
|
393 |
+
pp_outputs = []
|
394 |
+
for pp_layer in self.pyramid_pool_layers:
|
395 |
+
pp_output = pp_layer(input)
|
396 |
+
pp_outputs.append(pp_output)
|
397 |
+
pp_outputs.append(input)
|
398 |
+
x = torch.cat(pp_outputs, dim=1)
|
399 |
+
x = self.final_layer(x)
|
400 |
+
return x
|
401 |
+
"""
|
iam_line_recognition/test_internal.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
import argparse
|
6 |
+
import torchvision
|
7 |
+
import numpy as np
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from logger_utils import CSVWriter
|
12 |
+
from model_main import CRNN, STN_CRNN
|
13 |
+
from utils import ctc_decode, compute_wer_and_cer_for_sample
|
14 |
+
from dataset import HWRecogIAMDataset, split_dataset, get_dataloader_for_testing
|
15 |
+
|
16 |
+
|
17 |
+
def test(hw_model, test_loader, device, list_test_files, which_ctc_decoder="beam_search", save_prediction_stats=False):
|
18 |
+
"""
|
19 |
+
---------
|
20 |
+
Arguments
|
21 |
+
---------
|
22 |
+
hw_model : object
|
23 |
+
handwriting recognition model object
|
24 |
+
test_loader : object
|
25 |
+
dataset loader object
|
26 |
+
device : str
|
27 |
+
device to be used for running the evaluation
|
28 |
+
list_test_files : list
|
29 |
+
list of all the test files
|
30 |
+
which_ctc_decoder : str
|
31 |
+
string indicating which ctc decoder to use
|
32 |
+
save_prediction_stats : bool
|
33 |
+
whether to save prediction stats
|
34 |
+
"""
|
35 |
+
hw_model.eval()
|
36 |
+
num_test_samples = len(test_loader.dataset)
|
37 |
+
num_test_batches = len(test_loader)
|
38 |
+
|
39 |
+
count = 0
|
40 |
+
list_test_cers, list_test_wers = [], []
|
41 |
+
|
42 |
+
if save_prediction_stats:
|
43 |
+
csv_writer = CSVWriter(
|
44 |
+
file_name="pred_stats.csv",
|
45 |
+
column_names=["file_name", "num_chars", "num_words", "cer", "wer"]
|
46 |
+
)
|
47 |
+
|
48 |
+
with torch.no_grad():
|
49 |
+
for images, labels, length_labels in test_loader:
|
50 |
+
count += 1
|
51 |
+
images = images.to(device, dtype=torch.float)
|
52 |
+
log_probs = hw_model(images)
|
53 |
+
pred_labels = ctc_decode(log_probs, which_ctc_decoder=which_ctc_decoder)
|
54 |
+
labels = labels.cpu().numpy().tolist()
|
55 |
+
|
56 |
+
str_label = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in labels]
|
57 |
+
str_label = "".join(str_label)
|
58 |
+
str_pred = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in pred_labels[0]]
|
59 |
+
str_pred = "".join(str_pred)
|
60 |
+
|
61 |
+
cer_sample, wer_sample = compute_wer_and_cer_for_sample(str_pred, str_label)
|
62 |
+
list_test_cers.append(cer_sample)
|
63 |
+
list_test_wers.append(wer_sample)
|
64 |
+
|
65 |
+
print(f"progress: {count}/{num_test_samples}, test file: {list_test_files[count-1]}")
|
66 |
+
print(f"{str_label} - label")
|
67 |
+
print(f"{str_pred} - prediction")
|
68 |
+
print(f"cer: {cer_sample:.3f}, wer: {wer_sample:.3f}\n")
|
69 |
+
|
70 |
+
if save_prediction_stats:
|
71 |
+
csv_writer.write_row([
|
72 |
+
list_test_files[count-1],
|
73 |
+
len(str_label),
|
74 |
+
len(str_label.split(" ")),
|
75 |
+
cer_sample,
|
76 |
+
wer_sample,
|
77 |
+
])
|
78 |
+
list_test_cers = np.array(list_test_cers)
|
79 |
+
list_test_wers = np.array(list_test_wers)
|
80 |
+
mean_test_cer = np.mean(list_test_cers)
|
81 |
+
mean_test_wer = np.mean(list_test_wers)
|
82 |
+
print(f"test set - mean cer: {mean_test_cer:.3f}, mean wer: {mean_test_wer:.3f}\n")
|
83 |
+
|
84 |
+
if save_prediction_stats:
|
85 |
+
csv_writer.close()
|
86 |
+
return
|
87 |
+
|
88 |
+
def test_hw_recognizer(FLAGS):
|
89 |
+
file_txt_labels = os.path.join(FLAGS.dir_dataset, "iam_lines_gt.txt")
|
90 |
+
dir_images = os.path.join(FLAGS.dir_dataset, "img")
|
91 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
92 |
+
|
93 |
+
# choose a device for testing
|
94 |
+
if torch.cuda.is_available():
|
95 |
+
device = torch.device("cuda")
|
96 |
+
else:
|
97 |
+
device = torch.device("cpu")
|
98 |
+
|
99 |
+
# get the internal test set files
|
100 |
+
test_x, test_y = split_dataset(file_txt_labels, for_train=False)
|
101 |
+
num_test_samples = len(test_x)
|
102 |
+
# get the internal test set dataloader
|
103 |
+
test_loader = get_dataloader_for_testing(
|
104 |
+
test_x, test_y,
|
105 |
+
dir_images=dir_images, image_height=FLAGS.image_height, image_width=FLAGS.image_width,
|
106 |
+
)
|
107 |
+
|
108 |
+
num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1
|
109 |
+
print(f"task - handwriting recognition")
|
110 |
+
print(f"model: {FLAGS.which_hw_model}, ctc decoder: {FLAGS.which_ctc_decoder}")
|
111 |
+
print(f"image height: {FLAGS.image_height}, image width: {FLAGS.image_width}")
|
112 |
+
print(f"num test samples: {num_test_samples}")
|
113 |
+
|
114 |
+
# load the right model
|
115 |
+
if FLAGS.which_hw_model == "crnn":
|
116 |
+
hw_model = CRNN(num_classes, FLAGS.image_height)
|
117 |
+
elif FLAGS.which_hw_model == "stn_crnn":
|
118 |
+
hw_model = STN_CRNN(num_classes, FLAGS.image_height, FLAGS.image_width)
|
119 |
+
else:
|
120 |
+
print(f"unidentified option : {FLAGS.which_hw_model}")
|
121 |
+
sys.exit(0)
|
122 |
+
hw_model.to(device)
|
123 |
+
hw_model.load_state_dict(torch.load(FLAGS.file_model))
|
124 |
+
|
125 |
+
# start testing of the model on the internal set
|
126 |
+
print(f"testing of handwriting recognition model {FLAGS.which_hw_model} started\n")
|
127 |
+
test(hw_model, test_loader, device, test_x, FLAGS.which_ctc_decoder, bool(FLAGS.save_prediction_stats))
|
128 |
+
print(f"testing handwriting recognition model completed!!!!")
|
129 |
+
return
|
130 |
+
|
131 |
+
def main():
|
132 |
+
image_height = 32
|
133 |
+
image_width = 768
|
134 |
+
which_hw_model = "crnn"
|
135 |
+
dir_dataset = "/home/abhishek/Desktop/RUG/hw_recognition/IAM-data/"
|
136 |
+
file_model = "model_crnn/crnn_H_32_W_768_E_177.pth"
|
137 |
+
which_ctc_decoder = "beam_search"
|
138 |
+
save_prediction_stats = 0
|
139 |
+
|
140 |
+
parser = argparse.ArgumentParser(
|
141 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
142 |
+
)
|
143 |
+
|
144 |
+
parser.add_argument("--image_height", default=image_height,
|
145 |
+
type=int, help="image height to be used to predict with the model")
|
146 |
+
parser.add_argument("--image_width", default=image_width,
|
147 |
+
type=int, help="image width to be used to predict with the model")
|
148 |
+
parser.add_argument("--dir_dataset", default=dir_dataset,
|
149 |
+
type=str, help="full directory path to the dataset")
|
150 |
+
parser.add_argument("--which_hw_model", default=which_hw_model,
|
151 |
+
type=str, choices=["crnn", "stn_crnn"], help="which model to be used for prediction")
|
152 |
+
parser.add_argument("--which_ctc_decoder", default=which_ctc_decoder,
|
153 |
+
type=str, choices=["beam_search", "greedy"], help="which ctc decoder to use")
|
154 |
+
parser.add_argument("--file_model", default=file_model,
|
155 |
+
type=str, help="full path to trained model file (.pth)")
|
156 |
+
parser.add_argument("--save_prediction_stats", default=save_prediction_stats,
|
157 |
+
type=int, choices=[0, 1], help="save prediction stats (1 - yes, 0 - no)")
|
158 |
+
|
159 |
+
FLAGS, unparsed = parser.parse_known_args()
|
160 |
+
test_hw_recognizer(FLAGS)
|
161 |
+
return
|
162 |
+
|
163 |
+
if __name__ == "__main__":
|
164 |
+
main()
|
iam_line_recognition/train.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
import argparse
|
6 |
+
import torchvision
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from model_main import CRNN, STN_CRNN
|
11 |
+
from logger_utils import CSVWriter, write_json_file
|
12 |
+
from utils import compute_wer_and_cer_for_sample, ctc_decode
|
13 |
+
from dataset import HWRecogIAMDataset, split_dataset, get_dataloaders_for_training
|
14 |
+
|
15 |
+
|
16 |
+
def train(hw_model, optimizer, criterion, train_loader, device):
|
17 |
+
"""
|
18 |
+
---------
|
19 |
+
Arguments
|
20 |
+
---------
|
21 |
+
hw_model : object
|
22 |
+
handwriting recognition model object
|
23 |
+
optimizer : object
|
24 |
+
optimizer object to be used for optimization
|
25 |
+
criterion : object
|
26 |
+
criterion or loss object to be used as the objective function for optimization
|
27 |
+
train_loader : object
|
28 |
+
train set dataloader object
|
29 |
+
device : str
|
30 |
+
device to be used for running the evaluation
|
31 |
+
|
32 |
+
-------
|
33 |
+
Returns
|
34 |
+
-------
|
35 |
+
train_loss : float
|
36 |
+
mean training loss for an epoch
|
37 |
+
"""
|
38 |
+
hw_model.train()
|
39 |
+
train_running_loss = 0.0
|
40 |
+
num_train_samples = len(train_loader.dataset)
|
41 |
+
num_train_batches = len(train_loader)
|
42 |
+
|
43 |
+
for images, labels, lengths_labels in train_loader:
|
44 |
+
images = images.to(device, dtype=torch.float)
|
45 |
+
labels = labels.to(device, dtype=torch.long)
|
46 |
+
lengths_labels = lengths_labels.to(device, torch.long)
|
47 |
+
|
48 |
+
batch_size = images.size(0)
|
49 |
+
optimizer.zero_grad()
|
50 |
+
log_probs = hw_model(images)
|
51 |
+
|
52 |
+
lengths_preds = torch.LongTensor([log_probs.size(0)] * batch_size)
|
53 |
+
lengths_labels = torch.flatten(lengths_labels)
|
54 |
+
|
55 |
+
loss = criterion(log_probs, labels, lengths_preds, lengths_labels)
|
56 |
+
train_running_loss += loss.item()
|
57 |
+
loss.backward()
|
58 |
+
torch.nn.utils.clip_grad_norm_(hw_model.parameters(), 5) # gradient clipping with 5
|
59 |
+
optimizer.step()
|
60 |
+
|
61 |
+
train_loss = train_running_loss / num_train_batches
|
62 |
+
return train_loss
|
63 |
+
|
64 |
+
def validate(hw_model, criterion, valid_loader, device):
|
65 |
+
"""
|
66 |
+
---------
|
67 |
+
Arguments
|
68 |
+
---------
|
69 |
+
hw_model : object
|
70 |
+
handwriting recognition model object
|
71 |
+
criterion : object
|
72 |
+
criterion or loss object to be used as the objective function for optimization
|
73 |
+
valid_loader : object
|
74 |
+
validation set dataloader object
|
75 |
+
device : str
|
76 |
+
device to be used for running the evaluation
|
77 |
+
|
78 |
+
-------
|
79 |
+
Returns
|
80 |
+
-------
|
81 |
+
a 3 tuple of
|
82 |
+
valid_loss : float
|
83 |
+
mean validation loss for an epoch
|
84 |
+
valid_cer : float
|
85 |
+
mean character error rate (CER) for validation set
|
86 |
+
valid_wer : float
|
87 |
+
mean word error rate (WER) for validation set
|
88 |
+
"""
|
89 |
+
hw_model.eval()
|
90 |
+
valid_running_loss = 0.0
|
91 |
+
valid_running_cer = 0.0
|
92 |
+
valid_running_wer = 0.0
|
93 |
+
num_valid_samples = len(valid_loader.dataset)
|
94 |
+
num_valid_batches = len(valid_loader)
|
95 |
+
|
96 |
+
count = 0
|
97 |
+
with torch.no_grad():
|
98 |
+
for images, labels, lengths_labels in valid_loader:
|
99 |
+
images = images.to(device, dtype=torch.float)
|
100 |
+
labels = labels.to(device, dtype=torch.long)
|
101 |
+
lengths_labels = lengths_labels.to(device, torch.long)
|
102 |
+
|
103 |
+
batch_size = images.size(0)
|
104 |
+
log_probs = hw_model(images)
|
105 |
+
lengths_preds = torch.LongTensor([log_probs.size(0)] * batch_size)
|
106 |
+
|
107 |
+
loss = criterion(log_probs, labels, lengths_preds, lengths_labels)
|
108 |
+
valid_running_loss += loss.item()
|
109 |
+
|
110 |
+
pred_labels = ctc_decode(log_probs)
|
111 |
+
labels_for_eval = labels.cpu().numpy().tolist()
|
112 |
+
lengths_labels_for_eval = lengths_labels.cpu().numpy().tolist()
|
113 |
+
|
114 |
+
final_labels_for_eval = []
|
115 |
+
length_label_counter = 0
|
116 |
+
for pred_label, length_label in zip(pred_labels, lengths_labels_for_eval):
|
117 |
+
label = labels_for_eval[length_label_counter:length_label_counter+length_label]
|
118 |
+
length_label_counter += length_label
|
119 |
+
|
120 |
+
final_labels_for_eval.append(label)
|
121 |
+
|
122 |
+
for i in range(len(final_labels_for_eval)):
|
123 |
+
if len(pred_labels[i]) != 0:
|
124 |
+
str_label = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in final_labels_for_eval[i]]
|
125 |
+
str_label = "".join(str_label)
|
126 |
+
str_pred = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in pred_labels[i]]
|
127 |
+
str_pred = "".join(str_pred)
|
128 |
+
|
129 |
+
cer_sample, wer_sample = compute_wer_and_cer_for_sample(str_pred, str_label)
|
130 |
+
else:
|
131 |
+
cer_sample, wer_sample = 100, 100
|
132 |
+
|
133 |
+
valid_running_cer += cer_sample
|
134 |
+
valid_running_wer += wer_sample
|
135 |
+
|
136 |
+
valid_loss = valid_running_loss / num_valid_batches
|
137 |
+
valid_cer = valid_running_cer / num_valid_samples
|
138 |
+
valid_wer = valid_running_wer / num_valid_samples
|
139 |
+
return valid_loss, valid_cer, valid_wer
|
140 |
+
|
141 |
+
def train_hw_recognizer(FLAGS):
|
142 |
+
file_txt_labels = os.path.join(FLAGS.dir_dataset, "iam_lines_gt.txt")
|
143 |
+
dir_images = os.path.join(FLAGS.dir_dataset, "img")
|
144 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
145 |
+
|
146 |
+
# train only on a CUDA device (GPU)
|
147 |
+
if torch.cuda.is_available():
|
148 |
+
device = torch.device("cuda")
|
149 |
+
else:
|
150 |
+
print("CUDA device not found, so exiting....")
|
151 |
+
sys.exit(0)
|
152 |
+
|
153 |
+
# split dataset into train and validation sets
|
154 |
+
train_x, valid_x, train_y, valid_y = split_dataset(file_txt_labels, for_train=True)
|
155 |
+
num_train_samples = len(train_x)
|
156 |
+
num_valid_samples = len(valid_x)
|
157 |
+
# get dataloaders for train and validation sets
|
158 |
+
train_loader, valid_loader = get_dataloaders_for_training(
|
159 |
+
train_x, train_y, valid_x, valid_y,
|
160 |
+
dir_images=dir_images, image_height=FLAGS.image_height, image_width=FLAGS.image_width,
|
161 |
+
batch_size=FLAGS.batch_size,
|
162 |
+
)
|
163 |
+
|
164 |
+
# create a directory for saving the model
|
165 |
+
dir_model = f"model_{FLAGS.which_hw_model}"
|
166 |
+
if not os.path.isdir(dir_model):
|
167 |
+
print(f"creating directory: {dir_model}")
|
168 |
+
os.makedirs(dir_model)
|
169 |
+
|
170 |
+
# save train and validation metrics in a csv file
|
171 |
+
file_logger_train = os.path.join(dir_model, "train_metrics.csv")
|
172 |
+
csv_writer = CSVWriter(
|
173 |
+
file_name=file_logger_train,
|
174 |
+
column_names=["epoch", "loss_train", "loss_valid", "cer_valid", "wer_valid"]
|
175 |
+
)
|
176 |
+
|
177 |
+
file_params = os.path.join(dir_model, "params.json")
|
178 |
+
write_json_file(file_params, vars(FLAGS))
|
179 |
+
|
180 |
+
num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1
|
181 |
+
print(f"task - handwriting recognition")
|
182 |
+
print(f"model: {FLAGS.which_hw_model}")
|
183 |
+
print(f"optimizer: {FLAGS.which_optimizer}, learning rate: {FLAGS.learning_rate:.6f}, weight decay: {FLAGS.weight_decay:.8f}")
|
184 |
+
print(f"batch size: {FLAGS.batch_size}, image height: {FLAGS.image_height}, image width: {FLAGS.image_width}")
|
185 |
+
print(f"num train samples: {num_train_samples}, num validation samples: {num_valid_samples}\n")
|
186 |
+
|
187 |
+
# load the right model
|
188 |
+
if FLAGS.which_hw_model == "crnn":
|
189 |
+
hw_model = CRNN(num_classes, FLAGS.image_height)
|
190 |
+
elif FLAGS.which_hw_model == "stn_crnn":
|
191 |
+
hw_model = STN_CRNN(num_classes, FLAGS.image_height, FLAGS.image_width)
|
192 |
+
else:
|
193 |
+
print(f"unidentified option: {FLAGS.which_hw_model}")
|
194 |
+
sys.exit(0)
|
195 |
+
hw_model.to(device)
|
196 |
+
|
197 |
+
# load the right optimizer based on user option
|
198 |
+
if FLAGS.which_optimizer == "adam":
|
199 |
+
optimizer = torch.optim.Adam(hw_model.parameters(), lr=FLAGS.learning_rate, weight_decay=FLAGS.weight_decay)
|
200 |
+
elif FLAGS.which_optimizer == "adadelta":
|
201 |
+
optimizer = torch.optim.Adadelta(hw_model.parameters(), lr=FLAGS.learning_rate, rho=0.95, eps=1e-8, weight_decay=FLAGS.weight_decay)
|
202 |
+
else:
|
203 |
+
print(f"unidentified option: {FLAGS.which_optimizer}")
|
204 |
+
sys.exit(0)
|
205 |
+
# use the CTC loss as the objective function for training
|
206 |
+
criterion = nn.CTCLoss(reduction="mean", zero_infinity=True)
|
207 |
+
|
208 |
+
# start training the model
|
209 |
+
print(f"training of handwriting recognition model {FLAGS.which_hw_model} started\n")
|
210 |
+
for epoch in range(1, FLAGS.num_epochs+1):
|
211 |
+
time_start = time.time()
|
212 |
+
train_loss = train(hw_model, optimizer, criterion, train_loader, device)
|
213 |
+
valid_loss, valid_cer, valid_wer = validate(hw_model, criterion, valid_loader, device)
|
214 |
+
time_end = time.time()
|
215 |
+
print(f"epoch: {epoch}/{FLAGS.num_epochs}, time: {time_end-time_start:.3f} sec.")
|
216 |
+
print(f"train loss: {train_loss:.6f}, validation loss: {valid_loss:.6f}, validation cer: {valid_cer:.4f}, validation wer: {valid_wer:.4f}\n")
|
217 |
+
|
218 |
+
csv_writer.write_row(
|
219 |
+
[
|
220 |
+
epoch,
|
221 |
+
round(train_loss, 6),
|
222 |
+
round(valid_loss, 6),
|
223 |
+
round(valid_cer, 4),
|
224 |
+
round(valid_wer, 4),
|
225 |
+
]
|
226 |
+
)
|
227 |
+
torch.save(hw_model.state_dict(), os.path.join(dir_model, f"{FLAGS.which_hw_model}_H_{FLAGS.image_height}_W_{FLAGS.image_width}_E_{epoch}.pth"))
|
228 |
+
print(f"Training of handwriting recognition model {FLAGS.which_hw_model} complete!!!!")
|
229 |
+
# close the csv file
|
230 |
+
csv_writer.close()
|
231 |
+
return
|
232 |
+
|
233 |
+
def main():
|
234 |
+
learning_rate = 1
|
235 |
+
# 3e-4 for Adam, 1 for Adadelta
|
236 |
+
weight_decay = 0
|
237 |
+
# 3e-5 with Adam for both CRNN and STN-CRNN
|
238 |
+
# 0 with Adadelta for CRNN and STN-CRNN
|
239 |
+
batch_size = 64
|
240 |
+
num_epochs = 100
|
241 |
+
image_height = 32
|
242 |
+
image_width = 768
|
243 |
+
which_hw_model = "crnn"
|
244 |
+
which_optimizer = "adadelta"
|
245 |
+
dir_dataset = "/home/abhishek/Desktop/RUG/hw_recognition/IAM-data/"
|
246 |
+
|
247 |
+
parser = argparse.ArgumentParser(
|
248 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
249 |
+
)
|
250 |
+
|
251 |
+
parser.add_argument("--learning_rate", default=learning_rate,
|
252 |
+
type=float, help="learning rate to use for training")
|
253 |
+
parser.add_argument("--weight_decay", default=weight_decay,
|
254 |
+
type=float, help="weight decay to use for training")
|
255 |
+
parser.add_argument("--batch_size", default=batch_size,
|
256 |
+
type=int, help="batch size to use for training")
|
257 |
+
parser.add_argument("--num_epochs", default=num_epochs,
|
258 |
+
type=int, help="num epochs to train the model")
|
259 |
+
parser.add_argument("--image_height", default=image_height,
|
260 |
+
type=int, help="image height to be used to train the model")
|
261 |
+
parser.add_argument("--image_width", default=image_width,
|
262 |
+
type=int, help="image width to be used to train the model")
|
263 |
+
parser.add_argument("--dir_dataset", default=dir_dataset,
|
264 |
+
type=str, help="full directory path to the dataset")
|
265 |
+
parser.add_argument("--which_optimizer", default=which_optimizer,
|
266 |
+
type=str, choices=["adadelta", "adam"], help="which optimizer to use to train")
|
267 |
+
parser.add_argument("--which_hw_model", default=which_hw_model,
|
268 |
+
type=str, choices=["crnn", "stn_crnn", "stn_pp_crnn"], help="which model to train")
|
269 |
+
|
270 |
+
FLAGS, unparsed = parser.parse_known_args()
|
271 |
+
train_hw_recognizer(FLAGS)
|
272 |
+
return
|
273 |
+
|
274 |
+
if __name__ == "__main__":
|
275 |
+
main()
|
iam_line_recognition/utils.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import fastwer
|
3 |
+
import numpy as np
|
4 |
+
from scipy.special import logsumexp
|
5 |
+
|
6 |
+
|
7 |
+
"""
|
8 |
+
-------------
|
9 |
+
CTC decoder
|
10 |
+
-------------
|
11 |
+
"""
|
12 |
+
|
13 |
+
NINF = -1 * float("inf")
|
14 |
+
DEFAULT_EMISSION_THRESHOLD = 0.01
|
15 |
+
|
16 |
+
def _reconstruct(labels, blank=0):
|
17 |
+
new_labels = []
|
18 |
+
# merge same labels
|
19 |
+
previous = None
|
20 |
+
for l in labels:
|
21 |
+
if l != previous:
|
22 |
+
new_labels.append(l)
|
23 |
+
previous = l
|
24 |
+
# delete blank
|
25 |
+
new_labels = [l for l in new_labels if l != blank]
|
26 |
+
return new_labels
|
27 |
+
|
28 |
+
def beam_search_decode(emission_log_prob, blank=0, **kwargs):
|
29 |
+
beam_size = kwargs["beam_size"]
|
30 |
+
emission_threshold = kwargs.get("emission_threshold", np.log(DEFAULT_EMISSION_THRESHOLD))
|
31 |
+
|
32 |
+
length, class_count = emission_log_prob.shape
|
33 |
+
|
34 |
+
beams = [([], 0)] # (prefix, accumulated_log_prob)
|
35 |
+
for t in range(length):
|
36 |
+
new_beams = []
|
37 |
+
for prefix, accumulated_log_prob in beams:
|
38 |
+
for c in range(class_count):
|
39 |
+
log_prob = emission_log_prob[t, c]
|
40 |
+
if log_prob < emission_threshold:
|
41 |
+
continue
|
42 |
+
new_prefix = prefix + [c]
|
43 |
+
# log(p1 * p2) = log_p1 + log_p2
|
44 |
+
new_accu_log_prob = accumulated_log_prob + log_prob
|
45 |
+
new_beams.append((new_prefix, new_accu_log_prob))
|
46 |
+
|
47 |
+
# sorted by accumulated_log_prob
|
48 |
+
new_beams.sort(key=lambda x: x[1], reverse=True)
|
49 |
+
beams = new_beams[:beam_size]
|
50 |
+
|
51 |
+
# sum up beams to produce labels
|
52 |
+
total_accu_log_prob = {}
|
53 |
+
for prefix, accu_log_prob in beams:
|
54 |
+
labels = tuple(_reconstruct(prefix, blank))
|
55 |
+
# log(p1 + p2) = logsumexp([log_p1, log_p2])
|
56 |
+
total_accu_log_prob[labels] = \
|
57 |
+
logsumexp([accu_log_prob, total_accu_log_prob.get(labels, NINF)])
|
58 |
+
|
59 |
+
labels_beams = [(list(labels), accu_log_prob)
|
60 |
+
for labels, accu_log_prob in total_accu_log_prob.items()]
|
61 |
+
labels_beams.sort(key=lambda x: x[1], reverse=True)
|
62 |
+
labels = labels_beams[0][0]
|
63 |
+
|
64 |
+
return labels
|
65 |
+
|
66 |
+
def greedy_decode(emission_log_prob, blank=0):
|
67 |
+
labels = np.argmax(emission_log_prob, axis=-1)
|
68 |
+
labels = _reconstruct(labels, blank=blank)
|
69 |
+
return labels
|
70 |
+
|
71 |
+
def ctc_decode(log_probs, which_ctc_decoder="beam_search", label_2_char=None, blank=0, beam_size=25):
|
72 |
+
emission_log_probs = np.transpose(log_probs.cpu().numpy(), (1, 0, 2))
|
73 |
+
# size of emission_log_probs: (batch, length, class)
|
74 |
+
|
75 |
+
decoded_list = []
|
76 |
+
for emission_log_prob in emission_log_probs:
|
77 |
+
if which_ctc_decoder == "beam_search":
|
78 |
+
decoded = beam_search_decode(emission_log_prob, blank=blank, beam_size=beam_size)
|
79 |
+
elif which_ctc_decoder == "greedy":
|
80 |
+
decoded = greedy_decode(emission_log_prob, blank=blank)
|
81 |
+
else:
|
82 |
+
print(f"unidentified option for which_ctc_decoder : {which_ctc_decoder}")
|
83 |
+
sys.exit(0)
|
84 |
+
|
85 |
+
if label_2_char:
|
86 |
+
decoded = [label_2_char[l] for l in decoded]
|
87 |
+
decoded_list.append(decoded)
|
88 |
+
return decoded_list
|
89 |
+
|
90 |
+
"""
|
91 |
+
--------------------
|
92 |
+
Evaluation Metrics
|
93 |
+
--------------------
|
94 |
+
"""
|
95 |
+
def compute_wer_and_cer_for_batch(batch_preds, batch_gts):
|
96 |
+
cer_batch = fastwer.score(batch_preds, batch_gts, char_level=True)
|
97 |
+
wer_batch = fastwer.score(batch_preds, batch_gts)
|
98 |
+
return cer_batch, wer_batch
|
99 |
+
|
100 |
+
def compute_wer_and_cer_for_sample(str_pred, str_gt):
|
101 |
+
cer_sample = fastwer.score_sent(str_pred, str_gt, char_level=True)
|
102 |
+
wer_sample = fastwer.score_sent(str_pred, str_gt)
|
103 |
+
return cer_sample, wer_sample
|
iam_line_recognition/utils_unique_chars.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from dataset import read_IAM_label_txt_file
|
5 |
+
|
6 |
+
def list_unique_characters_in_IAM_dataset(FLAGS):
|
7 |
+
_, all_labels = read_IAM_label_txt_file(FLAGS.file_txt_labels)
|
8 |
+
|
9 |
+
num_labels = len(all_labels)
|
10 |
+
print(f"num labels : {num_labels}")
|
11 |
+
unique_chars = []
|
12 |
+
|
13 |
+
for label in all_labels:
|
14 |
+
unique_chars = unique_chars + list(np.unique(np.array(list(label))))
|
15 |
+
|
16 |
+
unique_chars = sorted(unique_chars)
|
17 |
+
unique_chars = np.array(unique_chars)
|
18 |
+
unique_chars = np.unique(unique_chars)
|
19 |
+
unique_chars = ''.join(unique_chars)
|
20 |
+
|
21 |
+
# prints all unique chars in the IAM dataset
|
22 |
+
print(unique_chars)
|
23 |
+
|
24 |
+
# prints the number of unique chars in the IAM dataset
|
25 |
+
print(f"Number of unique characters : {len(unique_chars)}")
|
26 |
+
return
|
27 |
+
|
28 |
+
def main():
|
29 |
+
file_txt_labels = "/home/abhishek/Desktop/RUG/hw_recognition/IAM-data/iam_lines_gt.txt"
|
30 |
+
|
31 |
+
parser = argparse.ArgumentParser(
|
32 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
33 |
+
)
|
34 |
+
|
35 |
+
parser.add_argument("--file_txt_labels", default=file_txt_labels,
|
36 |
+
type=str, help="full path to label text file")
|
37 |
+
|
38 |
+
FLAGS, unparsed = parser.parse_known_args()
|
39 |
+
list_unique_characters_in_IAM_dataset(FLAGS)
|
40 |
+
return
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
main()
|