GuardrailDetection / FinetuneImageCaptioning.py
dentadelta123's picture
Can we use image caption to estimate the photo
c40dd83
import glob
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, random_split
from transformers import TrainingArguments, Trainer, ViTFeatureExtractor, BertTokenizer, VisionEncoderDecoderModel
import torch
import gc
import os
torch.manual_seed(42)
from pathlib import Path
# I'm on Linux so you need to convert back to Windows
device = 'cuda' if torch.cuda.is_available() else 'cpu'
path = '/media/delta/S/Photos/Photo_Data'
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained("google/vit-base-patch16-224-in21k", "bert-base-uncased").to(device)
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
list_of_csv = glob.glob(f'{path}/*.csv') # to change
DF = []
for f in list_of_csv:
df = pd.read_csv(f)
DF.append(df)
ds = pd.concat(DF)
class CustomDataset(Dataset):
def __init__(self,ds, tokenizer,feature_extractor):
self.Pixel_Values = []
self.Labels = []
for i,r in ds.iterrows():
image_path = r['IMAGEPATH'] #A table in csv format with 2 columns IMAGEPATH and CAPTION
labels = r['CAPTION']
labels = str(labels)
if len(image_path) >=10 and len(labels)>=10:
image_path = image_path.split('\\')
image_path = image_path[-3:]
image_path = Path(os.getcwd(),image_path[0],image_path[1],image_path[2])
image = Image.open(str(image_path)).convert("RGB")
pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
self.Pixel_Values.append(pixel_values)
labels = tokenizer(labels,return_tensors="pt", truncation=True, max_length=128, padding="max_length").input_ids
labels[labels == tokenizer.pad_token_id] = -100
self.Labels.append(labels)
def __len__(self):
return len(self.Pixel_Values)
def __getitem__(self, idx):
return {"pixel_values": self.Pixel_Values[idx], "labels": self.Labels[idx]}
dataset = CustomDataset(ds,tokenizer,feature_extractor)
train_size = int(0.9 * len(dataset))
train_dataset, val_dataset = random_split(dataset, [train_size, len(dataset) - train_size])
gc.collect()
torch.cuda.empty_cache()
training_args = TrainingArguments(output_dir=str(Path(os.getcwd(),'results')),
num_train_epochs=6,
logging_steps=300,
save_steps=14770,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
fp16=False, #doesnt work for this model
optim="adamw_torch", #change to adamw_torch if you have have enough memory['adamw_hf', 'adamw_torch', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'sgd', 'adagrad']
warmup_steps=1,
weight_decay=0.05,
logging_dir='/home/delta/Downloads/logs', # loss graph
report_to = 'tensorboard',
)
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"][0] for example in examples]) #0 to change from [1,3,224,224] to [3,224,224] torch stack will add it back depends on the batch size,
labels = torch.stack([example["labels"][0] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
Trainer(model=model, args=training_args, train_dataset=train_dataset,
eval_dataset=val_dataset, data_collator=collate_fn).train()
model.save_pretrained('/media/delta/S/model_caption')
tokenizer.save_pretrained('/media/delta/S/tokenizer_caption')
feature_extractor.save_pretrained('/media/delta/S/feature_extractor_caption')