import numpy as np import pandas as pd import torch import random import pickle from os.path import join from os.path import isfile from PIL import Image from sklearn.model_selection import train_test_split from torch.utils.data import Dataset from torchvision.transforms import ( Compose, RandomCrop, CenterCrop, RandomHorizontalFlip, ToTensor, ) import time from torchvision.transforms import GaussianBlur from torchvision import transforms from pathlib import Path import json from tqdm import tqdm import multiprocessing as mp import ctypes def normalize(lat, lon): """Used to put all lat lon inside ±90 and ±180.""" lat = (lat + 90) % 360 - 90 if lat > 90: lat = 180 - lat lon += 180 lon = (lon + 180) % 360 - 180 return lat, lon def collate_fn(batch): """Collate function for the dataloader. Args: batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" Returns: dict: dictionary with keys "img", "gps", "idx" and optionally "label" """ keys = list(batch[0].keys()) if "weight" in batch[0].keys(): keys.remove("weight") output = {} for key in [ "idx", "unique_country", "unique_region", "unique_sub-region", "unique_city", "img_idx", "text", ]: if key in keys: idx = [x[key] for x in batch] output[key] = idx keys.remove(key) if "img" in keys and isinstance(batch[0]["img"], Image.Image): output["img"] = [x["img"] for x in batch] keys.remove("img") for key in keys: if not ("text" in key): output[key] = torch.stack([x[key] for x in batch]) return output def collate_fn_streetclip(batch): """Collate function for the dataloader. Args: batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" Returns: dict: dictionary with keys "img", "gps", "idx" and optionally "label" """ keys = list(batch[0].keys()) if "weight" in batch[0].keys(): keys.remove("weight") output = {} for key in [ "idx", "unique_country", "unique_region", "unique_sub-region", "unique_city", "img_idx", "img", "text", ]: if key in keys: idx = [x[key] for x in batch] output[key] = idx keys.remove(key) for key in keys: if not ("text" in key): output[key] = torch.stack([x[key] for x in batch]) return output def collate_fn_denstity(batch): """Collate function for the dataloader. Args: batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" Returns: dict: dictionary with keys "img", "gps", "idx" and optionally "label" """ keys = list(batch[0].keys()) if "weight" in batch[0].keys(): keys.remove("weight") # Sample indices based on the weights weights = np.array([x["weight"] for x in batch]) normalized_weights = weights / np.sum(weights) sampled_indices = np.random.choice( len(batch), size=len(batch), p=normalized_weights, replace=True ) output = {} for key in [ "idx", "unique_country", "unique_region", "unique_sub-region", "unique_city", "img_idx", "text", ]: if key in keys: idx = [batch[i][key] for i in sampled_indices] output[key] = idx keys.remove(key) for key in keys: if not ("text" in key): output[key] = torch.stack([batch[i][key] for i in sampled_indices]) return output def collate_fn_streetclip_denstity(batch): """Collate function for the dataloader. Args: batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" Returns: dict: dictionary with keys "img", "gps", "idx" and optionally "label" """ keys = list(batch[0].keys()) if "weight" in batch[0].keys(): keys.remove("weight") # Sample indices based on the weights weights = np.array([x["weight"] for x in batch]) normalized_weights = weights / np.sum(weights) sampled_indices = np.random.choice( len(batch), size=len(batch), p=normalized_weights, replace=True ) output = {} for key in [ "idx", "unique_country", "unique_region", "unique_sub-region", "unique_city", "img_idx", "img", "text", ]: if key in keys: idx = [batch[i][key] for i in sampled_indices] output[key] = idx keys.remove(key) for key in keys: if not ("text" in key): output[key] = torch.stack([batch[i][key] for i in sampled_indices]) return output def collate_fn_contrastive(batch): """Collate function for the dataloader. Args: batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" Returns: dict: dictionary with keys "img", "gps", "idx" and optionally "label" """ output = collate_fn(batch) pos_img = torch.stack([x["pos_img"] for x in batch]) output["pos_img"] = pos_img return output def collate_fn_contrastive_density(batch): """Collate function for the dataloader. Args: batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" Returns: dict: dictionary with keys "img", "gps", "idx" and optionally "label" """ keys = list(batch[0].keys()) if "weight" in batch[0].keys(): keys.remove("weight") # Sample indices based on the weights weights = np.array([x["weight"] for x in batch]) normalized_weights = weights / np.sum(weights) sampled_indices = np.random.choice( len(batch), size=len(batch), p=normalized_weights, replace=True ) output = {} for key in [ "idx", "unique_country", "unique_region", "unique_sub-region", "unique_city", "img_idx", ]: if key in keys: idx = [batch[i][key] for i in sampled_indices] output[key] = idx keys.remove(key) for key in keys: if not ("text" in key): output[key] = torch.stack([batch[i][key] for i in sampled_indices]) return output class iNaturalist(Dataset): def __init__( self, path, transforms, split="train", output_type="image", embedding_name="dinov2", ): super().__init__() self.split = split with open(Path(path) / f"{split}.json", "r") as f: self.metadata = json.load(f) self.metadata = [ datapoint for datapoint in self.metadata["images"] if "latitude" in datapoint and datapoint["latitude"] is not None ] self.path = path self.transforms = transforms self.output_type = output_type self.embedding_name = embedding_name self.collate_fn = collate_fn def __getitem__(self, i): output = {} if "image" in self.output_type: image_path = Path(self.path) / "images" / self.metadata[i]["file_name"] img = self.transforms(Image.open(image_path)) output["img"] = img if "emb" in self.output_type: emb_path = ( Path(self.path) / "embeddings" / self.embedding_name / self.metadata[i]["file_name"].replace(".jpg", ".npy") ) output["emb"] = torch.tensor(np.load(emb_path)) lat, lon = normalize( self.metadata[i]["latitude"], self.metadata[i]["longitude"] ) output["gps"] = torch.tensor( [np.radians(lat), np.radians(lon)], dtype=torch.float ) output["idx"] = i output["img_idx"] = self.metadata[i]["id"] return output def __len__(self): return len(self.metadata) class OSV5M(Dataset): csv_dtype = {"category": str, "country": str, "city": str} # Don't remove. def __init__( self, path, transforms, split="train", class_name=None, aux_data=[], is_baseline=False, areas=["country", "region", "sub-region", "city"], streetclip=False, suff="", blur=False, output_type="image", embedding_name="dinov2", ): """Initializes the dataset. Args: path (str): path to the dataset transforms (torchvision.transforms): transforms to apply to the images split (str): split to use (train, val, test) class_name (str): category to use (e.g. "city") aux_data (list of str): auxilliary datas to use areas (list of str): regions to perform accuracy streetclip (bool): if the model is streetclip, do not use transform suff (str): suffix of test csv blur (bool): blur bottom of images or not output_type (str): type of output (image or emb) """ self.suff = suff self.path = path self.aux = len(aux_data) > 0 self.aux_list = aux_data self.split = split if split == "select": self.df = self.load_split(split) split = "test" else: self.df = self.load_split(split) self.split = split if "image" in output_type: self.image_data_folder = join( path, "images", ("train" if split == "val" else split), ) self.image_dict_names = {} for root, _, files in os.walk(self.image_data_folder): for file in files: self.image_dict_names[file] = os.path.join(root, file) if "emb" in output_type: self.emb_data_folder = join( path, "embeddings", embedding_name, ("train" if split == "val" else split), ) self.emb_dict_names = {} for root, _, files in os.walk(self.emb_data_folder): for file in files: self.emb_dict_names[file] = os.path.join(root, file) self.output_type = output_type self.is_baseline = is_baseline if self.aux: self.aux_data = {} for col in self.aux_list: if col in ["land_cover", "climate", "soil"]: self.aux_data[col] = pd.get_dummies(self.df[col], dtype=float) if col == "climate": for i in range(31): if not (i in list(self.aux_data[col].columns)): self.aux_data[col][i] = 0 desired_order = [i for i in range(31)] desired_order.remove(20) self.aux_data[col] = self.aux_data[col][desired_order] else: self.aux_data[col] = self.df[col].apply(lambda x: [x]) self.areas = ["_".join(["unique", area]) for area in areas] if class_name is None: self.class_name = class_name elif "quadtree" in class_name: self.class_name = class_name else: self.class_name = "_".join(["unique", class_name]) ex = self.extract_classes(self.class_name) self.df = self.df[ ["id", "latitude", "longitude", "weight"] + self.areas + ex ].fillna("NaN") if self.class_name in self.areas: self.df.columns = list(self.df.columns)[:-1] + [self.class_name + "_2"] self.transforms = transforms self.collate_fn = collate_fn self.collate_fn_density = collate_fn_denstity self.blur = blur self.streetclip = streetclip if self.streetclip: self.collate_fn = collate_fn_streetclip self.collate_fn_density = collate_fn_streetclip_denstity def load_split(self, split): """Returns a new dataset with the given split.""" start_time = time.time() if split == "test": df = pd.read_csv(join(self.path, "test.csv"), dtype=self.csv_dtype) # extract coord longitude = df["longitude"].values latitude = df["latitude"].values # Create bins num_bins = 100 lon_bins = np.linspace(longitude.min(), longitude.max(), num_bins) lat_bins = np.linspace(latitude.min(), latitude.max(), num_bins) # compute density and weights hist, _, _ = np.histogram2d(longitude, latitude, bins=[lon_bins, lat_bins]) weights = 1.0 / np.power(hist[df["lon_bin"], df["lat_bin"]], 0.75) normalized_weights = weights / np.sum(weights) df["weight"] = normalized_weights return df elif split == "select": df = pd.read_csv(join(self.path, "select.csv"), dtype=self.csv_dtype) # extract coord longitude = df["longitude"].values latitude = df["latitude"].values # Create bins num_bins = 100 lon_bins = np.linspace(longitude.min(), longitude.max(), num_bins) lat_bins = np.linspace(latitude.min(), latitude.max(), num_bins) # compute density and weights hist, _, _ = np.histogram2d(longitude, latitude, bins=[lon_bins, lat_bins]) weights = 1.0 / np.power(hist[df["lon_bin"], df["lat_bin"]], 0.75) normalized_weights = weights / np.sum(weights) df["weight"] = normalized_weights return df else: if len(self.suff) == 0: df = pd.read_csv(join(self.path, "train.csv"), dtype=self.csv_dtype) else: df = pd.read_csv( join(self.path, "train" + "_" + self.suff + ".csv"), dtype=self.csv_dtype, ) # extract coord longitude = df["longitude"].values latitude = df["latitude"].values # Create bins num_bins = 100 lon_bins = np.linspace(longitude.min(), longitude.max(), num_bins) lat_bins = np.linspace(latitude.min(), latitude.max(), num_bins) # compute density and weights hist, _, _ = np.histogram2d(longitude, latitude, bins=[lon_bins, lat_bins]) weights = 1.0 / np.power(hist[df["lon_bin"], df["lat_bin"]], 0.75) normalized_weights = weights / np.sum(weights) df["weight"] = normalized_weights test_df = df.sample( n=int(0.1 * len(df)), weights=normalized_weights, replace=False, random_state=42, ) end_time = time.time() print(f"Loading {split} dataset took {(end_time - start_time):.2f} seconds") if split == "val": return test_df else: return df.drop(test_df.index) def extract_classes(self, tag=None): """Extracts the categories from the dataset.""" if tag is None: self.has_labels = False return [] splits = ["train", "test"] if self.is_baseline else ["train"] # splits = ["train", "test"] print(f"Loading categories from {splits}") # concatenate all categories from relevant splits to find the unique ones. self.categories = sorted( pd.concat( [pd.read_csv(join(self.path, f"{split}.csv"))[tag] for split in splits] ) .fillna("NaN") .unique() .tolist() ) if "NaN" in self.categories: self.categories.remove("NaN") if self.split != "test": self.df = self.df.dropna(subset=[tag]) # compute the total number of categories - this name is fixed and will be used as a lookup during init self.num_classes = len(self.categories) # create a mapping from category to index self.category_to_index = { category: i for i, category in enumerate(self.categories) } self.has_labels = True return [tag] def __getitem__(self, i): """Returns an item from the dataset. Args: i (int): index of the item Returns: dict: dictionary with keys "img", "gps", "idx" and optionally "label" """ x = list(self.df.iloc[i]) # id, latitude, longitude, {category} output = {} if "image" in self.output_type: if self.streetclip: img = Image.open(self.image_dict_names[f"{int(x[0])}.jpg"]) elif self.blur: img = transforms.ToTensor()( Image.open(self.image_dict_names[f"{int(x[0])}.jpg"]) ) u = GaussianBlur(kernel_size=13, sigma=2.0) bottom_part = img[:, -14:, :].unsqueeze(0) blurred_bottom = u(bottom_part) img[:, -14:, :] = blurred_bottom.squeeze() img = self.transforms(transforms.ToPILImage()(img)) else: img = self.transforms( Image.open(self.image_dict_names[f"{int(x[0])}.jpg"]) ) output["img"] = img if "emb" in self.output_type: output["emb"] = torch.FloatTensor( np.load(self.emb_dict_names[f"{int(x[0])}.npy"]) ) lat, lon = normalize(x[1], x[2]) gps = torch.FloatTensor([np.radians(lat), np.radians(lon)]).squeeze(0) output.update( { "gps": gps, "idx": i, "img_idx": int(x[0]), "weight": x[3], } ) for count, area in enumerate(self.areas): output[area] = x[ count + 4 ] #'country': x[3], 'region': x[4], 'sub-region': x[5], 'city': x[6]} if self.has_labels: if x[-1] in self.categories: output["label"] = torch.LongTensor( [self.category_to_index[x[-1]]] ).squeeze(-1) else: output["label"] = torch.LongTensor([-1]).squeeze(-1) if self.aux: for col in self.aux_list: output[col] = torch.FloatTensor(self.aux_data[col].iloc[i]) return output def __len__(self): return len(self.df) class ContrastiveOSV5M(OSV5M): def __init__( self, path, transforms, split="train", class_name=None, aux_data=[], class_name2=None, blur=False, ): """ class_name2 (str): if not None, we do contrastive an other class than the one specified for classif """ super().__init__( path, transforms, split=split, class_name=class_name, aux_data=aux_data, blur=blur, ) self.add_label = False if not (class_name2 is None) and split != "test" and split != "select": self.add_label = True self.class_name = class_name2 self.extract_classes_contrastive(tag=class_name2) self.df = self.df.reset_index(drop=True) self.dict_classes = { value: indices.tolist() for value, indices in self.df.groupby(self.class_name).groups.items() } self.collate_fn = collate_fn_contrastive self.random_crop = RandomCrop(224) # use when no positive image is available def sample_positive(self, i): """ sample positive image from the same city, country if it is available otherwise, apply different crop to the image """ x = self.df.iloc[i] # id, latitude, longitude, {category} class_name = x[self.class_name] idxs = self.dict_classes[class_name] idxs.remove(i) if len(idxs) > 0: idx = random.choice(idxs) x = self.df.iloc[idx] pos_img = self.transforms( Image.open(self.dict_names[f"{int(x['id'])}.jpg"]) ) else: pos_img = self.random_crop( self.transforms(Image.open(self.dict_names[f"{int(x['id'])}.jpg"])) ) return pos_img def extract_classes_contrastive(self, tag=None): """Extracts the categories from the dataset.""" if tag is None: self.has_labels = False return [] splits = ["train", "test"] if self.is_baseline else ["train"] # splits = ["train", "test"] print(f"Loading categories from {splits}") # concatenate all categories from relevant splits to find the unique ones. categories = sorted( pd.concat( [pd.read_csv(join(self.path, f"{split}.csv"))[tag] for split in splits] ) .fillna("NaN") .unique() .tolist() ) # create a mapping from category to index self.contrastive_category_to_index = { category: i for i, category in enumerate(categories) } def __getitem__(self, i): output = super().__getitem__(i) pos_img = self.sample_positive(i) output["pos_img"] = pos_img if self.add_label: output["label_contrastive"] = torch.LongTensor( [self.contrastive_category_to_index[self.df[self.class_name].iloc[i]]] ).squeeze(-1) return output class TextContrastiveOSV5M(OSV5M): def __init__( self, path, transforms, split="train", class_name=None, aux_data=[], blur=False, ): super().__init__( path, transforms, split=split, class_name=class_name, aux_data=aux_data, blur=blur, ) self.df = self.df.reset_index(drop=True) def get_text(self, i): """ sample positive image from the same city, country if it is available otherwise, apply different crop to the image """ x = self.df.iloc[i] # id, latitude, longitude, {category} l = [ name.split("_")[-1] for name in [ x["unique_city"], x["unique_sub-region"], x["unique_region"], x["unique_country"], ] ] pre = False sentence = "An image of " if l[0] != "NaN": sentence += "the city of " sentence += l[0] pre = True if l[1] != "NaN": if pre: sentence += ", in " sentence += "the area of " sentence += l[1] pre = True if l[2] != "NaN": if pre: sentence += ", in " sentence += "the region of " sentence += l[2] pre = True if l[3] != "NaN": if pre: sentence += ", in " sentence += l[3] return sentence def __getitem__(self, i): output = super().__getitem__(i) output["text"] = self.get_text(i) return output import os import json class Baseline(Dataset): def __init__( self, path, which, transforms, ): """Initializes the dataset. Args: path (str): path to the dataset which (str): which baseline to use (im2gps, im2gps3k) transforms (torchvision.transforms): transforms to apply to the images """ baselines = { "im2gps": self.load_im2gps, "im2gps3k": self.load_im2gps, "yfcc4k": self.load_yfcc4k, } self.path = path self.samples = baselines[which]() self.transforms = transforms self.collate_fn = collate_fn self.class_name = which def load_im2gps( self, ): json_path = join(self.path, "info.json") with open(json_path) as f: data = json.load(f) samples = [] for f in os.listdir(join(self.path, "images")): if len(data[f]): lat = float(data[f][-4].replace("latitude: ", "")) lon = float(data[f][-3].replace("longitude: ", "")) samples.append((f, lat, lon)) return samples def load_yfcc4k( self, ): samples = [] with open(join(self.path, "info.txt")) as f: lines = f.readlines() for line in lines: x = line.split("\t") f, lon, lat = x[1], x[12], x[13] samples.append((f + ".jpg", float(lat), float(lon))) return samples def __getitem__(self, i): """Returns an item from the dataset. Args: i (int): index of the item Returns: dict: dictionary with keys "img", "gps", "idx" and optionally "label" """ img_path, lat, lon = self.samples[i] img = self.transforms( Image.open(join(self.path, "images", img_path)).convert("RGB") ) lat, lon = normalize(lat, lon) gps = torch.FloatTensor([np.radians(lat), np.radians(lon)]).squeeze(0) return { "img": img, "gps": gps, "idx": i, } def __len__(self): return len(self.samples) null_transform = lambda x: x