File size: 3,367 Bytes
fd52a0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e12aee
fd52a0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e12aee
fd52a0b
 
 
 
0e12aee
fd52a0b
 
0e12aee
 
fd52a0b
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from torch.utils.data import Dataset
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Trainer, TrainingArguments
from PIL import Image
import pandas as pd
from sklearn.model_selection import train_test_split


class HandwrittenMathDataset(Dataset):
    """
    Initialize the class with the provided annotations file, image directory, and processor.

    Parameters:
        annotations_file (str): The file path to the annotations file.
        img_dir (str): The directory path to the images.
        processor: The processor object to be used for image processing.
    """
    def __init__(self, annotations_file, img_dir, processor, subset="train"):
        self.img_labels = pd.read_csv(annotations_file)
        self.train_data, self.test_data = train_test_split(self.img_labels, test_size=0.2, random_state=42)
        self.data = self.train_data if subset == "train" else self.test_data
        self.img_dir = img_dir
        self.processor = processor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx, 0]
        image = Image.open(img_path).convert("RGB")
        # Ensure the image is processed correctly
        pixel_values = self.processor(images=image, return_tensors="pt").pixel_values
        label = self.data.iloc[idx, 1]
        # Process labels correctly
        labels = self.processor.tokenizer(label, padding="max_length", max_length=128, truncation=True,
                                          return_tensors="pt").input_ids
        # Replace -100 in the labels as they are not to be computed for loss
        labels[labels == self.processor.tokenizer.pad_token_id] = -100

        return {"pixel_values": pixel_values.squeeze(), "labels": labels.squeeze()}


def main():
    """
    A function to train a model for handwritten text recognition using TrOCRProcessor and VisionEncoderDecoderModel.
    """
    annotations_file = './dataset/annotations.csv'
    img_dir = './dataset/images/'
    model_id = 'microsoft/trocr-base-handwritten'

    processor = TrOCRProcessor.from_pretrained(model_id)
    model = VisionEncoderDecoderModel.from_pretrained(model_id).to("cuda")

    # Set the decoder_start_token_id
    model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
    model.config.pad_token_id = processor.tokenizer.pad_token_id

    train_dataset = HandwrittenMathDataset(annotations_file=annotations_file, img_dir=img_dir, processor=processor,
                                           subset="train")
    test_dataset = HandwrittenMathDataset(annotations_file=annotations_file, img_dir=img_dir, processor=processor,
                                          subset="test")

    training_args = TrainingArguments(
        output_dir='./model',
        per_device_train_batch_size=2,
        num_train_epochs=20,
        logging_dir='./training_logs',
        logging_steps=10,
        save_strategy="epoch",
        save_total_limit=1,
        weight_decay=0.1,
        learning_rate=1e-4,
        gradient_checkpointing=True,
        gradient_accumulation_steps=2,
        evaluation_strategy="epoch"
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset
    )

    trainer.train()


if __name__ == '__main__':
    main()