krplt's picture
fix: v1 model
0e12aee
raw
history blame
No virus
3.37 kB
from torch.utils.data import Dataset
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Trainer, TrainingArguments
from PIL import Image
import pandas as pd
from sklearn.model_selection import train_test_split
class HandwrittenMathDataset(Dataset):
"""
Initialize the class with the provided annotations file, image directory, and processor.
Parameters:
annotations_file (str): The file path to the annotations file.
img_dir (str): The directory path to the images.
processor: The processor object to be used for image processing.
"""
def __init__(self, annotations_file, img_dir, processor, subset="train"):
self.img_labels = pd.read_csv(annotations_file)
self.train_data, self.test_data = train_test_split(self.img_labels, test_size=0.2, random_state=42)
self.data = self.train_data if subset == "train" else self.test_data
self.img_dir = img_dir
self.processor = processor
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_path = self.data.iloc[idx, 0]
image = Image.open(img_path).convert("RGB")
# Ensure the image is processed correctly
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values
label = self.data.iloc[idx, 1]
# Process labels correctly
labels = self.processor.tokenizer(label, padding="max_length", max_length=128, truncation=True,
return_tensors="pt").input_ids
# Replace -100 in the labels as they are not to be computed for loss
labels[labels == self.processor.tokenizer.pad_token_id] = -100
return {"pixel_values": pixel_values.squeeze(), "labels": labels.squeeze()}
def main():
"""
A function to train a model for handwritten text recognition using TrOCRProcessor and VisionEncoderDecoderModel.
"""
annotations_file = './dataset/annotations.csv'
img_dir = './dataset/images/'
model_id = 'microsoft/trocr-base-handwritten'
processor = TrOCRProcessor.from_pretrained(model_id)
model = VisionEncoderDecoderModel.from_pretrained(model_id).to("cuda")
# Set the decoder_start_token_id
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
train_dataset = HandwrittenMathDataset(annotations_file=annotations_file, img_dir=img_dir, processor=processor,
subset="train")
test_dataset = HandwrittenMathDataset(annotations_file=annotations_file, img_dir=img_dir, processor=processor,
subset="test")
training_args = TrainingArguments(
output_dir='./model',
per_device_train_batch_size=2,
num_train_epochs=20,
logging_dir='./training_logs',
logging_steps=10,
save_strategy="epoch",
save_total_limit=1,
weight_decay=0.1,
learning_rate=1e-4,
gradient_checkpointing=True,
gradient_accumulation_steps=2,
evaluation_strategy="epoch"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset
)
trainer.train()
if __name__ == '__main__':
main()