File size: 2,376 Bytes
2c8b554
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

import os
import time
import random

import h5py
import numpy as np
from PIL import Image
from tqdm import tqdm
import joblib

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from lib.utils import preprocess_image
from lib.utils import preprocess_image, grid_positions, upscale_positions
from lib.dataloaders.datasetPhotoTourism_ipr import PhotoTourismIPR
from lib.dataloaders.datasetPhotoTourism_real import PhotoTourism

from sys import exit, argv
import cv2
import csv

np.random.seed(0)


class PhotoTourismCombined(Dataset):
    def __init__(self, base_path, preprocessing, ipr_pref=0.5, train=True, cropSize=256):
        self.base_path = base_path
        self.preprocessing = preprocessing
        self.cropSize=cropSize

        self.ipr_pref = ipr_pref

        # self.dataset_len = 0
        # self.dataset_len2 = 0

        print("[INFO] Building Original Dataset")
        self.PTReal = PhotoTourism(base_path, preprocessing=preprocessing, train=train, image_size=cropSize)
        self.PTReal.build_dataset()

        # self.dataset_len1 = len(self.PTReal)
        # print("size 1:",len(self.PTReal))
        # for _ in self.PTReal:
        #     pass
        # print("size 2:",len(self.PTReal))
        self.dataset_len1 = len(self.PTReal)
        # joblib.dump(self.PTReal.dataset, os.path.join(self.base_path, "orig_PT_2.gz"), 3)

        print("[INFO] Building IPR Dataset")
        self.PTipr = PhotoTourismIPR(base_path, preprocessing=preprocessing, train=train, cropSize=cropSize)
        self.PTipr.build_dataset()

        # self.dataset_len2 = len(self.PTipr)
        # print("size 1:",len(self.PTipr))
        # for _ in self.PTipr:
        #     pass
        # print("size 2:",len(self.PTipr))
        self.dataset_len2 = len(self.PTipr)

        # joblib.dump((self.PTipr.dataset_H, self.PTipr.valid_images), os.path.join(self.base_path, "ipr_PT_2.gz"), 3)

    def __getitem__(self, idx):
        if random.random()<self.ipr_pref:
            return (self.PTipr[idx%self.dataset_len1], 1)
        return (self.PTReal[idx%self.dataset_len2], 0)

    def __len__(self):
        return self.dataset_len2+self.dataset_len1


if __name__=="__main__":
    pt = PhotoTourismCombined("/scratch/udit/phototourism/", 'caffe', 256)
    dl = DataLoader(pt, batch_size=1, num_workers=2)
    for _ in dl:
        pass