from torch.utils.data import Dataset from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Trainer, TrainingArguments from PIL import Image import pandas as pd from sklearn.model_selection import train_test_split class HandwrittenMathDataset(Dataset): """ Initialize the class with the provided annotations file, image directory, and processor. Parameters: annotations_file (str): The file path to the annotations file. img_dir (str): The directory path to the images. processor: The processor object to be used for image processing. """ def __init__(self, annotations_file, img_dir, processor, subset="train"): self.img_labels = pd.read_csv(annotations_file) self.train_data, self.test_data = train_test_split(self.img_labels, test_size=0.2, random_state=42) self.data = self.train_data if subset == "train" else self.test_data self.img_dir = img_dir self.processor = processor def __len__(self): return len(self.data) def __getitem__(self, idx): img_path = self.data.iloc[idx, 0] image = Image.open(img_path).convert("RGB") # Ensure the image is processed correctly pixel_values = self.processor(images=image, return_tensors="pt").pixel_values label = self.data.iloc[idx, 1] # Process labels correctly labels = self.processor.tokenizer(label, padding="max_length", max_length=128, truncation=True, return_tensors="pt").input_ids # Replace -100 in the labels as they are not to be computed for loss labels[labels == self.processor.tokenizer.pad_token_id] = -100 return {"pixel_values": pixel_values.squeeze(), "labels": labels.squeeze()} def main(): """ A function to train a model for handwritten text recognition using TrOCRProcessor and VisionEncoderDecoderModel. """ annotations_file = './dataset/annotations.csv' img_dir = './dataset/images/' model_id = 'microsoft/trocr-base-handwritten' processor = TrOCRProcessor.from_pretrained(model_id) model = VisionEncoderDecoderModel.from_pretrained(model_id).to("cuda") # Set the decoder_start_token_id model.config.decoder_start_token_id = processor.tokenizer.cls_token_id model.config.pad_token_id = processor.tokenizer.pad_token_id train_dataset = HandwrittenMathDataset(annotations_file=annotations_file, img_dir=img_dir, processor=processor, subset="train") test_dataset = HandwrittenMathDataset(annotations_file=annotations_file, img_dir=img_dir, processor=processor, subset="test") training_args = TrainingArguments( output_dir='./model', per_device_train_batch_size=2, num_train_epochs=20, logging_dir='./training_logs', logging_steps=10, save_strategy="epoch", save_total_limit=1, weight_decay=0.1, learning_rate=1e-4, gradient_checkpointing=True, gradient_accumulation_steps=2, evaluation_strategy="epoch" ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=test_dataset ) trainer.train() if __name__ == '__main__': main()