Spaces:
Sleeping
Sleeping
Create preprocessing.py
Browse files- preprocessing.py +75 -0
preprocessing.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
from torch.utils.data import Dataset, DataLoader
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
import torch
|
8 |
+
from torch.autograd import Variable
|
9 |
+
import pdb
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
class Text2ImageDataset(Dataset):
|
13 |
+
|
14 |
+
def __init__(self, dataset_dir):
|
15 |
+
self.dataset_dir = dataset_dir
|
16 |
+
with open(os.path.join(self.dataset_dir, 'descriptions.json'), 'r') as file:
|
17 |
+
self.dataset = json.load(file)
|
18 |
+
self.images_path = os.path.join(dataset_dir, 'CUHKSZ_Photos')
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return len(self.dataset)
|
22 |
+
|
23 |
+
def __getitem__(self, idx):
|
24 |
+
if self.dataset is None:
|
25 |
+
with open(os.path.join(self.dataset_dir, 'descriptions.json'), 'r') as file:
|
26 |
+
self.dataset = json.load(file)
|
27 |
+
|
28 |
+
item = self.dataset[idx]
|
29 |
+
examples_class = self.dataset[idx]['text']
|
30 |
+
examples_text = self.dataset[idx]['text']
|
31 |
+
|
32 |
+
image_path = os.path.join(self.images_path, item['file_name'])
|
33 |
+
right_image = Image.open(image_path).resize((128,128))
|
34 |
+
right_embed = np.array(process_caption(examples_text), dtype=float)
|
35 |
+
wrong_image = self.find_wrong_image(examples_class)
|
36 |
+
|
37 |
+
right_image = self.validate_image(right_image)
|
38 |
+
wrong_image = self.validate_image(wrong_image)
|
39 |
+
|
40 |
+
sample = {
|
41 |
+
'right_images': torch.FloatTensor(right_image),
|
42 |
+
'right_embed': torch.FloatTensor(right_embed),
|
43 |
+
'wrong_images': torch.FloatTensor(wrong_image)
|
44 |
+
}
|
45 |
+
|
46 |
+
sample['right_images'] = sample['right_images'].sub_(127.5).div_(127.5)
|
47 |
+
sample['wrong_images'] =sample['wrong_images'].sub_(127.5).div_(127.5)
|
48 |
+
|
49 |
+
return sample
|
50 |
+
|
51 |
+
def find_wrong_image(self, category):
|
52 |
+
idx = np.random.randint(len(self.dataset))
|
53 |
+
examples_class = self.dataset[idx]['class']
|
54 |
+
_category = examples_class
|
55 |
+
|
56 |
+
if _category != category:
|
57 |
+
item = self.dataset[idx]
|
58 |
+
image_path = os.path.join(self.images_path, item['file_name'])
|
59 |
+
return Image.open(image_path).resize((128,128))
|
60 |
+
|
61 |
+
return self.find_wrong_image(category)
|
62 |
+
|
63 |
+
def validate_image(self, img):
|
64 |
+
img = img.convert('RGB')
|
65 |
+
img = np.array(img, dtype=float)
|
66 |
+
if img.shape[2] == 4:
|
67 |
+
img = img[:, :, :3]
|
68 |
+
if len(img.shape) < 3:
|
69 |
+
rgb = np.empty((64, 64, 3), dtype=np.float32)
|
70 |
+
rgb[:, :, 0] = img
|
71 |
+
rgb[:, :, 1] = img
|
72 |
+
rgb[:, :, 2] = img
|
73 |
+
img = rgb
|
74 |
+
|
75 |
+
return img.transpose(2, 0, 1)
|