import glob import pandas as pd from PIL import Image from torch.utils.data import Dataset, random_split from transformers import TrainingArguments, Trainer, ViTFeatureExtractor, BertTokenizer, VisionEncoderDecoderModel import torch import gc import os torch.manual_seed(42) from pathlib import Path # I'm on Linux so you need to convert back to Windows device = 'cuda' if torch.cuda.is_available() else 'cpu' path = '/media/delta/S/Photos/Photo_Data' feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained("google/vit-base-patch16-224-in21k", "bert-base-uncased").to(device) model.config.decoder_start_token_id = tokenizer.cls_token_id model.config.pad_token_id = tokenizer.pad_token_id list_of_csv = glob.glob(f'{path}/*.csv') # to change DF = [] for f in list_of_csv: df = pd.read_csv(f) DF.append(df) ds = pd.concat(DF) class CustomDataset(Dataset): def __init__(self,ds, tokenizer,feature_extractor): self.Pixel_Values = [] self.Labels = [] for i,r in ds.iterrows(): image_path = r['IMAGEPATH'] #A table in csv format with 2 columns IMAGEPATH and CAPTION labels = r['CAPTION'] labels = str(labels) if len(image_path) >=10 and len(labels)>=10: image_path = image_path.split('\\') image_path = image_path[-3:] image_path = Path(os.getcwd(),image_path[0],image_path[1],image_path[2]) image = Image.open(str(image_path)).convert("RGB") pixel_values = feature_extractor(image, return_tensors="pt").pixel_values self.Pixel_Values.append(pixel_values) labels = tokenizer(labels,return_tensors="pt", truncation=True, max_length=128, padding="max_length").input_ids labels[labels == tokenizer.pad_token_id] = -100 self.Labels.append(labels) def __len__(self): return len(self.Pixel_Values) def __getitem__(self, idx): return {"pixel_values": self.Pixel_Values[idx], "labels": self.Labels[idx]} dataset = CustomDataset(ds,tokenizer,feature_extractor) train_size = int(0.9 * len(dataset)) train_dataset, val_dataset = random_split(dataset, [train_size, len(dataset) - train_size]) gc.collect() torch.cuda.empty_cache() training_args = TrainingArguments(output_dir=str(Path(os.getcwd(),'results')), num_train_epochs=6, logging_steps=300, save_steps=14770, per_device_train_batch_size=16, per_device_eval_batch_size=16, gradient_accumulation_steps=1, gradient_checkpointing=False, fp16=False, #doesnt work for this model optim="adamw_torch", #change to adamw_torch if you have have enough memory['adamw_hf', 'adamw_torch', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'sgd', 'adagrad'] warmup_steps=1, weight_decay=0.05, logging_dir='/home/delta/Downloads/logs', # loss graph report_to = 'tensorboard', ) def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"][0] for example in examples]) #0 to change from [1,3,224,224] to [3,224,224] torch stack will add it back depends on the batch size, labels = torch.stack([example["labels"][0] for example in examples]) return {"pixel_values": pixel_values, "labels": labels} Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, data_collator=collate_fn).train() model.save_pretrained('/media/delta/S/model_caption') tokenizer.save_pretrained('/media/delta/S/tokenizer_caption') feature_extractor.save_pretrained('/media/delta/S/feature_extractor_caption')