|
import os |
|
import torch |
|
import evaluate |
|
import numpy as np |
|
import pandas as pd |
|
import glob as glob |
|
import torch.optim as optim |
|
import matplotlib.pyplot as plt |
|
import torchvision.transforms as transforms |
|
import subprocess |
|
|
|
from flask import Flask, request, jsonify |
|
from PIL import Image |
|
from zipfile import ZipFile |
|
from tqdm.notebook import tqdm |
|
from dataclasses import dataclass |
|
from torch.utils.data import Dataset |
|
from urllib.request import urlretrieve |
|
from transformers import ( |
|
VisionEncoderDecoderModel, |
|
TrOCRProcessor, |
|
Seq2SeqTrainer, |
|
Seq2SeqTrainingArguments, |
|
default_data_collator |
|
) |
|
from roboflow import Roboflow |
|
rf = Roboflow(api_key="kGIFR6wPmDow2dHnoXoi") |
|
project = rf.workspace("capstone-design-oyzc3").project("dataset-train-test") |
|
dataset = project.version(1).download("folder") |
|
|
|
|
|
|
|
|
|
|
|
subprocess.run(['wget', '--no-check-certificate', 'https://docs.google.com/uc?export=download&id=12reT7rxiRqTERYqeKYx7WGz5deMXjnEo', '-O', 'filetxt']) |
|
subprocess.run(['unzip', 'filetxt']) |
|
|
|
def seed_everything(seed_value): |
|
np.random.seed(seed_value) |
|
torch.manual_seed(seed_value) |
|
torch.cuda.manual_seed_all(seed_value) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
seed_everything(42) |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
def download_and_unzip(url, save_path): |
|
print(f"Downloading and extracting assets....", end="") |
|
|
|
|
|
|
|
urlretrieve(url, save_path) |
|
|
|
|
|
try: |
|
|
|
with ZipFile(save_path) as z: |
|
|
|
z.extractall(os.path.split(save_path)[0]) |
|
|
|
|
|
print("Done") |
|
|
|
|
|
except Exception as e: |
|
print("\nInvalid file.", e) |
|
|
|
URL = r"https://app.roboflow.com/ds/TZnI5u5spH?key=krcK5FWtuB" |
|
asset_zip_path = os.path.join(os.getcwd(), "capstone-design-oyzc3.zip") |
|
|
|
|
|
if not os.path.exists(asset_zip_path): |
|
download_and_unzip(URL, asset_zip_path) |
|
|
|
@dataclass(frozen=True) |
|
class TrainingConfig: |
|
BATCH_SIZE: int = 25 |
|
EPOCHS: int = 20 |
|
LEARNING_RATE: float = 0.00005 |
|
|
|
@dataclass(frozen=True) |
|
class DatasetConfig: |
|
DATA_ROOT: str = 'DATASET-TRAIN-TEST-1' |
|
|
|
|
|
@dataclass(frozen=True) |
|
class ModelConfig: |
|
MODEL_NAME: str = 'microsoft/trocr-small-printed' |
|
|
|
def visualize(dataset_path): |
|
plt.figure(figsize=(15, 3)) |
|
for i in range(15): |
|
plt.subplot(3, 5, i+1) |
|
all_images = os.listdir(f"{dataset_path}/train/train") |
|
image = plt.imread(f"{dataset_path}/train/train/{all_images[i]}") |
|
plt.imshow(image) |
|
plt.axis('off') |
|
plt.title(all_images[i].split('.')[0]) |
|
plt.show() |
|
|
|
|
|
visualize(DatasetConfig.DATA_ROOT) |
|
|
|
train_df = pd.read_fwf( |
|
os.path.join('train.txt'), header=None |
|
) |
|
train_df.rename(columns={0: 'file_name', 1: 'text'}, inplace=True) |
|
test_df = pd.read_fwf( |
|
os.path.join('test.txt'), header=None |
|
) |
|
test_df.rename(columns={0: 'file_name', 1: 'text'}, inplace=True) |
|
|
|
|
|
train_transforms = transforms.Compose([ |
|
transforms.ColorJitter(brightness=.5, hue=.3), |
|
transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)), |
|
]) |
|
|
|
class CustomOCRDataset(Dataset): |
|
def __init__(self, root_dir, df, processor, max_target_length=128): |
|
self.root_dir = root_dir |
|
self.df = df |
|
self.processor = processor |
|
self.max_target_length = max_target_length |
|
|
|
|
|
def __len__(self): |
|
return len(self.df) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
file_name = self.df['file_name'][idx] |
|
|
|
text = self.df['text'][idx] |
|
|
|
image = Image.open(self.root_dir + file_name).convert('RGB') |
|
image = train_transforms(image) |
|
pixel_values = self.processor(image, return_tensors='pt').pixel_values |
|
|
|
|
|
labels = self.processor.tokenizer( |
|
text, |
|
padding='max_length', |
|
max_length=self.max_target_length |
|
).input_ids |
|
|
|
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels] |
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} |
|
return encoding |
|
|
|
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME) |
|
train_dataset = CustomOCRDataset( |
|
root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'train/train/'), |
|
df=train_df, |
|
processor=processor |
|
) |
|
valid_dataset = CustomOCRDataset( |
|
root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'test/test/'), |
|
df=test_df, |
|
processor=processor |
|
) |
|
|
|
model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME) |
|
model.to(device) |
|
print(model) |
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
print(f"{total_params:,} total parameters.") |
|
total_trainable_params = sum( |
|
p.numel() for p in model.parameters() if p.requires_grad) |
|
print(f"{total_trainable_params:,} training parameters.") |
|
|
|
|
|
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id |
|
model.config.pad_token_id = processor.tokenizer.pad_token_id |
|
|
|
model.config.vocab_size = model.config.decoder.vocab_size |
|
model.config.eos_token_id = processor.tokenizer.sep_token_id |
|
|
|
|
|
model.config.max_length = 64 |
|
model.config.early_stopping = True |
|
model.config.no_repeat_ngram_size = 3 |
|
model.config.length_penalty = 2.0 |
|
model.config.num_beams = 4 |
|
|
|
optimizer = optim.AdamW( |
|
model.parameters(), lr=TrainingConfig.LEARNING_RATE, weight_decay=0.0005 |
|
) |
|
|
|
cer_metric = evaluate.load('cer') |
|
|
|
|
|
def compute_cer(pred): |
|
labels_ids = pred.label_ids |
|
pred_ids = pred.predictions |
|
|
|
|
|
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) |
|
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id |
|
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True) |
|
|
|
|
|
cer = cer_metric.compute(predictions=pred_str, references=label_str) |
|
|
|
|
|
return {"cer": cer} |
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
predict_with_generate=True, |
|
evaluation_strategy='epoch', |
|
per_device_train_batch_size=TrainingConfig.BATCH_SIZE, |
|
per_device_eval_batch_size=TrainingConfig.BATCH_SIZE, |
|
|
|
output_dir='seq2seq_model_printed/', |
|
logging_strategy='epoch', |
|
save_strategy='epoch', |
|
save_total_limit=5, |
|
report_to='tensorboard', |
|
num_train_epochs=TrainingConfig.EPOCHS |
|
) |
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
model=model, |
|
tokenizer=processor.feature_extractor, |
|
args=training_args, |
|
compute_metrics=compute_cer, |
|
train_dataset=train_dataset, |
|
eval_dataset=valid_dataset, |
|
data_collator=default_data_collator |
|
) |
|
|
|
res = trainer.train() |
|
|
|
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME) |
|
trained_model = VisionEncoderDecoderModel.from_pretrained('seq2seq_model_printed/checkpoint-'+str(res.global_step)).to(device) |
|
|
|
def read_and_show(image_path): |
|
""" |
|
:param image_path: String, path to the input image. |
|
|
|
|
|
Returns: |
|
image: PIL Image. |
|
""" |
|
image = Image.open(image_path).convert('RGB') |
|
return image |
|
|
|
def ocr(image, processor, model): |
|
""" |
|
:param image: PIL Image. |
|
:param processor: Huggingface OCR processor. |
|
:param model: Huggingface OCR model. |
|
|
|
|
|
Returns: |
|
generated_text: the OCR'd text string. |
|
""" |
|
|
|
pixel_values = processor(image, return_tensors='pt').pixel_values.to(device) |
|
generated_ids = model.generate(pixel_values) |
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
return generated_text |
|
|
|
def eval_new_data( |
|
data_path=os.path.join(DatasetConfig.DATA_ROOT, 'test/test', '*'), |
|
num_samples=50 |
|
): |
|
image_paths = glob.glob(data_path) |
|
for i, image_path in tqdm(enumerate(image_paths), total=len(image_paths)): |
|
if i == num_samples: |
|
break |
|
image = read_and_show(image_path) |
|
text = ocr(image, processor, trained_model) |
|
plt.figure(figsize=(7, 4)) |
|
plt.imshow(image) |
|
plt.title(text) |
|
plt.axis('off') |
|
plt.show() |
|
|
|
eval_new_data( |
|
data_path=os.path.join(DatasetConfig.DATA_ROOT, 'test/test/', '*'), |
|
num_samples=100 |
|
) |
|
|
|
|