Edit model card

ByT5 base English fine tuned for OCR Correction

This model is a fine-tuned version of the byt5-base for OCR Correction. ByT5 was introduced in this paper and the idea and code for fine-tuning the model for OCR Correction was taken from here.

Model description

byt5-base-english-ocr-correction is a model that has taken the byt5-base model and fine-tuned it an OCR Correction dataset. The model has been fine-tuned to take an input sentence that has incorrectly transcribed from an OCR model and output a sentence that corrects the errors.

The model was trained by taking the wikitext dataset and adding synthetic OCR errors using nlpaug.

Intended uses & limitations

You can use the model for Text-to-Text Generation to remove errors caused by an OCR model.

How to use

from transformers import T5ForConditionalGeneration
import torch
import nlpaug.augmenter.char as nac

aug = nac.OcrAug(aug_char_p =0.4, aug_word_p = 0.6)
corrected_text = "Life is like a box of chocolates"
augmented_text = aug.augment(corrected_text)

model = T5ForConditionalGeneration.from_pretrained('yelpfeast/byt5-base-english-ocr-correction')

input_ids = torch.tensor([list("Life is like a box of chocolates.".encode("utf-8"))]) + 3  # add 3 for special tokens
labels = torch.tensor([list("La vie est comme une boîte de chocolat.".encode("utf-8"))]) + 3  # add 3 for special tokens

loss = model(input_ids, labels=labels).loss # forward pass

from transformers import T5ForConditionalGeneration, AutoTokenizer
import nlpaug.augmenter.char as nac

aug = nac.OcrAug(aug_char_p =0.4, aug_word_p = 0.6)
corrected_text = "Life is like a box of chocolates"
augmented_text = aug.augment(corrected_text)
print(augmented_text)

model = T5ForConditionalGeneration.from_pretrained('yelpfeast/byt5-base-english-ocr-correction')
tokenizer = AutoTokenizer.from_pretrained("yelpfeast/byt5-base-english-ocr-correction")

inputs = tokenizer(augmented_text, return_tensors="pt", padding=True)

output_sequences = model.generate(

    input_ids=inputs["input_ids"],

    attention_mask=inputs["attention_mask"],

    do_sample=False,  # disable sampling to test if batching affects output

)

print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))

Limitations

The model has been trained on text that has been artificially corrupted to look like OCR errors. These errors may not be similar for all OCR models and hence the model may not do a good job at producing fully correct text.

Downloads last month
84

Dataset used to train yelpfeast/byt5-base-english-ocr-correction