krplt's picture
fix: v1 model
0e12aee
raw
history blame contribute delete
No virus
3.37 kB
from torch.utils.data import Dataset
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Trainer, TrainingArguments
from PIL import Image
import pandas as pd
from sklearn.model_selection import train_test_split
class HandwrittenMathDataset(Dataset):
"""
Initialize the class with the provided annotations file, image directory, and processor.
Parameters:
annotations_file (str): The file path to the annotations file.
img_dir (str): The directory path to the images.
processor: The processor object to be used for image processing.
"""
def __init__(self, annotations_file, img_dir, processor, subset="train"):
self.img_labels = pd.read_csv(annotations_file)
self.train_data, self.test_data = train_test_split(self.img_labels, test_size=0.2, random_state=42)
self.data = self.train_data if subset == "train" else self.test_data
self.img_dir = img_dir
self.processor = processor
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_path = self.data.iloc[idx, 0]
image = Image.open(img_path).convert("RGB")
# Ensure the image is processed correctly
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values
label = self.data.iloc[idx, 1]
# Process labels correctly
labels = self.processor.tokenizer(label, padding="max_length", max_length=128, truncation=True,
return_tensors="pt").input_ids
# Replace -100 in the labels as they are not to be computed for loss
labels[labels == self.processor.tokenizer.pad_token_id] = -100
return {"pixel_values": pixel_values.squeeze(), "labels": labels.squeeze()}
def main():
"""
A function to train a model for handwritten text recognition using TrOCRProcessor and VisionEncoderDecoderModel.
"""
annotations_file = './dataset/annotations.csv'
img_dir = './dataset/images/'
model_id = 'microsoft/trocr-base-handwritten'
processor = TrOCRProcessor.from_pretrained(model_id)
model = VisionEncoderDecoderModel.from_pretrained(model_id).to("cuda")
# Set the decoder_start_token_id
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
train_dataset = HandwrittenMathDataset(annotations_file=annotations_file, img_dir=img_dir, processor=processor,
subset="train")
test_dataset = HandwrittenMathDataset(annotations_file=annotations_file, img_dir=img_dir, processor=processor,
subset="test")
training_args = TrainingArguments(
output_dir='./model',
per_device_train_batch_size=2,
num_train_epochs=20,
logging_dir='./training_logs',
logging_steps=10,
save_strategy="epoch",
save_total_limit=1,
weight_decay=0.1,
learning_rate=1e-4,
gradient_checkpointing=True,
gradient_accumulation_steps=2,
evaluation_strategy="epoch"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset
)
trainer.train()
if __name__ == '__main__':
main()