import os import torch import evaluate import numpy as np import pandas as pd import glob as glob import torch.optim as optim import matplotlib.pyplot as plt import torchvision.transforms as transforms import subprocess from flask import Flask, request, jsonify from PIL import Image from zipfile import ZipFile from tqdm.notebook import tqdm from dataclasses import dataclass from torch.utils.data import Dataset from urllib.request import urlretrieve from transformers import ( VisionEncoderDecoderModel, TrOCRProcessor, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator ) from roboflow import Roboflow rf = Roboflow(api_key="kGIFR6wPmDow2dHnoXoi") project = rf.workspace("capstone-design-oyzc3").project("dataset-train-test") dataset = project.version(1).download("folder") #!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=12reT7rxiRqTERYqeKYx7WGz5deMXjnEo' -O filetxt #!unzip filetxt # Use subprocess to execute the wget command subprocess.run(['wget', '--no-check-certificate', 'https://docs.google.com/uc?export=download&id=12reT7rxiRqTERYqeKYx7WGz5deMXjnEo', '-O', 'filetxt']) subprocess.run(['unzip', 'filetxt']) def seed_everything(seed_value): np.random.seed(seed_value) torch.manual_seed(seed_value) torch.cuda.manual_seed_all(seed_value) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False seed_everything(42) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def download_and_unzip(url, save_path): print(f"Downloading and extracting assets....", end="") # Downloading zip file using urllib package. urlretrieve(url, save_path) try: # Extracting zip file using the zipfile package. with ZipFile(save_path) as z: # Extract ZIP file contents in the same directory. z.extractall(os.path.split(save_path)[0]) print("Done") except Exception as e: print("\nInvalid file.", e) URL = r"https://app.roboflow.com/ds/TZnI5u5spH?key=krcK5FWtuB" asset_zip_path = os.path.join(os.getcwd(), "capstone-design-oyzc3.zip") # Download if asset ZIP does not exist. if not os.path.exists(asset_zip_path): download_and_unzip(URL, asset_zip_path) @dataclass(frozen=True) class TrainingConfig: BATCH_SIZE: int = 25 EPOCHS: int = 20 LEARNING_RATE: float = 0.00005 @dataclass(frozen=True) class DatasetConfig: DATA_ROOT: str = 'DATASET-TRAIN-TEST-1' # TRAIN_DATA_ROOT: str = 'filetxt/DATASET TXT/train/' # TEST_DATA_ROOT: str = 'filetxt/DATASET TXT/test/' @dataclass(frozen=True) class ModelConfig: MODEL_NAME: str = 'microsoft/trocr-small-printed' def visualize(dataset_path): plt.figure(figsize=(15, 3)) for i in range(15): plt.subplot(3, 5, i+1) all_images = os.listdir(f"{dataset_path}/train/train") image = plt.imread(f"{dataset_path}/train/train/{all_images[i]}") plt.imshow(image) plt.axis('off') plt.title(all_images[i].split('.')[0]) plt.show() visualize(DatasetConfig.DATA_ROOT) train_df = pd.read_fwf( os.path.join('train.txt'), header=None ) train_df.rename(columns={0: 'file_name', 1: 'text'}, inplace=True) test_df = pd.read_fwf( os.path.join('test.txt'), header=None ) test_df.rename(columns={0: 'file_name', 1: 'text'}, inplace=True) # Augmentations. train_transforms = transforms.Compose([ transforms.ColorJitter(brightness=.5, hue=.3), transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)), ]) class CustomOCRDataset(Dataset): def __init__(self, root_dir, df, processor, max_target_length=128): self.root_dir = root_dir self.df = df self.processor = processor self.max_target_length = max_target_length def __len__(self): return len(self.df) def __getitem__(self, idx): # The image file name. file_name = self.df['file_name'][idx] # The text (label). text = self.df['text'][idx] # Read the image, apply augmentations, and get the transformed pixels. image = Image.open(self.root_dir + file_name).convert('RGB') image = train_transforms(image) pixel_values = self.processor(image, return_tensors='pt').pixel_values # Pass the text through the tokenizer and get the labels, # i.e. tokenized labels. labels = self.processor.tokenizer( text, padding='max_length', max_length=self.max_target_length ).input_ids # We are using -100 as the padding token. labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels] encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} return encoding processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME) train_dataset = CustomOCRDataset( root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'train/train/'), df=train_df, processor=processor ) valid_dataset = CustomOCRDataset( root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'test/test/'), df=test_df, processor=processor ) model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME) model.to(device) print(model) # Total parameters and trainable parameters. total_params = sum(p.numel() for p in model.parameters()) print(f"{total_params:,} total parameters.") total_trainable_params = sum( p.numel() for p in model.parameters() if p.requires_grad) print(f"{total_trainable_params:,} training parameters.") # Set special tokens used for creating the decoder_input_ids from the labels. model.config.decoder_start_token_id = processor.tokenizer.cls_token_id model.config.pad_token_id = processor.tokenizer.pad_token_id # Set Correct vocab size. model.config.vocab_size = model.config.decoder.vocab_size model.config.eos_token_id = processor.tokenizer.sep_token_id model.config.max_length = 64 model.config.early_stopping = True model.config.no_repeat_ngram_size = 3 model.config.length_penalty = 2.0 model.config.num_beams = 4 optimizer = optim.AdamW( model.parameters(), lr=TrainingConfig.LEARNING_RATE, weight_decay=0.0005 ) cer_metric = evaluate.load('cer') def compute_cer(pred): labels_ids = pred.label_ids pred_ids = pred.predictions pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id label_str = processor.batch_decode(labels_ids, skip_special_tokens=True) cer = cer_metric.compute(predictions=pred_str, references=label_str) return {"cer": cer} training_args = Seq2SeqTrainingArguments( predict_with_generate=True, evaluation_strategy='epoch', per_device_train_batch_size=TrainingConfig.BATCH_SIZE, per_device_eval_batch_size=TrainingConfig.BATCH_SIZE, fp16=False, # fp16_full_eval=True, # fp16_backend='apex', output_dir='seq2seq_model_printed/', logging_strategy='epoch', save_strategy='epoch', save_total_limit=5, report_to='tensorboard', num_train_epochs=TrainingConfig.EPOCHS ) # Initialize trainer. trainer = Seq2SeqTrainer( model=model, tokenizer=processor.image_processor, args=training_args, compute_metrics=compute_cer, train_dataset=train_dataset, eval_dataset=valid_dataset, data_collator=default_data_collator ) res = trainer.train() processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME) trained_model = VisionEncoderDecoderModel.from_pretrained('seq2seq_model_printed/checkpoint-'+str(res.global_step)).to(device) def read_and_show(image_path): """ :param image_path: String, path to the input image. Returns: image: PIL Image. """ image = Image.open(image_path).convert('RGB') return image def ocr(image, processor, model): """ :param image: PIL Image. :param processor: Huggingface OCR processor. :param model: Huggingface OCR model. Returns: generated_text: the OCR'd text string. """ # We can directly perform OCR on cropped images. pixel_values = processor(image, return_tensors='pt').pixel_values.to(device) generated_ids = model.generate(pixel_values) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_text def eval_new_data( data_path=os.path.join(DatasetConfig.DATA_ROOT, 'test/test', '*'), num_samples=50 ): image_paths = glob.glob(data_path) for i, image_path in tqdm(enumerate(image_paths), total=len(image_paths)): if i == num_samples: break image = read_and_show(image_path) text = ocr(image, processor, trained_model) plt.figure(figsize=(7, 4)) plt.imshow(image) plt.title(text) plt.axis('off') plt.show() eval_new_data( data_path=os.path.join(DatasetConfig.DATA_ROOT, 'test/test/', '*'), num_samples=100 )