## Welcome!
to the repo for

*Learning the Legibility of Visual Text Perturbations* (EACL 2023)

by Dev Seth, Rickard Stureborg, Danish Pruthi and Bhuwan Dhingra

### A `LEGIT` Introduction
This notebook provides a helpful starting point to interact with the datasets and models presented in the Learning Legibility paper.

All assets are hosted on the HuggingFace Hub and can be used with the `transformers` and `datasets` libraries: 
  - TrOCR-MT Model: https://huggingface.co/dvsth/LEGIT-TrOCR-MT 
  - LEGIT Dataset: https://huggingface.co/datasets/dvsth/LEGIT
  - Perturbed Jigsaw Dataset: https://huggingface.co/datasets/dvsth/LEGIT-VIPER-Jigsaw-Toxic-Comment-Perturbed

**For an interactive preview of the perturbation process and legibility assessment model, run `demo.py` using the command `python demo.py` (will open a browser-based interface). The demo allows you to perturb a word with your chosen attack parameters, then see the model's legibility estimate for the generated perturbations.**

##### Setup

In [1]:
# external imports -- use pip or conda to install these packages
import torch
from transformers import TrOCRProcessor, AutoModel, TrainingArguments
from datasets import load_dataset

# local imports
from classes.LegibilityModel import LegibilityModel
from classes.Trainer import MultiTaskTrainer
from classes.Metrics import binary_classification_metric, ranking_metric

#### Loading the Model and Dataset

In [2]:
# load the model schema and pretrained weights
# (this may take some time to download)
model = AutoModel.from_pretrained("dvsth/LEGIT-TrOCR-MT", revision='main', trust_remote_code=True)

Interactive dataset preview available [here](https://huggingface.co/datasets/dvsth/LEGIT/viewer/dvsth--LEGIT/test).

In [3]:
dataset = load_dataset('dvsth/LEGIT').with_format('torch')

Using custom data configuration dvsth--LEGIT-d84a4d72774d3652
Found cached dataset parquet (/Users/dvsth/.cache/huggingface/datasets/dvsth___parquet/dvsth--LEGIT-d84a4d72774d3652/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/3 [00:00<?, ?it/s]

#### Training/Eval Loop

##### Trainer setup

In [4]:
# preprocessor provides image normalization and resizing
preprocessor = TrOCRProcessor.from_pretrained(
    "microsoft/trocr-base-handwritten")

# apply preprocessing batch-wise
def collate_fn(data):
    return {
        'choice': torch.tensor([d['choice'].item() for d in data]),
        'img0': preprocessor([d['img0'] for d in data], return_tensors='pt')['pixel_values'],
        'img1': preprocessor([d['img1'] for d in data], return_tensors='pt')['pixel_values']
    }


train_args = TrainingArguments(
    output_dir=f'runs',             # change this to a unique path for each run, e.g. f'runs/{run_id}'
    overwrite_output_dir=True,
    num_train_epochs=5,             # we found 3 epochs to be sufficient for convergence on the base models
    per_device_train_batch_size=26, # fits on 1 x NVIDIA A6000, 48GB VRAM
    per_device_eval_batch_size=26,  # can be increased to 32
    gradient_accumulation_steps=2,  # increase this to fit on a smaller GPU
    warmup_steps=0,             
    weight_decay=0.0,
    learning_rate=1e-5,             # we found this to be the best initial learning rate for the base models
    save_strategy="steps",
    save_steps=200,
    eval_steps=200,
    evaluation_strategy="steps",
    logging_strategy='steps',
    logging_steps=50,
    fp16=False,                     
    load_best_model_at_end=True,    # load the best model at the end of training based on validation F1
    metric_for_best_model='f1_score')

trainer = MultiTaskTrainer(
    model=model,
    compute_metrics=binary_classification_metric, # check out metrics.py for a list of metrics
    args=train_args,
    data_collator=collate_fn,
    train_dataset=dataset['train'],
    eval_dataset=dataset['valid'])


Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


##### Generate predictions and compute metrics

In [5]:
predictions = trainer.predict(dataset['test'].select(range(100))) # takes ~1-2 minutes on a laptop CPU
print(predictions.metrics)

The following columns in the test set don't have a corresponding argument in `LegibilityModel.forward` and have been ignored: k, word1, word0, n, model1, word, n1, k1, model0. If k, word1, word0, n, model1, word, n1, k1, model0 are not expected by `LegibilityModel.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 100
  Batch size = 26


  0%|          | 0/4 [00:00<?, ?it/s]

{'test_loss': 0.5344929695129395, 'test_precision': 0.9479166567925349, 'test_recall': 0.8921568539984622, 'test_accuracy': 0.8787878721303949, 'test_f1_score': 0.9191914103665608, 'test_runtime': 47.6671, 'test_samples_per_second': 2.098, 'test_steps_per_second': 0.084}
