diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..6dbdba028ae6c6919f653a4aba1db6d1306d21bb --- /dev/null +++ b/app.py @@ -0,0 +1,50 @@ +import os +import streamlit as st +from PIL import Image + +from inference import get_predictions + + +st.title('Handwriting Recognition Demo') + +sample_files = os.listdir('./data/sample_images') +tot_index = len(sample_files) +sample_path = './data/sample_images' + +if 'image_index' not in st.session_state: + st.session_state['image_index'] = 4 + +if 'which_button' not in st.session_state: + st.session_state['which_button'] = 'sample_button' + +st.write('**Select from sample images**') + +st.write("Select one from these available samples: ") +current_index = st.session_state['image_index'] +current_image = Image.open(os.path.join(sample_path, sample_files[current_index])) + +# next = st.button('next_image') +prev_button, next_button = st.columns(2) +with prev_button: + prev = st.button('prev_image') +with next_button: + next = st.button('next_image') +if prev: + current_index = (current_index - 1) % tot_index +if next: + current_index = (current_index + 1) % tot_index +st.session_state['image_index'] = current_index +sample_image = Image.open(os.path.join(sample_path, sample_files[current_index])) +st.image(sample_image, caption='Chosen image') + +# use_sample_image = st.button("Use this Sample") +# if use_sample_image is True: +# st.session_state['which_button'] = 'sample_button' + +predict_clicked = st.button("Get prediction") +if predict_clicked: + which_button = st.session_state['which_button'] + if which_button == 'sample_button': + predictions = get_predictions(sample_image) + st.markdown('**The model predictions along with their probabilities are :**') + st.write(predictions) \ No newline at end of file diff --git a/data/sample_images/sample10_LEA b/data/sample_images/sample10_LEA new file mode 100644 index 0000000000000000000000000000000000000000..dbcea24c58468ce3c7f8905e946fe8fe15800374 Binary files /dev/null and b/data/sample_images/sample10_LEA differ diff --git a/data/sample_images/sample11_EVAN b/data/sample_images/sample11_EVAN new file mode 100644 index 0000000000000000000000000000000000000000..cd719a46883945d35f9f408b751ca4fe6f944c3f Binary files /dev/null and b/data/sample_images/sample11_EVAN differ diff --git a/data/sample_images/sample12_HELOISE b/data/sample_images/sample12_HELOISE new file mode 100644 index 0000000000000000000000000000000000000000..986693d6438e10e11d03df125ebe802d1ebd5277 Binary files /dev/null and b/data/sample_images/sample12_HELOISE differ diff --git a/data/sample_images/sample13_ROCHETTE b/data/sample_images/sample13_ROCHETTE new file mode 100644 index 0000000000000000000000000000000000000000..3aae1c77bfe7c376ef175b0c813e47c458a96dd3 Binary files /dev/null and b/data/sample_images/sample13_ROCHETTE differ diff --git a/data/sample_images/sample14_JUSTINE b/data/sample_images/sample14_JUSTINE new file mode 100644 index 0000000000000000000000000000000000000000..d0329a453e17e109b9f520a263dc3f40d7bf87cc Binary files /dev/null and b/data/sample_images/sample14_JUSTINE differ diff --git a/data/sample_images/sample15_PAQUET b/data/sample_images/sample15_PAQUET new file mode 100644 index 0000000000000000000000000000000000000000..b328f812d6d01bac554edc6cb3bbb5e07776fd03 Binary files /dev/null and b/data/sample_images/sample15_PAQUET differ diff --git a/data/sample_images/sample16_RADIA b/data/sample_images/sample16_RADIA new file mode 100644 index 0000000000000000000000000000000000000000..e2348fca22b32dfd46aa8d51c97fb293d036782f Binary files /dev/null and b/data/sample_images/sample16_RADIA differ diff --git a/data/sample_images/sample17_LOUIS b/data/sample_images/sample17_LOUIS new file mode 100644 index 0000000000000000000000000000000000000000..9184e506a4823beec821087dab12e6891cce94f5 Binary files /dev/null and b/data/sample_images/sample17_LOUIS differ diff --git a/data/sample_images/sample18_WICO b/data/sample_images/sample18_WICO new file mode 100644 index 0000000000000000000000000000000000000000..6c729ad240c7e201f1e94807af826fd36edb6159 Binary files /dev/null and b/data/sample_images/sample18_WICO differ diff --git a/data/sample_images/sample19_AZZANO b/data/sample_images/sample19_AZZANO new file mode 100644 index 0000000000000000000000000000000000000000..99a16170a3332d92f82615fbf822075051b9e92a Binary files /dev/null and b/data/sample_images/sample19_AZZANO differ diff --git a/data/sample_images/sample1_MOUSSAID b/data/sample_images/sample1_MOUSSAID new file mode 100644 index 0000000000000000000000000000000000000000..1fe2425cb3778dc0a3cc9b357439480219975274 Binary files /dev/null and b/data/sample_images/sample1_MOUSSAID differ diff --git a/data/sample_images/sample20_HEZZAT b/data/sample_images/sample20_HEZZAT new file mode 100644 index 0000000000000000000000000000000000000000..2b08d4c54abd2296ffd8153d394160d539ac8249 Binary files /dev/null and b/data/sample_images/sample20_HEZZAT differ diff --git a/data/sample_images/sample21_RAPHAEL b/data/sample_images/sample21_RAPHAEL new file mode 100644 index 0000000000000000000000000000000000000000..12479e0342b8b698caba0f9c852675eb1aee9f44 Binary files /dev/null and b/data/sample_images/sample21_RAPHAEL differ diff --git a/data/sample_images/sample22_SARAH b/data/sample_images/sample22_SARAH new file mode 100644 index 0000000000000000000000000000000000000000..fe31d312d2501e12fe6df41f21a79ced7971dc36 Binary files /dev/null and b/data/sample_images/sample22_SARAH differ diff --git a/data/sample_images/sample23_KLEIN b/data/sample_images/sample23_KLEIN new file mode 100644 index 0000000000000000000000000000000000000000..801b956fb365c2fd5bea4c9ff3a697e999cfe510 Binary files /dev/null and b/data/sample_images/sample23_KLEIN differ diff --git a/data/sample_images/sample24_CLEMENCE b/data/sample_images/sample24_CLEMENCE new file mode 100644 index 0000000000000000000000000000000000000000..1be5fa59ed27aaf2986b087c0cf2d302184ed3cd Binary files /dev/null and b/data/sample_images/sample24_CLEMENCE differ diff --git a/data/sample_images/sample25_HADDAD b/data/sample_images/sample25_HADDAD new file mode 100644 index 0000000000000000000000000000000000000000..00d3e56926a82befa9dce463672f7c968ddc47c3 Binary files /dev/null and b/data/sample_images/sample25_HADDAD differ diff --git a/data/sample_images/sample26_LOUIS b/data/sample_images/sample26_LOUIS new file mode 100644 index 0000000000000000000000000000000000000000..59ecd529a0740db65b50bdcbc285da3435b5f392 Binary files /dev/null and b/data/sample_images/sample26_LOUIS differ diff --git a/data/sample_images/sample27_COUTO b/data/sample_images/sample27_COUTO new file mode 100644 index 0000000000000000000000000000000000000000..8f5c66312cc4187877460135e7073d2eef74c554 Binary files /dev/null and b/data/sample_images/sample27_COUTO differ diff --git a/data/sample_images/sample28_BARED b/data/sample_images/sample28_BARED new file mode 100644 index 0000000000000000000000000000000000000000..b7514895580e42350a64b3d2ffb9eee70d21403d Binary files /dev/null and b/data/sample_images/sample28_BARED differ diff --git a/data/sample_images/sample29_LONGEIM b/data/sample_images/sample29_LONGEIM new file mode 100644 index 0000000000000000000000000000000000000000..adb4ef3d97025f1e52f86b431af0b6561e9acf6c Binary files /dev/null and b/data/sample_images/sample29_LONGEIM differ diff --git a/data/sample_images/sample2_TOM b/data/sample_images/sample2_TOM new file mode 100644 index 0000000000000000000000000000000000000000..ac29bd1b4448f0382772a3c7f74796559eea555c Binary files /dev/null and b/data/sample_images/sample2_TOM differ diff --git a/data/sample_images/sample30_FIDELE b/data/sample_images/sample30_FIDELE new file mode 100644 index 0000000000000000000000000000000000000000..c8e431060a014347277b3e71ab49f2071b411163 Binary files /dev/null and b/data/sample_images/sample30_FIDELE differ diff --git a/data/sample_images/sample31_SUISSA b/data/sample_images/sample31_SUISSA new file mode 100644 index 0000000000000000000000000000000000000000..967decdfdb3d26b25f00976fb1b61c72889947b6 Binary files /dev/null and b/data/sample_images/sample31_SUISSA differ diff --git a/data/sample_images/sample32_TRISTAN b/data/sample_images/sample32_TRISTAN new file mode 100644 index 0000000000000000000000000000000000000000..33d8a11e76f67e222ab713d7edddbaf71f152090 Binary files /dev/null and b/data/sample_images/sample32_TRISTAN differ diff --git a/data/sample_images/sample33_DEBERNARDI b/data/sample_images/sample33_DEBERNARDI new file mode 100644 index 0000000000000000000000000000000000000000..3d5f0bc098521efa8c7f5dab2e80a2b72e592d56 Binary files /dev/null and b/data/sample_images/sample33_DEBERNARDI differ diff --git a/data/sample_images/sample34_LOLA b/data/sample_images/sample34_LOLA new file mode 100644 index 0000000000000000000000000000000000000000..be71560d68713af82481552112d9da89901dda35 Binary files /dev/null and b/data/sample_images/sample34_LOLA differ diff --git a/data/sample_images/sample35_JUSTIN b/data/sample_images/sample35_JUSTIN new file mode 100644 index 0000000000000000000000000000000000000000..b3746186b158d5a347521cde6c047e667a51d674 Binary files /dev/null and b/data/sample_images/sample35_JUSTIN differ diff --git a/data/sample_images/sample36_ANA b/data/sample_images/sample36_ANA new file mode 100644 index 0000000000000000000000000000000000000000..49a19a8f0225c87516f59b1df670fcaeb7424384 Binary files /dev/null and b/data/sample_images/sample36_ANA differ diff --git a/data/sample_images/sample37_BAUDRILLART b/data/sample_images/sample37_BAUDRILLART new file mode 100644 index 0000000000000000000000000000000000000000..d1b87dd599e88cdf5e23f6fc351b3c50f049df7f Binary files /dev/null and b/data/sample_images/sample37_BAUDRILLART differ diff --git a/data/sample_images/sample38_JEREMY b/data/sample_images/sample38_JEREMY new file mode 100644 index 0000000000000000000000000000000000000000..d8da78a30cd60bf633127bfad1061a4484370240 Binary files /dev/null and b/data/sample_images/sample38_JEREMY differ diff --git a/data/sample_images/sample39_MATMATI b/data/sample_images/sample39_MATMATI new file mode 100644 index 0000000000000000000000000000000000000000..7bc2cd5984d68883021a124b877813d4d7b154d6 Binary files /dev/null and b/data/sample_images/sample39_MATMATI differ diff --git a/data/sample_images/sample3_GARCIA b/data/sample_images/sample3_GARCIA new file mode 100644 index 0000000000000000000000000000000000000000..50148767b44e4a83595e3e0bf79f22472ff45c54 Binary files /dev/null and b/data/sample_images/sample3_GARCIA differ diff --git a/data/sample_images/sample40_SASHA b/data/sample_images/sample40_SASHA new file mode 100644 index 0000000000000000000000000000000000000000..2e83c500795f4870cd5b1086ef7b81aaf846d010 Binary files /dev/null and b/data/sample_images/sample40_SASHA differ diff --git a/data/sample_images/sample41_THIBAULT b/data/sample_images/sample41_THIBAULT new file mode 100644 index 0000000000000000000000000000000000000000..a86b491ff1ccecde1bc703743416bdfab5d3f939 Binary files /dev/null and b/data/sample_images/sample41_THIBAULT differ diff --git a/data/sample_images/sample42_SOUNI b/data/sample_images/sample42_SOUNI new file mode 100644 index 0000000000000000000000000000000000000000..e98b342f56cf1696d64e5afba171d9c54016801f Binary files /dev/null and b/data/sample_images/sample42_SOUNI differ diff --git a/data/sample_images/sample43_JOUIDI b/data/sample_images/sample43_JOUIDI new file mode 100644 index 0000000000000000000000000000000000000000..f89660e9768b29eccfea349ef86b9903558ae1b6 Binary files /dev/null and b/data/sample_images/sample43_JOUIDI differ diff --git a/data/sample_images/sample44_GAUTIER b/data/sample_images/sample44_GAUTIER new file mode 100644 index 0000000000000000000000000000000000000000..a9e8e50c15247d7b8e8c26088f969919b29d4ff0 Binary files /dev/null and b/data/sample_images/sample44_GAUTIER differ diff --git a/data/sample_images/sample45_MAREZ b/data/sample_images/sample45_MAREZ new file mode 100644 index 0000000000000000000000000000000000000000..b85f33633da94ac93c2e71cece7f05dd225378ce Binary files /dev/null and b/data/sample_images/sample45_MAREZ differ diff --git a/data/sample_images/sample46_BRESCIANI b/data/sample_images/sample46_BRESCIANI new file mode 100644 index 0000000000000000000000000000000000000000..4c7e792ee913ad76042537bfa2eb4fd1a84272f4 Binary files /dev/null and b/data/sample_images/sample46_BRESCIANI differ diff --git a/data/sample_images/sample47_CLEMENT b/data/sample_images/sample47_CLEMENT new file mode 100644 index 0000000000000000000000000000000000000000..de24fec59f01b03b0724a9678d5bc7f3d67106a1 Binary files /dev/null and b/data/sample_images/sample47_CLEMENT differ diff --git a/data/sample_images/sample48_DUHAMEL b/data/sample_images/sample48_DUHAMEL new file mode 100644 index 0000000000000000000000000000000000000000..89f8d6594466f8f303799bfaaabcbe045fcee77a Binary files /dev/null and b/data/sample_images/sample48_DUHAMEL differ diff --git a/data/sample_images/sample49_THOMAS b/data/sample_images/sample49_THOMAS new file mode 100644 index 0000000000000000000000000000000000000000..9091364856596af7398871de30ecd5a8748c1aa5 Binary files /dev/null and b/data/sample_images/sample49_THOMAS differ diff --git a/data/sample_images/sample4_LAURENT b/data/sample_images/sample4_LAURENT new file mode 100644 index 0000000000000000000000000000000000000000..594f9736bee783c04282b818a394f758d584b5a4 Binary files /dev/null and b/data/sample_images/sample4_LAURENT differ diff --git a/data/sample_images/sample50_ABDO b/data/sample_images/sample50_ABDO new file mode 100644 index 0000000000000000000000000000000000000000..da5a4a875648f85450935a624f853c638b033e8f Binary files /dev/null and b/data/sample_images/sample50_ABDO differ diff --git a/data/sample_images/sample5_PHILIPPE b/data/sample_images/sample5_PHILIPPE new file mode 100644 index 0000000000000000000000000000000000000000..1a791672bd41d39705512763e09f23eaf5b6e884 Binary files /dev/null and b/data/sample_images/sample5_PHILIPPE differ diff --git a/data/sample_images/sample6_ANTOINE b/data/sample_images/sample6_ANTOINE new file mode 100644 index 0000000000000000000000000000000000000000..716a619edde870c76d3f3d19905e96da95f7edfb Binary files /dev/null and b/data/sample_images/sample6_ANTOINE differ diff --git a/data/sample_images/sample7_MARCHAND b/data/sample_images/sample7_MARCHAND new file mode 100644 index 0000000000000000000000000000000000000000..e70fc75671fe99203a527ff2a155415908b405f5 Binary files /dev/null and b/data/sample_images/sample7_MARCHAND differ diff --git a/data/sample_images/sample8_SOUNDOUS b/data/sample_images/sample8_SOUNDOUS new file mode 100644 index 0000000000000000000000000000000000000000..d06b882134101d418eb5d73b5492f13b4034dffd Binary files /dev/null and b/data/sample_images/sample8_SOUNDOUS differ diff --git a/data/sample_images/sample9_NOHAMED-AMAR b/data/sample_images/sample9_NOHAMED-AMAR new file mode 100644 index 0000000000000000000000000000000000000000..858a04c2bd7354794aaea9eb1063b58b6bcb3a87 Binary files /dev/null and b/data/sample_images/sample9_NOHAMED-AMAR differ diff --git a/datasets.py b/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..0c60d2adaeee3295b40bcb0fe5130398a11e9b56 --- /dev/null +++ b/datasets.py @@ -0,0 +1,127 @@ +import os +from typing import Optional + +import pandas as pd +import pytorch_lightning as pl +import torch +from PIL import Image + +from torch.utils.data import Dataset, DataLoader +from torch.nn.utils.rnn import pad_sequence +from torchvision.transforms import Compose, Resize, ToTensor, Grayscale, RandomRotation, RandomApply, \ + GaussianBlur, CenterCrop + + +class KaggleHandwrittenNames(Dataset): + def __init__(self, data, transforms, label_to_index, img_path): + self.data = data + self.transforms = transforms + self.img_path = img_path + self.label_to_index = label_to_index + + def __len__(self): + return self.data.shape[0] + + def __getitem__(self, index): + row = self.data.iloc[index] + file_name = row['FILENAME'] + image_label = row['IDENTITY'] + the_image = Image.open(os.path.join(self.img_path, file_name)) + transformed_image = self.transforms(the_image) + target_len = len(image_label) + label_chars = list(image_label) + image_label = torch.tensor([self.label_to_index[char] for char in label_chars]) + return { + 'transformed_image': transformed_image, + 'label': image_label, + 'target_len': target_len + } + + +class KaggleHandwritingDataModule(pl.LightningDataModule): + def __init__(self, train_data, val_data, hparams, label_to_index): + super().__init__() + self.train_data = train_data + self.val_data = val_data + self.train_batch_size = hparams['train_batch_size'] + self.val_batch_size = hparams['val_batch_size'] + self.transforms = Compose([Resize((hparams['input_height'], hparams['input_width'])), Grayscale(), + ToTensor()]) + applier1 = RandomApply(transforms=[RandomRotation(10), GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))], p=0.5) + applier2 = RandomApply(transforms=[CenterCrop((hparams['input_height'] - 1, hparams['input_width'] - 2))], p=0.5) + self.train_transforms = Compose([applier2, Resize((hparams['input_height'], hparams['input_width'])), Grayscale(), + applier1, ToTensor()]) + self.train_img_path = hparams['train_img_path'] + self.val_img_path = hparams['val_img_path'] + self.label_to_index = label_to_index + + def setup(self, stage: Optional[str] = None): + if stage == 'fit' or stage is None: + self.train = KaggleHandwrittenNames(self.train_data, self.train_transforms, self.label_to_index, self.train_img_path) + self.val = KaggleHandwrittenNames(self.val_data, self.transforms, self.label_to_index, self.val_img_path) + + def custom_collate(data): + ''' + To handle variable max seq length batch size + ''' + transformed_images = [] + labels = [] + target_lens = [] + for d in data: + transformed_images.append(d['transformed_image']) + labels.append(d['label']) + target_lens.append(d['target_len']) + batch_labels = pad_sequence(labels, batch_first=True, padding_value=-1) + transformed_images = torch.stack(transformed_images) + target_lens = torch.tensor(target_lens) + return { + 'transformed_images': transformed_images, + 'labels': batch_labels, + 'target_lens': target_lens + } + + def train_dataloader(self): + return DataLoader(self.train, batch_size=self.train_batch_size, shuffle=True, pin_memory=True, + num_workers=8, collate_fn=KaggleHandwritingDataModule.custom_collate) + + def val_dataloader(self): + return DataLoader(self.val, batch_size=self.val_batch_size, shuffle=False, pin_memory=True, + num_workers=8, collate_fn=KaggleHandwritingDataModule.custom_collate) + + +def test_kaggle_handwritting(): + pl.seed_everything(267) + hparams = { + 'train_img_path': './data/kaggle-handwriting-recognition/train_v2/train/', + 'lr': 1e-3, 'val_img_path': './data/kaggle-handwriting-recognition/validation_v2/validation/', + 'test_img_path': './data/kaggle-handwriting-recognition/test_v2/test/', + 'data_path': './data/kaggle-handwriting-recognition', 'gru_input_size': 256, + 'train_batch_size': 64, 'val_batch_size': 256, 'input_height': 36, 'input_width': 324, 'gru_hidden_size': 128, + 'gru_num_layers': 1, 'num_classes': 28 + } + label_to_index = {' ': 0, '-': 1, 'A': 2, 'B': 3, 'C': 4, 'D': 5, 'E': 6, 'F': 7, 'G': 8, 'H': 9, 'I': 10, 'J': 11, + 'K': 12, 'L': 13, 'M': 14, 'N': 15, 'O': 16, 'P': 17, 'Q': 18, 'R': 19, 'S': 20, 'T': 21, 'U': 22, + 'V': 23, 'W': 24, 'X': 25, 'Y': 26, 'Z': 27} + + train_df = pd.read_csv(os.path.join(hparams['data_path'], 'train_new.csv')) + train_df = train_df[train_df.word_type == 'normal_word'] + train_df = train_df.sample(frac=1).reset_index(drop=True) + val_df = pd.read_csv(os.path.join(hparams['data_path'], 'val_new.csv')) + val_df = val_df[val_df.word_type == 'normal_word'] + val_df = val_df.sample(frac=1).reset_index(drop=True) + sample_module = KaggleHandwritingDataModule(train_df, val_df, hparams, label_to_index) + sample_module.setup() + sample_train_module = sample_module.train_dataloader() + sample_val_module = sample_module.val_dataloader() + sample_train_batch = next(iter(sample_train_module)) + sample_val_batch = next(iter(sample_val_module)) + print(sample_train_batch['transformed_images'].shape) + print(sample_val_batch['transformed_images'].shape) + print(sample_train_batch['labels'].shape) + print(sample_val_batch['labels'].shape) + print(sample_train_batch['target_lens'].shape) + print(sample_val_batch['target_lens'].shape) + + +if __name__ == '__main__': + test_kaggle_handwritting() \ No newline at end of file diff --git a/final-models/epoch=47-val-loss=0.190-val-exact-match=83.1511001586914-val-char-error-rate=0.042957037687301636.ckpt b/final-models/epoch=47-val-loss=0.190-val-exact-match=83.1511001586914-val-char-error-rate=0.042957037687301636.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..25873d72e3dca3e67d63b50d368cc99460a0f992 --- /dev/null +++ b/final-models/epoch=47-val-loss=0.190-val-exact-match=83.1511001586914-val-char-error-rate=0.042957037687301636.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f4ef0fcf8cb55d80e6f10e14709164dfba2259d52cc52c2c82aedec66b3ce726 +size 21353207 diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..12f2c0ddb86e1c8ee2d1c0c3654760b913a47d8f --- /dev/null +++ b/inference.py @@ -0,0 +1,52 @@ +import os +import streamlit as st +import torch +from ctc_decoder import best_path, beam_search +from torchvision.transforms import Compose, Resize, Grayscale, ToTensor + +from training_modules import HandwritingRecogTrainModule + + +@st.experimental_memo +def get_model_details(): + path = './final-models/' + model_weights = 'epoch=47-val-loss=0.190-val-exact-match=83.1511001586914-val-char-error-rate=0.042957037687301636.ckpt' + model_path = os.path.join(path, model_weights) + hparams = { + 'train_img_path': './data/kaggle-handwriting-recognition/train_v2/train/', + 'lr': 1e-4, 'val_img_path': './data/kaggle-handwriting-recognition/validation_v2/validation/', + 'test_img_path': './data/kaggle-handwriting-recognition/test_v2/test/', + 'data_path': './data/kaggle-handwriting-recognition', 'gru_input_size': 256, + 'train_batch_size': 64, 'val_batch_size': 1024, 'input_height': 36, 'input_width': 324, 'gru_hidden_size': 128, + 'gru_num_layers': 2, 'num_classes': 28 + } + label_to_index = {' ': 0, '-': 1, 'A': 2, 'B': 3, 'C': 4, 'D': 5, 'E': 6, 'F': 7, 'G': 8, 'H': 9, 'I': 10, 'J': 11, + 'K': 12, 'L': 13, 'M': 14, 'N': 15, 'O': 16, 'P': 17, 'Q': 18, 'R': 19, 'S': 20, 'T': 21, 'U': 22, + 'V': 23, 'W': 24, 'X': 25, 'Y': 26, 'Z': 27} + + index_to_labels = {0: ' ', 1: '-', 2: 'A', 3: 'B', 4: 'C', 5: 'D', 6: 'E', 7: 'F', 8: 'G', 9: 'H', 10: 'I', + 11: 'J', 12: 'K', 13: 'L', 14: 'M', 15: 'N', 16: 'O', 17: 'P', 18: 'Q', 19: 'R', 20: 'S', + 21: 'T', 22: 'U', 23: 'V', 24: 'W', 25: 'X', 26: 'Y', 27: 'Z'} + transforms = Compose([Resize((hparams['input_height'], hparams['input_width'])), Grayscale(), ToTensor()]) + return model_path, hparams, label_to_index, index_to_labels, transforms + + +@st.experimental_memo +def load_trained_model(model_path): + + model = HandwritingRecogTrainModule.load_from_checkpoint(model_path) + return model + + +def get_predictions(image): + model_path, hparams, label_to_index, index_to_labels, transforms = get_model_details() + transformed_image = transforms(image) + transformed_image = torch.unsqueeze(transformed_image, 0) + model = load_trained_model(model_path) + model.eval() + out = model(transformed_image) + out = out.cpu().detach().numpy() + prediction = out[0] + predicted_string = beam_search(prediction, model.chars, beam_width=2) + + return predicted_string diff --git a/modelling.py b/modelling.py new file mode 100644 index 0000000000000000000000000000000000000000..0f83adf15632bcb1c2aa62c99aee4143c523d3c5 --- /dev/null +++ b/modelling.py @@ -0,0 +1,105 @@ +import os +import pandas as pd +import torch.nn as nn +import torch.nn.functional as F +import pytorch_lightning as pl +from datasets import KaggleHandwritingDataModule + + +class PrintLayer(nn.Module): + def __init__(self): + super(PrintLayer, self).__init__() + + def forward(self, x): + print(x.shape) + return x + + +class HandwritingRecognitionCNN(nn.Module): + def __init__(self): + super().__init__() + self.image_feature_extractor = nn.Sequential( + nn.Conv2d(1, 32, stride=(1, 2), kernel_size=3, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + nn.Conv2d(32, 64, stride=2, kernel_size=3, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.Conv2d(64, 128, stride=2, kernel_size=3, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, stride=(1, 2), kernel_size=3, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + return self.image_feature_extractor(x) + + +class HandwritingRecognitionGRU(nn.Module): + def __init__(self, input_dim, hidden_size, num_layers, num_classes): + super().__init__() + self.gru_layer = nn.GRU(input_dim, hidden_size, num_layers, batch_first=True, bidirectional=True, dropout=0.3) + self.output = nn.Linear(hidden_size * 2, num_classes) + + def forward(self, x): + recurrent_output, _ = self.gru_layer(x) + out = self.output(recurrent_output) + out = F.log_softmax(out, dim=2) + return out + + +class HandwritingRecognition(nn.Module): + def __init__(self, gru_input_size, gru_hidden, gru_layers, num_classes): + super().__init__() + self.cnn_feature_extractor = HandwritingRecognitionCNN() + self.gru = HandwritingRecognitionGRU(gru_input_size, gru_hidden, gru_layers, num_classes+1) + self.linear1 = nn.Linear(1280, 512) + self.activation = nn.ReLU(inplace=True) + self.dropout = nn.Dropout(p=0.4) + self.linear2 = nn.Linear(512, 256) + + def forward(self, x): + out = self.cnn_feature_extractor(x) + batch_size, channels, width, height = out.size() + out = out.view(batch_size, -1, height) + out = out.permute(0, 2, 1) + out = self.linear1(out) + out = self.activation(self.linear2(out)) + out = self.gru(out) + out = out.permute(1, 0, 2) + return out + + +def test_modelling(): + pl.seed_everything(6579) + hparams = { + 'train_img_path': './data/kaggle-handwriting-recognition/train_v2/train/', + 'lr': 1e-3, 'val_img_path': './data/kaggle-handwriting-recognition/validation_v2/validation/', + 'test_img_path': './data/kaggle-handwriting-recognition/test_v2/test/', + 'data_path': './data/kaggle-handwriting-recognition', 'gru_input_size': 256, + 'train_batch_size': 64, 'val_batch_size': 256, 'input_height': 36, 'input_width': 324, 'gru_hidden_size': 128, + 'gru_num_layers': 1, 'num_classes': 28 + } + label_to_index = {' ': 0, '-': 1, 'A': 2, 'B': 3, 'C': 4, 'D': 5, 'E': 6, 'F': 7, 'G': 8, 'H': 9, 'I': 10, 'J': 11, + 'K': 12, 'L': 13, 'M': 14, 'N': 15, 'O': 16, 'P': 17, 'Q': 18, 'R': 19, 'S': 20, 'T': 21, 'U': 22, + 'V': 23, 'W': 24, 'X': 25, 'Y': 26, 'Z': 27} + + train_df = pd.read_csv(os.path.join(hparams['data_path'], 'train_new.csv')) + train_df = train_df[train_df.word_type == 'normal_word'] + train_df = train_df.sample(frac=1).reset_index(drop=True) + val_df = pd.read_csv(os.path.join(hparams['data_path'], 'val_new.csv')) + val_df = val_df[val_df.word_type == 'normal_word'] + val_df = val_df.sample(frac=1).reset_index(drop=True) + sample_module = KaggleHandwritingDataModule(train_df, val_df, hparams, label_to_index) + sample_module.setup() + sample_train_module = sample_module.train_dataloader() + sample_train_batch = next(iter(sample_train_module)) + model = HandwritingRecognition(hparams['gru_input_size'], hparams['gru_hidden_size'], + hparams['gru_num_layers'], hparams['num_classes']) + output = model(sample_train_batch['transformed_images']) + print("the output shape:", output.shape) + +if __name__ == '__main__': + test_modelling() diff --git a/training.py b/training.py new file mode 100644 index 0000000000000000000000000000000000000000..c53430b4651e91332db254ea43a66fc6cd093f75 --- /dev/null +++ b/training.py @@ -0,0 +1,64 @@ +import os + +import pandas as pd +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ModelSummary +from pytorch_lightning.loggers import WandbLogger + +from datasets import KaggleHandwritingDataModule +from training_modules import HandwritingRecogTrainModule + + +def get_data(path): + train_df = pd.read_csv(os.path.join(path, 'train_new.csv')) + val_df = pd.read_csv(os.path.join(path, 'val_new.csv')) + train_df = train_df[train_df.IDENTITY != 'UNREADABLE'] + val_df = val_df[val_df.IDENTITY != 'UNREADABLE'] + train_df = train_df[train_df.IDENTITY != 'EMPTY'] + val_df = val_df[val_df.IDENTITY != 'EMPTY'] + train_df = train_df.sample(frac=1).reset_index(drop=True) + return train_df, val_df + + +def train_model(train_module, data_module): + checkpoint_callback = ModelCheckpoint(filename='{epoch}-{val-loss:.3f}-{val-exact-match}-{val-char-error-rate}', + save_top_k=1, monitor='val-char-error-rate', mode='min', save_last=True) + wandb_logger = WandbLogger(project="handwriting_recognition_kaggle", save_dir='./lightning_logs', + name='CNNR_run_new_version') + early_stopping = EarlyStopping(monitor="val-char-error-rate", patience=10, verbose=False, mode="min") + model_summary = ModelSummary(max_depth=-1) + # lr_monitor = LearningRateMonitor(logging_interval='step') + + # trainer = pl.Trainer(accelerator='gpu', fast_dev_run=True, max_epochs=200, + # callbacks=[checkpoint_callback, early_stopping], precision=16) + + trainer = pl.Trainer(accelerator='gpu', fast_dev_run=False, max_epochs=100, + callbacks=[checkpoint_callback, early_stopping, model_summary], logger=wandb_logger, + precision=16) + trainer.fit(train_module, data_module) + + +def test_handwriting_recognition(): + pl.seed_everything(15798) + hparams = { + 'train_img_path': './data/kaggle-handwriting-recognition/train_v2/train/', + 'lr': 1e-4, 'val_img_path': './data/kaggle-handwriting-recognition/validation_v2/validation/', + 'test_img_path': './data/kaggle-handwriting-recognition/test_v2/test/', + 'data_path': './data/kaggle-handwriting-recognition', 'gru_input_size': 256, + 'train_batch_size': 64, 'val_batch_size': 1024, 'input_height': 36, 'input_width': 324, 'gru_hidden_size': 128, + 'gru_num_layers': 2, 'num_classes': 28 + } + label_to_index = {' ': 0, '-': 1, 'A': 2, 'B': 3, 'C': 4, 'D': 5, 'E': 6, 'F': 7, 'G': 8, 'H': 9, 'I': 10, 'J': 11, + 'K': 12, 'L': 13, 'M': 14, 'N': 15, 'O': 16, 'P': 17, 'Q': 18, 'R': 19, 'S': 20, 'T': 21, 'U': 22, + 'V': 23, 'W': 24, 'X': 25, 'Y': 26, 'Z': 27} + index_to_labels = {0: ' ', 1: '-', 2: 'A', 3: 'B', 4: 'C', 5: 'D', 6: 'E', 7: 'F', 8: 'G', 9: 'H', 10: 'I', + 11: 'J', 12: 'K', 13: 'L', 14: 'M', 15: 'N', 16: 'O', 17: 'P', 18: 'Q', 19: 'R', 20: 'S', + 21: 'T', 22: 'U', 23: 'V', 24: 'W', 25: 'X', 26: 'Y', 27: 'Z'} + train_df, val_df = get_data(hparams['data_path']) + data_module = KaggleHandwritingDataModule(train_df, val_df, hparams, label_to_index) + train_module = HandwritingRecogTrainModule(hparams, index_to_labels=index_to_labels, label_to_index=label_to_index) + train_model(train_module, data_module) + + +if __name__ == '__main__': + test_handwriting_recognition() diff --git a/training_modules.py b/training_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..b329656d5ed3c3e352986260ca6e98388a41f106 --- /dev/null +++ b/training_modules.py @@ -0,0 +1,187 @@ +import os +from PIL import Image +import torch +import torch.nn as nn +from ctc_decoder import best_path, beam_search +import pytorch_lightning as pl +from torchmetrics import CharErrorRate +from torchvision.transforms import Compose, Resize, Grayscale, ToTensor +from torchvision.utils import make_grid + +from modelling import HandwritingRecognition + + +class HandwritingRecogTrainModule(pl.LightningModule): + def __init__(self, hparams, index_to_labels, label_to_index): + super().__init__() + # save_hyperparameters saves the parameters in the signature + self.save_hyperparameters() + self.chars = ' -ABCDEFGHIJKLMNOPQRSTUVWXYZ' + self.model = HandwritingRecognition(self.hparams['hparams']['gru_input_size'], self.hparams['hparams']['gru_hidden_size'], + self.hparams['hparams']['gru_num_layers'], self.hparams['hparams']['num_classes']) + self.criterion = nn.CTCLoss(blank=28, zero_infinity=True, reduction='mean') + self.transforms = Compose([Resize((self.hparams['hparams']['input_height'], self.hparams['hparams']['input_width'])), Grayscale(), + ToTensor()]) + self.char_metric = CharErrorRate() + + def forward(self, the_image): + out = self.model(the_image) + out = out.permute(1, 0, 2) + out = torch.exp(out) + return out + + def intermediate_operation(self, batch): + transformed_images = batch['transformed_images'] + labels = batch['labels'] + target_lens = batch['target_lens'] + + output = self.model(transformed_images) + + N = output.size(1) + input_length = output.size(0) + input_lengths = torch.full(size=(N,), fill_value=input_length, dtype=torch.int32) + + loss = self.criterion(output, labels, input_lengths, target_lens) + return loss, output + + def training_step(self, batch, batch_idx): + loss, preds = self.intermediate_operation(batch) + with torch.inference_mode(): + preds = preds.permute(1, 0, 2) + preds = torch.exp(preds) + ground_truth = batch['labels'] + target_lens = batch['target_lens'] + ground_truth = ground_truth.cpu().detach().numpy() + target_lens = target_lens.cpu().detach().numpy() + preds = preds.cpu().detach().numpy() + actual_predictions = [] + for pred in preds: + actual_predictions.append(best_path(pred, self.chars)) + exact_matches = 0 + actual_ground_truths = [] + for i, predicted_string in enumerate(actual_predictions): + ground_truth_sample = ground_truth[i][0:target_lens[i]] + ground_truth_string = [self.hparams.index_to_labels[index] for index in ground_truth_sample] + ground_truth_string = ''.join(ground_truth_string) + actual_ground_truths.append(ground_truth_string) + if predicted_string == ground_truth_string: + exact_matches += 1 + exact_match_percentage = (exact_matches / len(preds)) * 100 + char_error_rate = self.char_metric(actual_predictions, actual_ground_truths) + self.log_dict({'train-loss': loss, 'train-exact-match': exact_match_percentage, + 'train-char_error_rate': char_error_rate}, prog_bar=True, on_epoch=True, on_step=False) + return loss + + def validation_step(self, batch, batch_idx): + loss, preds = self.intermediate_operation(batch) + preds = preds.permute(1, 0, 2) + preds = torch.exp(preds) + ground_truth = batch['labels'] + target_lens = batch['target_lens'] + ground_truth = ground_truth.cpu().detach().numpy() + target_lens = target_lens.cpu().detach().numpy() + preds = preds.cpu().detach().numpy() + actual_predictions = [] + for pred in preds: + actual_predictions.append(best_path(pred, self.chars)) + exact_matches = 0 + actual_ground_truths = [] + for i, predicted_string in enumerate(actual_predictions): + ground_truth_sample = ground_truth[i][0:target_lens[i]] + ground_truth_string = [self.hparams.index_to_labels[index] for index in ground_truth_sample] + ground_truth_string = ''.join(ground_truth_string) + actual_ground_truths.append(ground_truth_string) + if predicted_string == ground_truth_string: + exact_matches += 1 + char_error_rate = self.char_metric(actual_predictions, actual_ground_truths) + exact_match_percentage = (exact_matches / len(preds)) * 100 + if batch_idx % self.trainer.num_val_batches[0] == 0: + small_batch = batch['transformed_images'][0:16] + small_batch_predictions = actual_predictions[0:16] + captions = small_batch_predictions + sampled_img_grid = make_grid(small_batch) + self.logger.log_image('Sample_Images', [sampled_img_grid], caption=[str(captions)]) + + self.log_dict({'val-loss': loss, 'val-exact-match': exact_match_percentage, + 'val-char-error-rate': char_error_rate}, prog_bar=False, on_epoch=True, on_step=False) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.hparams['hparams']['lr']) + + +def convert_to_torchscript(trained_path): + hparams = { + 'train_img_path': './data/kaggle-handwriting-recognition/train_v2/train/', + 'lr': 1e-3, 'val_img_path': './data/kaggle-handwriting-recognition/validation_v2/validation/', + 'test_img_path': './data/kaggle-handwriting-recognition/test_v2/test/', + 'data_path': './data/kaggle-handwriting-recognition', 'gru_input_size': 256, + 'train_batch_size': 64, 'val_batch_size': 256, 'input_height': 36, 'input_width': 324, 'gru_hidden_size': 128, + 'gru_num_layers': 1, 'num_classes': 28 + } + index_to_labels = {0: ' ', 1: '-', 2: 'A', 3: 'B', 4: 'C', 5: 'D', 6: 'E', 7: 'F', 8: 'G', 9: 'H', 10: 'I', + 11: 'J', 12: 'K', 13: 'L', 14: 'M', 15: 'N', 16: 'O', 17: 'P', 18: 'Q', 19: 'R', 20: 'S', + 21: 'T', 22: 'U', 23: 'V', 24: 'W', 25: 'X', 26: 'Y', 27: 'Z'} + label_to_index = {' ': 0, '-': 1, 'A': 2, 'B': 3, 'C': 4, 'D': 5, 'E': 6, 'F': 7, 'G': 8, 'H': 9, 'I': 10, 'J': 11, + 'K': 12, 'L': 13, 'M': 14, 'N': 15, 'O': 16, 'P': 17, 'Q': 18, 'R': 19, 'S': 20, 'T': 21, 'U': 22, + 'V': 23, 'W': 24, 'X': 25, 'Y': 26, 'Z': 27} + model = HandwritingRecogTrainModule(hparams, index_to_labels, label_to_index) + script = model.to_torchscript() + print("The script:", script) + torch.jit.save(script, './final-models/torchscript-model/handwritten-name_new.pt') + +def test_model(): + pl.seed_everything(2564) + hparams = { + 'train_img_path': './data/kaggle-handwriting-recognition/train_v2/train/', + 'lr': 1e-3, 'val_img_path': './data/kaggle-handwriting-recognition/validation_v2/validation/', + 'test_img_path': './data/kaggle-handwriting-recognition/test_v2/test/', + 'data_path': './data/kaggle-handwriting-recognition', 'gru_input_size': 256, + 'train_batch_size': 64, 'val_batch_size': 256, 'input_height': 36, 'input_width': 324, 'gru_hidden_size': 128, + 'gru_num_layers': 1, 'num_classes': 28 + } + index_to_labels = {0: ' ', 1: '-', 2: 'A', 3: 'B', 4: 'C', 5: 'D', 6: 'E', 7: 'F', 8: 'G', 9: 'H', 10: 'I', + 11: 'J', 12: 'K', 13: 'L', 14: 'M', 15: 'N', 16: 'O', 17: 'P', 18: 'Q', 19: 'R', 20: 'S', + 21: 'T', 22: 'U', 23: 'V', 24: 'W', 25: 'X', 26: 'Y', 27: 'Z'} + label_to_index = {' ': 0, '-': 1, 'A': 2, 'B': 3, 'C': 4, 'D': 5, 'E': 6, 'F': 7, 'G': 8, 'H': 9, 'I': 10, 'J': 11, + 'K': 12, 'L': 13, 'M': 14, 'N': 15, 'O': 16, 'P': 17, 'Q': 18, 'R': 19, 'S': 20, 'T': 21, 'U': 22, + 'V': 23, 'W': 24, 'X': 25, 'Y': 26, 'Z': 27} + + model = HandwritingRecogTrainModule.load_from_checkpoint( + './lightning_logs/CNNR_run_new_version/108xqa9y/checkpoints/' + 'epoch=21-val-loss=0.206-val-exact-match=81.46109771728516-val-char-error-rate=0.04727236181497574.ckpt') + input_image = Image.open(os.path.join(hparams['train_img_path'], 'TRAIN_96628.jpg')) + output = model(input_image) + print(output) + +def test_inference(): + transforms = Compose([Resize((36, 324)), Grayscale(), ToTensor()]) + input_image = Image.open(os.path.join('./data/kaggle-handwriting-recognition/train_v2/train/', 'TRAIN_96628.jpg')) + transformed_image = transforms(input_image) + # path = './lightning_logs/CNNR_run_64_2grulayers_0.3dropout/3182ng3f/checkpoints' + # model_weights = 'epoch=47-val-loss=0.190-val-exact-match=83.1511001586914-val-char-error-rate=0.042957037687301636.ckpt' + # trained_path = os.path.join(path, model_weights) + # model = HandwritingRecogTrainModule.load_from_checkpoint(trained_path) + # transformed_image = torch.unsqueeze(transformed_image, 0) + # model.eval() + # out = model(transformed_image) + script_path = './final-models/torchscript-model/handwritten-name_new.pt' + scripted_module = torch.jit.load(script_path) + out = scripted_module(transformed_image) + print("The final out shape:", out.shape) + print("The final out is :", out) + out = out.cpu().detach().numpy() + chars = ' -ABCDEFGHIJKLMNOPQRSTUVWXYZ' + for sample in out: + predicted_string = beam_search(sample, chars, beam_width=2) + print(predicted_string) + +def test_convert_to_torchscript(): + path = './lightning_logs/CNNR_run_new_version/108xqa9y/checkpoints/' + model_weights = 'epoch=21-val-loss=0.206-val-exact-match=81.46109771728516-val-char-error-rate=0.04727236181497574.ckpt' + trained_path = os.path.join(path, model_weights) + convert_to_torchscript(trained_path) + + +if __name__ == '__main__': + test_inference() +