import os |
import torch |
import evaluate |
import numpy as np |
import pandas as pd |
import glob as glob |
import torch.optim as optim |
import matplotlib.pyplot as plt |
import torchvision.transforms as transforms |
import subprocess |
from flask import Flask, request, jsonify |
from PIL import Image |
from zipfile import ZipFile |
from tqdm.notebook import tqdm |
from dataclasses import dataclass |
from torch.utils.data import Dataset |
from urllib.request import urlretrieve |
from transformers import ( |
VisionEncoderDecoderModel, |
TrOCRProcessor, |
Seq2SeqTrainer, |
Seq2SeqTrainingArguments, |
default_data_collator |
) |
from roboflow import Roboflow |
rf = Roboflow(api_key="kGIFR6wPmDow2dHnoXoi") |
project = rf.workspace("capstone-design-oyzc3").project("dataset-train-test") |
dataset = project.version(1).download("folder") |
subprocess.run(['wget', '--no-check-certificate', 'https://docs.google.com/uc?export=download&id=12reT7rxiRqTERYqeKYx7WGz5deMXjnEo', '-O', 'filetxt']) |
subprocess.run(['unzip', 'filetxt']) |
def seed_everything(seed_value): |
np.random.seed(seed_value) |
torch.manual_seed(seed_value) |
torch.cuda.manual_seed_all(seed_value) |
torch.backends.cudnn.deterministic = True |
torch.backends.cudnn.benchmark = False |
seed_everything(42) |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
def download_and_unzip(url, save_path): |
print(f"Downloading and extracting assets....", end="") |
urlretrieve(url, save_path) |
try: |
with ZipFile(save_path) as z: |
z.extractall(os.path.split(save_path)[0]) |
print("Done") |
except Exception as e: |
print("\nInvalid file.", e) |
URL = r"https://app.roboflow.com/ds/TZnI5u5spH?key=krcK5FWtuB" |
asset_zip_path = os.path.join(os.getcwd(), "capstone-design-oyzc3.zip") |
if not os.path.exists(asset_zip_path): |
download_and_unzip(URL, asset_zip_path) |
@dataclass(frozen=True) |
class TrainingConfig: |
BATCH_SIZE: int = 25 |
EPOCHS: int = 20 |
LEARNING_RATE: float = 0.00005 |
@dataclass(frozen=True) |
class DatasetConfig: |
@dataclass(frozen=True) |
class ModelConfig: |
MODEL_NAME: str = 'microsoft/trocr-small-printed' |
def visualize(dataset_path): |
plt.figure(figsize=(15, 3)) |
for i in range(15): |
plt.subplot(3, 5, i+1) |
all_images = os.listdir(f"{dataset_path}/train/train") |
image = plt.imread(f"{dataset_path}/train/train/{all_images[i]}") |
plt.imshow(image) |
plt.axis('off') |
plt.title(all_images[i].split('.')[0]) |
plt.show() |
visualize(DatasetConfig.DATA_ROOT) |
train_df = pd.read_fwf( |
os.path.join('train.txt'), header=None |
) |
train_df.rename(columns={0: 'file_name', 1: 'text'}, inplace=True) |
test_df = pd.read_fwf( |
os.path.join('test.txt'), header=None |
) |
test_df.rename(columns={0: 'file_name', 1: 'text'}, inplace=True) |
train_transforms = transforms.Compose([ |
transforms.ColorJitter(brightness=.5, hue=.3), |
transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)), |
]) |
class CustomOCRDataset(Dataset): |
def __init__(self, root_dir, df, processor, max_target_length=128): |
self.root_dir = root_dir |
self.df = df |
self.processor = processor |
self.max_target_length = max_target_length |
def __len__(self): |
return len(self.df) |
def __getitem__(self, idx): |
file_name = self.df['file_name'][idx] |
text = self.df['text'][idx] |
image = Image.open(self.root_dir + file_name).convert('RGB') |
image = train_transforms(image) |
pixel_values = self.processor(image, return_tensors='pt').pixel_values |
labels = self.processor.tokenizer( |
text, |
padding='max_length', |
max_length=self.max_target_length |
).input_ids |
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels] |
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} |
return encoding |
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME) |
train_dataset = CustomOCRDataset( |
root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'train/train/'), |
df=train_df, |
processor=processor |
) |
valid_dataset = CustomOCRDataset( |
root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'test/test/'), |
df=test_df, |
processor=processor |
) |
model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME) |
model.to(device) |
print(model) |
total_params = sum(p.numel() for p in model.parameters()) |
print(f"{total_params:,} total parameters.") |
total_trainable_params = sum( |
p.numel() for p in model.parameters() if p.requires_grad) |
print(f"{total_trainable_params:,} training parameters.") |
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id |
model.config.pad_token_id = processor.tokenizer.pad_token_id |
model.config.vocab_size = model.config.decoder.vocab_size |
model.config.eos_token_id = processor.tokenizer.sep_token_id |
model.config.max_length = 64 |
model.config.early_stopping = True |
model.config.no_repeat_ngram_size = 3 |
model.config.length_penalty = 2.0 |
model.config.num_beams = 4 |
optimizer = optim.AdamW( |
model.parameters(), lr=TrainingConfig.LEARNING_RATE, weight_decay=0.0005 |
) |
cer_metric = evaluate.load('cer') |
def compute_cer(pred): |
labels_ids = pred.label_ids |
pred_ids = pred.predictions |
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) |
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id |
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True) |
cer = cer_metric.compute(predictions=pred_str, references=label_str) |
return {"cer": cer} |
training_args = Seq2SeqTrainingArguments( |
predict_with_generate=True, |
evaluation_strategy='epoch', |
per_device_train_batch_size=TrainingConfig.BATCH_SIZE, |
per_device_eval_batch_size=TrainingConfig.BATCH_SIZE, |
output_dir='seq2seq_model_printed/', |
logging_strategy='epoch', |
save_strategy='epoch', |
save_total_limit=5, |
report_to='tensorboard', |
num_train_epochs=TrainingConfig.EPOCHS |
) |
trainer = Seq2SeqTrainer( |
model=model, |
tokenizer=processor.feature_extractor, |
args=training_args, |
compute_metrics=compute_cer, |
train_dataset=train_dataset, |
eval_dataset=valid_dataset, |
data_collator=default_data_collator |
) |
res = trainer.train() |
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME) |
trained_model = VisionEncoderDecoderModel.from_pretrained('seq2seq_model_printed/checkpoint-'+str(res.global_step)).to(device) |
def read_and_show(image_path): |
""" |
:param image_path: String, path to the input image. |
Returns: |
image: PIL Image. |
""" |
image = Image.open(image_path).convert('RGB') |
return image |
def ocr(image, processor, model): |
""" |
:param image: PIL Image. |
:param processor: Huggingface OCR processor. |
:param model: Huggingface OCR model. |
Returns: |
generated_text: the OCR'd text string. |
""" |
pixel_values = processor(image, return_tensors='pt').pixel_values.to(device) |
generated_ids = model.generate(pixel_values) |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
return generated_text |
def eval_new_data( |
data_path=os.path.join(DatasetConfig.DATA_ROOT, 'test/test', '*'), |
num_samples=50 |
): |
image_paths = glob.glob(data_path) |
for i, image_path in tqdm(enumerate(image_paths), total=len(image_paths)): |
if i == num_samples: |
break |
image = read_and_show(image_path) |
text = ocr(image, processor, trained_model) |
plt.figure(figsize=(7, 4)) |
plt.imshow(image) |
plt.title(text) |
plt.axis('off') |
plt.show() |
eval_new_data( |
data_path=os.path.join(DatasetConfig.DATA_ROOT, 'test/test/', '*'), |
num_samples=100 |
) |