import pandas as pd from datasets import load_dataset from sklearn.model_selection import train_test_split import torch from torch.utils.data import Dataset, DataLoader from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainingArguments, Seq2SeqTrainer from PIL import Image import io import numpy as np device = 'mps:0' # Load the dataset and filter for Latin entries dataset = load_dataset("CATMuS/medieval", split='train') # latin_dataset = dataset.filter(lambda example: example['language'] == 'Latin') latin_dataset = dataset.filter(lambda example: example['language'] == 'Latin' and example['script_type'] == 'Caroline') print(latin_dataset) # Convert to pandas DataFrame for easier manipulation df = pd.DataFrame(latin_dataset) # Split the data into training and testing sets train_df, test_df = train_test_split(df, test_size=0.2) train_df.reset_index(drop=True, inplace=True) test_df.reset_index(drop=True, inplace=True) # Define the dataset class class HandwrittenTextDataset(Dataset): def __init__(self, df, processor, max_target_length=128): self.df = df self.processor = processor self.max_target_length = max_target_length def __len__(self): return len(self.df) def __getitem__(self, idx): image_data = self.df['im'][idx] text = self.df['text'][idx] # Convert array to PIL image image = Image.fromarray(np.array(image_data)).convert("RGB") # Prepare image (i.e., resize + normalize) pixel_values = self.processor(images=image, return_tensors="pt").pixel_values # Add labels (input_ids) by encoding the text labels = self.processor.tokenizer(text, padding="max_length", max_length=self.max_target_length, truncation=True).input_ids # Important: make sure that PAD tokens are ignored by the loss function labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels] encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} return encoding # Instantiate processor and dataset processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") train_dataset = HandwrittenTextDataset(df=train_df, processor=processor) eval_dataset = HandwrittenTextDataset(df=test_df, processor=processor) # Create corresponding dataloaders train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True) eval_dataloader = DataLoader(eval_dataset, batch_size=4) # Load the model model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") # Set special tokens used for creating the decoder_input_ids from the labels model.config.decoder_start_token_id = processor.tokenizer.cls_token_id model.config.pad_token_id = processor.tokenizer.pad_token_id # Make sure vocab size is set correctly model.config.vocab_size = model.config.decoder.vocab_size # Set beam search parameters model.config.eos_token_id = processor.tokenizer.sep_token_id model.config.max_length = 64 model.config.early_stopping = True model.config.no_repeat_ngram_size = 3 model.config.length_penalty = 2.0 model.config.num_beams = 4 # Training arguments training_args = Seq2SeqTrainingArguments( output_dir="./results", per_device_train_batch_size=4, num_train_epochs=10, logging_steps=1000, save_steps=1000, evaluation_strategy="steps", save_total_limit=2, predict_with_generate=True, fp16=False, # Set to True if using a compatible GPU ) # Trainer trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, ) # Train the model trainer.train() # After training, save both the model and the processor model.save_pretrained("./finetuned_model") processor.save_pretrained("./finetuned_model") from datasets import load_metric cer_metric = load_metric("cer") def compute_cer(pred_ids, label_ids): pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) label_ids[label_ids == -100] = processor.tokenizer.pad_token_id label_str = processor.batch_decode(label_ids, skip_special_tokens=True) cer = cer_metric.compute(predictions=pred_str, references=label_str) return cer # Evaluation model.eval() valid_cer = 0.0 with torch.no_grad(): for batch in eval_dataloader: # Run batch generation outputs = model.generate(batch["pixel_values"].to(device)) # Compute metrics cer = compute_cer(pred_ids=outputs, label_ids=batch["labels"]) valid_cer += cer print("Validation CER:", valid_cer / len(eval_dataloader))