testocr / app.py
triopood's picture
Update app.py
f7def6a verified
import os
import torch
import evaluate
import numpy as np
import pandas as pd
import glob as glob
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import subprocess
from flask import Flask, request, jsonify
from PIL import Image
from zipfile import ZipFile
from tqdm.notebook import tqdm
from dataclasses import dataclass
from torch.utils.data import Dataset
from urllib.request import urlretrieve
from transformers import (
VisionEncoderDecoderModel,
TrOCRProcessor,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
default_data_collator
)
from roboflow import Roboflow
rf = Roboflow(api_key="kGIFR6wPmDow2dHnoXoi")
project = rf.workspace("capstone-design-oyzc3").project("dataset-train-test")
dataset = project.version(1).download("folder")
#!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=12reT7rxiRqTERYqeKYx7WGz5deMXjnEo' -O filetxt
#!unzip filetxt
# Use subprocess to execute the wget command
subprocess.run(['wget', '--no-check-certificate', 'https://docs.google.com/uc?export=download&id=12reT7rxiRqTERYqeKYx7WGz5deMXjnEo', '-O', 'filetxt'])
subprocess.run(['unzip', 'filetxt'])
def seed_everything(seed_value):
np.random.seed(seed_value)
torch.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def download_and_unzip(url, save_path):
print(f"Downloading and extracting assets....", end="")
# Downloading zip file using urllib package.
urlretrieve(url, save_path)
try:
# Extracting zip file using the zipfile package.
with ZipFile(save_path) as z:
# Extract ZIP file contents in the same directory.
z.extractall(os.path.split(save_path)[0])
print("Done")
except Exception as e:
print("\nInvalid file.", e)
URL = r"https://app.roboflow.com/ds/TZnI5u5spH?key=krcK5FWtuB"
asset_zip_path = os.path.join(os.getcwd(), "capstone-design-oyzc3.zip")
# Download if asset ZIP does not exist.
if not os.path.exists(asset_zip_path):
download_and_unzip(URL, asset_zip_path)
@dataclass(frozen=True)
class TrainingConfig:
BATCH_SIZE: int = 25
EPOCHS: int = 20
LEARNING_RATE: float = 0.00005
@dataclass(frozen=True)
class DatasetConfig:
DATA_ROOT: str = 'DATASET-TRAIN-TEST-1'
# TRAIN_DATA_ROOT: str = 'filetxt/DATASET TXT/train/'
# TEST_DATA_ROOT: str = 'filetxt/DATASET TXT/test/'
@dataclass(frozen=True)
class ModelConfig:
MODEL_NAME: str = 'microsoft/trocr-small-printed'
def visualize(dataset_path):
plt.figure(figsize=(15, 3))
for i in range(15):
plt.subplot(3, 5, i+1)
all_images = os.listdir(f"{dataset_path}/train/train")
image = plt.imread(f"{dataset_path}/train/train/{all_images[i]}")
plt.imshow(image)
plt.axis('off')
plt.title(all_images[i].split('.')[0])
plt.show()
visualize(DatasetConfig.DATA_ROOT)
train_df = pd.read_fwf(
os.path.join('train.txt'), header=None
)
train_df.rename(columns={0: 'file_name', 1: 'text'}, inplace=True)
test_df = pd.read_fwf(
os.path.join('test.txt'), header=None
)
test_df.rename(columns={0: 'file_name', 1: 'text'}, inplace=True)
# Augmentations.
train_transforms = transforms.Compose([
transforms.ColorJitter(brightness=.5, hue=.3),
transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
])
class CustomOCRDataset(Dataset):
def __init__(self, root_dir, df, processor, max_target_length=128):
self.root_dir = root_dir
self.df = df
self.processor = processor
self.max_target_length = max_target_length
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# The image file name.
file_name = self.df['file_name'][idx]
# The text (label).
text = self.df['text'][idx]
# Read the image, apply augmentations, and get the transformed pixels.
image = Image.open(self.root_dir + file_name).convert('RGB')
image = train_transforms(image)
pixel_values = self.processor(image, return_tensors='pt').pixel_values
# Pass the text through the tokenizer and get the labels,
# i.e. tokenized labels.
labels = self.processor.tokenizer(
text,
padding='max_length',
max_length=self.max_target_length
).input_ids
# We are using -100 as the padding token.
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
return encoding
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
train_dataset = CustomOCRDataset(
root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'train/train/'),
df=train_df,
processor=processor
)
valid_dataset = CustomOCRDataset(
root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'test/test/'),
df=test_df,
processor=processor
)
model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
model.to(device)
print(model)
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")
# Set special tokens used for creating the decoder_input_ids from the labels.
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# Set Correct vocab size.
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
optimizer = optim.AdamW(
model.parameters(), lr=TrainingConfig.LEARNING_RATE, weight_decay=0.0005
)
cer_metric = evaluate.load('cer')
def compute_cer(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
cer = cer_metric.compute(predictions=pred_str, references=label_str)
return {"cer": cer}
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
evaluation_strategy='epoch',
per_device_train_batch_size=TrainingConfig.BATCH_SIZE,
per_device_eval_batch_size=TrainingConfig.BATCH_SIZE,
fp16=False,
# fp16_full_eval=True,
# fp16_backend='apex',
output_dir='seq2seq_model_printed/',
logging_strategy='epoch',
save_strategy='epoch',
save_total_limit=5,
report_to='tensorboard',
num_train_epochs=TrainingConfig.EPOCHS
)
# Initialize trainer.
trainer = Seq2SeqTrainer(
model=model,
tokenizer=processor.image_processor,
args=training_args,
compute_metrics=compute_cer,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
data_collator=default_data_collator
)
res = trainer.train()
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
trained_model = VisionEncoderDecoderModel.from_pretrained('seq2seq_model_printed/checkpoint-'+str(res.global_step)).to(device)
def read_and_show(image_path):
"""
:param image_path: String, path to the input image.
Returns:
image: PIL Image.
"""
image = Image.open(image_path).convert('RGB')
return image
def ocr(image, processor, model):
"""
:param image: PIL Image.
:param processor: Huggingface OCR processor.
:param model: Huggingface OCR model.
Returns:
generated_text: the OCR'd text string.
"""
# We can directly perform OCR on cropped images.
pixel_values = processor(image, return_tensors='pt').pixel_values.to(device)
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
def eval_new_data(
data_path=os.path.join(DatasetConfig.DATA_ROOT, 'test/test', '*'),
num_samples=50
):
image_paths = glob.glob(data_path)
for i, image_path in tqdm(enumerate(image_paths), total=len(image_paths)):
if i == num_samples:
break
image = read_and_show(image_path)
text = ocr(image, processor, trained_model)
plt.figure(figsize=(7, 4))
plt.imshow(image)
plt.title(text)
plt.axis('off')
plt.show()
eval_new_data(
data_path=os.path.join(DatasetConfig.DATA_ROOT, 'test/test/', '*'),
num_samples=100
)