ibrim commited on
Commit
88147e5
·
verified ·
1 Parent(s): 44443bf

Upload dataset.py

Browse files
Files changed (1) hide show
  1. dataset.py +60 -0
dataset.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import albumentations as A
5
+
6
+ import config as CFG
7
+
8
+
9
+ class CLIPDataset(torch.utils.data.Dataset):
10
+ def __init__(self, image_filenames, captions, tokenizer, transforms):
11
+ """
12
+ image_filenames and cpations must have the same length; so, if there are
13
+ multiple captions for each image, the image_filenames must have repetitive
14
+ file names
15
+ """
16
+
17
+ self.image_filenames = image_filenames
18
+ self.captions = list(captions)
19
+ self.encoded_captions = tokenizer(
20
+ list(captions), padding=True, truncation=True, max_length=CFG.max_length
21
+ )
22
+ self.transforms = transforms
23
+
24
+ def __getitem__(self, idx):
25
+ item = {
26
+ key: torch.tensor(values[idx])
27
+ for key, values in self.encoded_captions.items()
28
+ }
29
+
30
+ image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
31
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
32
+ image = self.transforms(image=image)['image']
33
+ item['image'] = torch.tensor(image).permute(2, 0, 1).float()
34
+ item['caption'] = self.captions[idx]
35
+
36
+ return item
37
+
38
+
39
+ def __len__(self):
40
+ return len(self.captions)
41
+
42
+
43
+
44
+ def get_transforms(mode="train"):
45
+ if mode == "train":
46
+ return A.Compose(
47
+ [
48
+ A.Resize(CFG.size, CFG.size, always_apply=True),
49
+ A.Normalize(max_pixel_value=255.0, always_apply=True),
50
+ ]
51
+ )
52
+ else:
53
+ return A.Compose(
54
+ [
55
+ A.Resize(CFG.size, CFG.size, always_apply=True),
56
+ A.Normalize(max_pixel_value=255.0, always_apply=True),
57
+ ]
58
+ )
59
+
60
+