|
import os |
|
import numpy as np |
|
import json |
|
import pandas as pd |
|
from calendar import monthrange |
|
import torch |
|
import utils |
|
|
|
class LocationDataset(torch.utils.data.Dataset): |
|
def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device): |
|
|
|
|
|
self.input_enc = input_enc |
|
if self.input_enc in ['env', 'sin_cos_env']: |
|
raster = load_env() |
|
else: |
|
raster = None |
|
self.enc = utils.CoordEncoder(input_enc, raster) |
|
|
|
|
|
self.locs = locs |
|
self.loc_feats = self.enc.encode(self.locs) |
|
self.labels = labels |
|
self.classes = classes |
|
self.class_to_taxa = class_to_taxa |
|
|
|
|
|
self.num_classes = len(np.unique(labels)) |
|
self.input_dim = self.loc_feats.shape[1] |
|
|
|
if self.enc.raster is not None: |
|
self.enc.raster = self.enc.raster.to(device) |
|
|
|
def __len__(self): |
|
return self.loc_feats.shape[0] |
|
|
|
def __getitem__(self, index): |
|
loc_feat = self.loc_feats[index, :] |
|
loc = self.locs[index, :] |
|
class_id = self.labels[index] |
|
return loc_feat, loc, class_id |
|
|
|
def load_env(): |
|
with open('paths.json', 'r') as f: |
|
paths = json.load(f) |
|
raster = load_context_feats(os.path.join(paths['env'],'bioclim_elevation_scaled.npy')) |
|
return raster |
|
|
|
def load_context_feats(data_path): |
|
context_feats = np.load(data_path).astype(np.float32) |
|
context_feats = torch.from_numpy(context_feats) |
|
return context_feats |
|
|
|
def load_inat_data(ip_file, taxa_of_interest=None): |
|
|
|
print('\nLoading ' + ip_file) |
|
data = pd.read_csv(ip_file) |
|
|
|
|
|
num_obs = data.shape[0] |
|
data = data[((data['latitude'] <= 90) & (data['latitude'] >= -90) & (data['longitude'] <= 180) & (data['longitude'] >= -180) )] |
|
if (num_obs - data.shape[0]) > 0: |
|
print(num_obs - data.shape[0], 'items filtered due to invalid locations') |
|
|
|
if 'accuracy' in data.columns: |
|
data.drop(['accuracy'], axis=1, inplace=True) |
|
|
|
if 'positional_accuracy' in data.columns: |
|
data.drop(['positional_accuracy'], axis=1, inplace=True) |
|
|
|
if 'geoprivacy' in data.columns: |
|
data.drop(['geoprivacy'], axis=1, inplace=True) |
|
|
|
if 'observed_on' in data.columns: |
|
data.rename(columns = {'observed_on':'date'}, inplace=True) |
|
|
|
num_obs_orig = data.shape[0] |
|
data = data.dropna() |
|
size_diff = num_obs_orig - data.shape[0] |
|
if size_diff > 0: |
|
print(size_diff, 'observation(s) with a NaN entry out of' , num_obs_orig, 'removed') |
|
|
|
|
|
if taxa_of_interest is not None: |
|
num_obs_orig = data.shape[0] |
|
data = data[data['taxon_id'].isin(taxa_of_interest)] |
|
print(num_obs_orig - data.shape[0], 'observation(s) out of' , num_obs_orig, 'from different taxa removed') |
|
|
|
print('Number of unique classes {}'.format(np.unique(data['taxon_id'].values).shape[0])) |
|
|
|
locs = np.vstack((data['longitude'].values, data['latitude'].values)).T.astype(np.float32) |
|
taxa = data['taxon_id'].values.astype(np.int) |
|
|
|
if 'user_id' in data.columns: |
|
users = data['user_id'].values.astype(np.int) |
|
_, users = np.unique(users, return_inverse=True) |
|
elif 'observer_id' in data.columns: |
|
users = data['observer_id'].values.astype(np.int) |
|
_, users = np.unique(users, return_inverse=True) |
|
else: |
|
users = np.ones(taxa.shape[0], dtype=np.int)*-1 |
|
|
|
|
|
years = np.array([int(d_str[:4]) for d_str in data['date'].values]) |
|
months = np.array([int(d_str[5:7]) for d_str in data['date'].values]) |
|
days = np.array([int(d_str[8:10]) for d_str in data['date'].values]) |
|
days_per_month = np.cumsum([0] + [monthrange(2018, mm)[1] for mm in range(1, 12)]) |
|
dates = days_per_month[months-1] + days-1 |
|
dates = np.round((dates) / 364.0, 4).astype(np.float32) |
|
if 'id' in data.columns: |
|
obs_ids = data['id'].values |
|
elif 'observation_uuid' in data.columns: |
|
obs_ids = data['observation_uuid'].values |
|
|
|
return locs, taxa, users, dates, years, obs_ids |
|
|
|
def choose_aux_species(current_species, num_aux_species, aux_species_seed): |
|
if num_aux_species == 0: |
|
return [] |
|
with open('paths.json', 'r') as f: |
|
paths = json.load(f) |
|
data_dir = paths['train'] |
|
taxa_file = os.path.join(data_dir, 'geo_prior_train_meta.json') |
|
with open(taxa_file, 'r') as f: |
|
inat_large_metadata = json.load(f) |
|
aux_species_candidates = [x['taxon_id'] for x in inat_large_metadata] |
|
aux_species_candidates = np.setdiff1d(aux_species_candidates, current_species) |
|
print(f'choosing {num_aux_species} species to add from {len(aux_species_candidates)} candidates') |
|
rng = np.random.default_rng(aux_species_seed) |
|
idx_rand_aux_species = rng.permutation(len(aux_species_candidates)) |
|
aux_species = list(aux_species_candidates[idx_rand_aux_species[:num_aux_species]]) |
|
return aux_species |
|
|
|
def get_taxa_of_interest(species_set='all', num_aux_species=0, aux_species_seed=123, taxa_file_snt=None): |
|
if species_set == 'all': |
|
return None |
|
if species_set == 'snt_birds': |
|
assert taxa_file_snt is not None |
|
with open(taxa_file_snt, 'r') as f: |
|
taxa_subsets = json.load(f) |
|
taxa_of_interest = list(taxa_subsets['snt_birds']) |
|
else: |
|
raise NotImplementedError |
|
|
|
aux_species = choose_aux_species(taxa_of_interest, num_aux_species, aux_species_seed) |
|
taxa_of_interest.extend(aux_species) |
|
return taxa_of_interest |
|
|
|
def get_idx_subsample_observations(labels, hard_cap=-1, hard_cap_seed=123): |
|
if hard_cap == -1: |
|
return np.arange(len(labels)) |
|
print(f'subsampling (up to) {hard_cap} per class for the training set') |
|
class_counts = {id: 0 for id in np.unique(labels)} |
|
ss_rng = np.random.default_rng(hard_cap_seed) |
|
idx_rand = ss_rng.permutation(len(labels)) |
|
idx_ss = [] |
|
for i in idx_rand: |
|
class_id = labels[i] |
|
if class_counts[class_id] < hard_cap: |
|
idx_ss.append(i) |
|
class_counts[class_id] += 1 |
|
idx_ss = np.sort(idx_ss) |
|
print(f'final training set size: {len(idx_ss)}') |
|
return idx_ss |
|
|
|
def get_train_data(params): |
|
with open('paths.json', 'r') as f: |
|
paths = json.load(f) |
|
data_dir = paths['train'] |
|
obs_file = os.path.join(data_dir, 'geo_prior_train.csv') |
|
taxa_file = os.path.join(data_dir, 'geo_prior_train_meta.json') |
|
taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json') |
|
|
|
taxa_of_interest = get_taxa_of_interest(params['species_set'], params['num_aux_species'], params['aux_species_seed'], taxa_file_snt) |
|
|
|
locs, labels, _, _, _, _ = load_inat_data(obs_file, taxa_of_interest) |
|
unique_taxa, class_ids = np.unique(labels, return_inverse=True) |
|
class_to_taxa = unique_taxa.tolist() |
|
|
|
|
|
class_info_file = json.load(open(taxa_file, 'r')) |
|
class_names_file = [cc['latin_name'] for cc in class_info_file] |
|
taxa_ids_file = [cc['taxon_id'] for cc in class_info_file] |
|
classes = dict(zip(taxa_ids_file, class_names_file)) |
|
|
|
idx_ss = get_idx_subsample_observations(labels, params['hard_cap_num_per_class'], params['hard_cap_seed']) |
|
|
|
locs = torch.from_numpy(np.array(locs)[idx_ss]) |
|
|
|
labels = torch.from_numpy(np.array(class_ids)[idx_ss]) |
|
|
|
ds = LocationDataset(locs, labels, classes, class_to_taxa, params['input_enc'], params['device']) |
|
|
|
return ds |
|
|