triopood commited on
Commit
381bd48
1 Parent(s): 668cf56

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +285 -0
app.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=12reT7rxiRqTERYqeKYx7WGz5deMXjnEo' -O filetxt
2
+ !unzip filetxt
3
+
4
+ from roboflow import Roboflow
5
+ rf = Roboflow(api_key="kGIFR6wPmDow2dHnoXoi")
6
+ project = rf.workspace("capstone-design-oyzc3").project("dataset-train-test")
7
+ dataset = project.version(1).download("folder")
8
+
9
+ import os
10
+ import torch
11
+ import evaluate
12
+ import numpy as np
13
+ import pandas as pd
14
+ import glob as glob
15
+ import torch.optim as optim
16
+ import matplotlib.pyplot as plt
17
+ import torchvision.transforms as transforms
18
+
19
+
20
+ from PIL import Image
21
+ from zipfile import ZipFile
22
+ from tqdm.notebook import tqdm
23
+ from dataclasses import dataclass
24
+ from torch.utils.data import Dataset
25
+ from urllib.request import urlretrieve
26
+ from transformers import (
27
+ VisionEncoderDecoderModel,
28
+ TrOCRProcessor,
29
+ Seq2SeqTrainer,
30
+ Seq2SeqTrainingArguments,
31
+ default_data_collator
32
+ AutoModel
33
+ )
34
+
35
+ def seed_everything(seed_value):
36
+ np.random.seed(seed_value)
37
+ torch.manual_seed(seed_value)
38
+ torch.cuda.manual_seed_all(seed_value)
39
+ torch.backends.cudnn.deterministic = True
40
+ torch.backends.cudnn.benchmark = False
41
+
42
+ seed_everything(42)
43
+
44
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
45
+
46
+ def download_and_unzip(url, save_path):
47
+ print(f"Downloading and extracting assets....", end="")
48
+
49
+
50
+ # Downloading zip file using urllib package.
51
+ urlretrieve(url, save_path)
52
+
53
+
54
+ try:
55
+ # Extracting zip file using the zipfile package.
56
+ with ZipFile(save_path) as z:
57
+ # Extract ZIP file contents in the same directory.
58
+ z.extractall(os.path.split(save_path)[0])
59
+
60
+
61
+ print("Done")
62
+
63
+
64
+ except Exception as e:
65
+ print("\nInvalid file.", e)
66
+
67
+ URL = r"https://app.roboflow.com/ds/TZnI5u5spH?key=krcK5FWtuB"
68
+ asset_zip_path = os.path.join(os.getcwd(), "capstone-design-oyzc3.zip")
69
+
70
+ # Download if asset ZIP does not exist.
71
+ if not os.path.exists(asset_zip_path):
72
+ download_and_unzip(URL, asset_zip_path)
73
+
74
+ @dataclass(frozen=True)
75
+ class TrainingConfig:
76
+ BATCH_SIZE: int = 25
77
+ EPOCHS: int = 20
78
+ LEARNING_RATE: float = 0.00005
79
+
80
+ @dataclass(frozen=True)
81
+ class DatasetConfig:
82
+ DATA_ROOT: str = 'DATASET-TRAIN-TEST-1'
83
+
84
+ @dataclass(frozen=True)
85
+ class ModelConfig:
86
+ MODEL_NAME: str = 'microsoft/trocr-small-printed'
87
+
88
+ def visualize(dataset_path):
89
+ plt.figure(figsize=(15, 3))
90
+ for i in range(15):
91
+ plt.subplot(3, 5, i+1)
92
+ all_images = os.listdir(f"{dataset_path}/train/train")
93
+ image = plt.imread(f"{dataset_path}/train/train/{all_images[i]}")
94
+ plt.imshow(image)
95
+ plt.axis('off')
96
+ plt.title(all_images[i].split('.')[0])
97
+ plt.show()
98
+
99
+
100
+ visualize(DatasetConfig.DATA_ROOT)
101
+
102
+ train_df = pd.read_fwf(
103
+ os.path.join(DatasetConfig.DATA_ROOT, '/content/DATASET TXT/train/train.txt'), header=None
104
+ )
105
+ train_df.rename(columns={0: 'file_name', 1: 'text'}, inplace=True)
106
+ test_df = pd.read_fwf(
107
+ os.path.join(DatasetConfig.DATA_ROOT, '/content/DATASET TXT/test/test.txt'), header=None
108
+ )
109
+ test_df.rename(columns={0: 'file_name', 1: 'text'}, inplace=True)
110
+
111
+ # Augmentations.
112
+ train_transforms = transforms.Compose([
113
+ transforms.ColorJitter(brightness=.5, hue=.3),
114
+ transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
115
+ ])
116
+
117
+ class CustomOCRDataset(Dataset):
118
+ def __init__(self, root_dir, df, processor, max_target_length=128):
119
+ self.root_dir = root_dir
120
+ self.df = df
121
+ self.processor = processor
122
+ self.max_target_length = max_target_length
123
+
124
+
125
+ def __len__(self):
126
+ return len(self.df)
127
+
128
+
129
+ def __getitem__(self, idx):
130
+ # The image file name.
131
+ file_name = self.df['file_name'][idx]
132
+ # The text (label).
133
+ text = self.df['text'][idx]
134
+ # Read the image, apply augmentations, and get the transformed pixels.
135
+ image = Image.open(self.root_dir + file_name).convert('RGB')
136
+ image = train_transforms(image)
137
+ pixel_values = self.processor(image, return_tensors='pt').pixel_values
138
+ # Pass the text through the tokenizer and get the labels,
139
+ # i.e. tokenized labels.
140
+ labels = self.processor.tokenizer(
141
+ text,
142
+ padding='max_length',
143
+ max_length=self.max_target_length
144
+ ).input_ids
145
+ # We are using -100 as the padding token.
146
+ labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
147
+ encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
148
+ return encoding
149
+
150
+ processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
151
+ train_dataset = CustomOCRDataset(
152
+ root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'train/train/'),
153
+ df=train_df,
154
+ processor=processor
155
+ )
156
+ valid_dataset = CustomOCRDataset(
157
+ root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'test/test/'),
158
+ df=test_df,
159
+ processor=processor
160
+ )
161
+
162
+ model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
163
+ model.to(device)
164
+ print(model)
165
+ # Total parameters and trainable parameters.
166
+ total_params = sum(p.numel() for p in model.parameters())
167
+ print(f"{total_params:,} total parameters.")
168
+ total_trainable_params = sum(
169
+ p.numel() for p in model.parameters() if p.requires_grad)
170
+ print(f"{total_trainable_params:,} training parameters.")
171
+
172
+ # Set special tokens used for creating the decoder_input_ids from the labels.
173
+ model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
174
+ model.config.pad_token_id = processor.tokenizer.pad_token_id
175
+ # Set Correct vocab size.
176
+ model.config.vocab_size = model.config.decoder.vocab_size
177
+ model.config.eos_token_id = processor.tokenizer.sep_token_id
178
+
179
+
180
+ model.config.max_length = 64
181
+ model.config.early_stopping = True
182
+ model.config.no_repeat_ngram_size = 3
183
+ model.config.length_penalty = 2.0
184
+ model.config.num_beams = 4
185
+
186
+ optimizer = optim.AdamW(
187
+ model.parameters(), lr=TrainingConfig.LEARNING_RATE, weight_decay=0.0005
188
+ )
189
+
190
+ cer_metric = evaluate.load('cer')
191
+
192
+
193
+ def compute_cer(pred):
194
+ labels_ids = pred.label_ids
195
+ pred_ids = pred.predictions
196
+
197
+
198
+ pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
199
+ labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
200
+ label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
201
+
202
+
203
+ cer = cer_metric.compute(predictions=pred_str, references=label_str)
204
+
205
+
206
+ return {"cer": cer}
207
+
208
+ training_args = Seq2SeqTrainingArguments(
209
+ predict_with_generate=True,
210
+ evaluation_strategy='epoch',
211
+ per_device_train_batch_size=TrainingConfig.BATCH_SIZE,
212
+ per_device_eval_batch_size=TrainingConfig.BATCH_SIZE,
213
+ fp16=True,
214
+ output_dir='seq2seq_model_printed/',
215
+ logging_strategy='epoch',
216
+ save_strategy='epoch',
217
+ save_total_limit=5,
218
+ report_to='tensorboard',
219
+ num_train_epochs=TrainingConfig.EPOCHS
220
+ )
221
+
222
+ # Initialize trainer.
223
+ trainer = Seq2SeqTrainer(
224
+ model=model,
225
+ tokenizer=processor.feature_extractor,
226
+ args=training_args,
227
+ compute_metrics=compute_cer,
228
+ train_dataset=train_dataset,
229
+ eval_dataset=valid_dataset,
230
+ data_collator=default_data_collator
231
+ )
232
+
233
+ res = trainer.train()
234
+
235
+ processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
236
+ trained_model = VisionEncoderDecoderModel.from_pretrained('seq2seq_model_printed/checkpoint-'+str(res.global_step)).to(device)
237
+
238
+ def read_and_show(image_path):
239
+ """
240
+ :param image_path: String, path to the input image.
241
+
242
+
243
+ Returns:
244
+ image: PIL Image.
245
+ """
246
+ image = Image.open(image_path).convert('RGB')
247
+ return image
248
+
249
+ def ocr(image, processor, model):
250
+ """
251
+ :param image: PIL Image.
252
+ :param processor: Huggingface OCR processor.
253
+ :param model: Huggingface OCR model.
254
+
255
+
256
+ Returns:
257
+ generated_text: the OCR'd text string.
258
+ """
259
+ # We can directly perform OCR on cropped images.
260
+ pixel_values = processor(image, return_tensors='pt').pixel_values.to(device)
261
+ generated_ids = model.generate(pixel_values)
262
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
263
+ return generated_text
264
+
265
+ def eval_new_data(
266
+ data_path=os.path.join(DatasetConfig.DATA_ROOT, 'test/test', '*'),
267
+ num_samples=50
268
+ ):
269
+ image_paths = glob.glob(data_path)
270
+ for i, image_path in tqdm(enumerate(image_paths), total=len(image_paths)):
271
+ if i == num_samples:
272
+ break
273
+ image = read_and_show(image_path)
274
+ text = ocr(image, processor, trained_model)
275
+ plt.figure(figsize=(7, 4))
276
+ plt.imshow(image)
277
+ plt.title(text)
278
+ plt.axis('off')
279
+ plt.show()
280
+
281
+ eval_new_data(
282
+ data_path=os.path.join(DatasetConfig.DATA_ROOT, 'test/test/', '*'),
283
+ num_samples=100
284
+ )
285
+