dentadelta123 commited on
Commit
c40dd83
1 Parent(s): f995acb

Can we use image caption to estimate the photo

Browse files
Files changed (1) hide show
  1. FinetuneImageCaptioning.py +90 -0
FinetuneImageCaptioning.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import pandas as pd
3
+ from PIL import Image
4
+ from torch.utils.data import Dataset, random_split
5
+ from transformers import TrainingArguments, Trainer, ViTFeatureExtractor, BertTokenizer, VisionEncoderDecoderModel
6
+ import torch
7
+ import gc
8
+ import os
9
+ torch.manual_seed(42)
10
+ from pathlib import Path
11
+
12
+ # I'm on Linux so you need to convert back to Windows
13
+
14
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+ path = '/media/delta/S/Photos/Photo_Data'
16
+ feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
17
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
18
+ model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained("google/vit-base-patch16-224-in21k", "bert-base-uncased").to(device)
19
+ model.config.decoder_start_token_id = tokenizer.cls_token_id
20
+ model.config.pad_token_id = tokenizer.pad_token_id
21
+
22
+ list_of_csv = glob.glob(f'{path}/*.csv') # to change
23
+
24
+ DF = []
25
+ for f in list_of_csv:
26
+ df = pd.read_csv(f)
27
+ DF.append(df)
28
+ ds = pd.concat(DF)
29
+
30
+ class CustomDataset(Dataset):
31
+ def __init__(self,ds, tokenizer,feature_extractor):
32
+ self.Pixel_Values = []
33
+ self.Labels = []
34
+ for i,r in ds.iterrows():
35
+ image_path = r['IMAGEPATH'] #A table in csv format with 2 columns IMAGEPATH and CAPTION
36
+ labels = r['CAPTION']
37
+ labels = str(labels)
38
+ if len(image_path) >=10 and len(labels)>=10:
39
+ image_path = image_path.split('\\')
40
+ image_path = image_path[-3:]
41
+ image_path = Path(os.getcwd(),image_path[0],image_path[1],image_path[2])
42
+ image = Image.open(str(image_path)).convert("RGB")
43
+ pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
44
+ self.Pixel_Values.append(pixel_values)
45
+ labels = tokenizer(labels,return_tensors="pt", truncation=True, max_length=128, padding="max_length").input_ids
46
+ labels[labels == tokenizer.pad_token_id] = -100
47
+ self.Labels.append(labels)
48
+
49
+ def __len__(self):
50
+ return len(self.Pixel_Values)
51
+
52
+ def __getitem__(self, idx):
53
+ return {"pixel_values": self.Pixel_Values[idx], "labels": self.Labels[idx]}
54
+
55
+ dataset = CustomDataset(ds,tokenizer,feature_extractor)
56
+ train_size = int(0.9 * len(dataset))
57
+ train_dataset, val_dataset = random_split(dataset, [train_size, len(dataset) - train_size])
58
+
59
+
60
+ gc.collect()
61
+ torch.cuda.empty_cache()
62
+
63
+ training_args = TrainingArguments(output_dir=str(Path(os.getcwd(),'results')),
64
+ num_train_epochs=6,
65
+ logging_steps=300,
66
+ save_steps=14770,
67
+ per_device_train_batch_size=16,
68
+ per_device_eval_batch_size=16,
69
+ gradient_accumulation_steps=1,
70
+ gradient_checkpointing=False,
71
+ fp16=False, #doesnt work for this model
72
+ optim="adamw_torch", #change to adamw_torch if you have have enough memory['adamw_hf', 'adamw_torch', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'sgd', 'adagrad']
73
+ warmup_steps=1,
74
+ weight_decay=0.05,
75
+ logging_dir='/home/delta/Downloads/logs', # loss graph
76
+ report_to = 'tensorboard',
77
+ )
78
+
79
+ def collate_fn(examples):
80
+ pixel_values = torch.stack([example["pixel_values"][0] for example in examples]) #0 to change from [1,3,224,224] to [3,224,224] torch stack will add it back depends on the batch size,
81
+ labels = torch.stack([example["labels"][0] for example in examples])
82
+ return {"pixel_values": pixel_values, "labels": labels}
83
+
84
+
85
+ Trainer(model=model, args=training_args, train_dataset=train_dataset,
86
+ eval_dataset=val_dataset, data_collator=collate_fn).train()
87
+
88
+ model.save_pretrained('/media/delta/S/model_caption')
89
+ tokenizer.save_pretrained('/media/delta/S/tokenizer_caption')
90
+ feature_extractor.save_pretrained('/media/delta/S/feature_extractor_caption')