File size: 4,344 Bytes
c40dd83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import glob
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, random_split
from transformers import  TrainingArguments, Trainer, ViTFeatureExtractor, BertTokenizer, VisionEncoderDecoderModel
import torch
import gc
import os
torch.manual_seed(42)
from pathlib import Path

# I'm on Linux so you need to convert back to Windows

device = 'cuda' if torch.cuda.is_available() else 'cpu'
path = '/media/delta/S/Photos/Photo_Data'
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained("google/vit-base-patch16-224-in21k", "bert-base-uncased").to(device)
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id

list_of_csv = glob.glob(f'{path}/*.csv')  # to change

DF = []
for f in list_of_csv:
    df = pd.read_csv(f)
    DF.append(df)
ds = pd.concat(DF)

class CustomDataset(Dataset):
    def __init__(self,ds, tokenizer,feature_extractor):
        self.Pixel_Values = []
        self.Labels = []
        for i,r in ds.iterrows():
            image_path = r['IMAGEPATH']             #A table in csv format with 2 columns IMAGEPATH and CAPTION
            labels = r['CAPTION']
            labels = str(labels)
            if len(image_path) >=10 and len(labels)>=10:
                image_path = image_path.split('\\')
                image_path = image_path[-3:]
                image_path = Path(os.getcwd(),image_path[0],image_path[1],image_path[2])
                image = Image.open(str(image_path)).convert("RGB")
                pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
                self.Pixel_Values.append(pixel_values)
                labels = tokenizer(labels,return_tensors="pt", truncation=True, max_length=128, padding="max_length").input_ids
                labels[labels == tokenizer.pad_token_id] = -100
                self.Labels.append(labels)
        
    def __len__(self):
        return len(self.Pixel_Values)

    def __getitem__(self, idx):
        return {"pixel_values": self.Pixel_Values[idx], "labels": self.Labels[idx]}

dataset = CustomDataset(ds,tokenizer,feature_extractor)
train_size = int(0.9 * len(dataset))
train_dataset, val_dataset = random_split(dataset, [train_size, len(dataset) - train_size])


gc.collect()
torch.cuda.empty_cache()

training_args = TrainingArguments(output_dir=str(Path(os.getcwd(),'results')), 
                                  num_train_epochs=6, 
                                  logging_steps=300,
                                  save_steps=14770,
                                  per_device_train_batch_size=16, 
                                  per_device_eval_batch_size=16,
                                  gradient_accumulation_steps=1,
                                  gradient_checkpointing=False,
                                  fp16=False,  #doesnt work for this model
                                  optim="adamw_torch", #change to adamw_torch if you have have enough memory['adamw_hf', 'adamw_torch', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'sgd', 'adagrad']
                                  warmup_steps=1, 
                                  weight_decay=0.05, 
                                  logging_dir='/home/delta/Downloads/logs',  # loss graph
                                  report_to = 'tensorboard',
                                  )

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"][0] for example in examples])   #0 to change from [1,3,224,224] to  [3,224,224]  torch stack will add it back depends on the batch size,
    labels = torch.stack([example["labels"][0] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}


Trainer(model=model,  args=training_args, train_dataset=train_dataset, 
        eval_dataset=val_dataset, data_collator=collate_fn).train()
        
model.save_pretrained('/media/delta/S/model_caption')
tokenizer.save_pretrained('/media/delta/S/tokenizer_caption')
feature_extractor.save_pretrained('/media/delta/S/feature_extractor_caption')