|
import numpy as np |
|
import pandas as pd |
|
import random |
|
import torch |
|
import time |
|
import os |
|
import copy |
|
import json |
|
import tifffile |
|
import h3 |
|
import setup |
|
|
|
from sklearn.linear_model import RidgeCV |
|
from sklearn.preprocessing import MinMaxScaler |
|
from sklearn.metrics import average_precision_score |
|
|
|
import utils |
|
import models |
|
import datasets |
|
|
|
class EvaluatorSNT: |
|
def __init__(self, train_params, eval_params): |
|
self.train_params = train_params |
|
self.eval_params = eval_params |
|
with open('paths.json', 'r') as f: |
|
paths = json.load(f) |
|
D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True) |
|
D = D.item() |
|
self.loc_indices_per_species = D['loc_indices_per_species'] |
|
self.labels_per_species = D['labels_per_species'] |
|
self.taxa = D['taxa'] |
|
self.obs_locs = D['obs_locs'] |
|
self.obs_locs_idx = D['obs_locs_idx'] |
|
|
|
def get_labels(self, species): |
|
species = str(species) |
|
lat = [] |
|
lon = [] |
|
gt = [] |
|
for hx in self.data: |
|
cur_lat, cur_lon = h3.h3_to_geo(hx) |
|
if species in self.data[hx]: |
|
cur_label = int(len(self.data[hx][species]) > 0) |
|
gt.append(cur_label) |
|
lat.append(cur_lat) |
|
lon.append(cur_lon) |
|
lat = np.array(lat).astype(np.float32) |
|
lon = np.array(lon).astype(np.float32) |
|
obs_locs = np.vstack((lon, lat)).T |
|
gt = np.array(gt).astype(np.float32) |
|
return obs_locs, gt |
|
|
|
def run_evaluation(self, model, enc): |
|
results = {} |
|
|
|
|
|
np.random.seed(self.eval_params['seed']) |
|
random.seed(self.eval_params['seed']) |
|
|
|
|
|
results['mean_average_precision'] = np.zeros((len(self.taxa)), dtype=np.float32) |
|
|
|
obs_locs = torch.from_numpy(self.obs_locs).to(self.eval_params['device']) |
|
loc_feat = enc.encode(obs_locs) |
|
|
|
classes_of_interest = np.array([np.where(np.array(self.train_params['class_to_taxa']) == tt)[0] for tt in self.taxa]).squeeze() |
|
classes_of_interest = torch.from_numpy(classes_of_interest) |
|
|
|
with torch.no_grad(): |
|
loc_emb = model(loc_feat, return_feats=True) |
|
wt = model.class_emb.weight[classes_of_interest, :] |
|
pred_mtx = torch.matmul(loc_emb, wt.T).cpu().numpy() |
|
|
|
split_rng = np.random.default_rng(self.eval_params['split_seed']) |
|
|
|
for tt_id, tt in enumerate(self.taxa): |
|
|
|
cur_class_of_interest = np.where(self.taxa == tt)[0][0] |
|
cur_loc_indices = np.array(self.loc_indices_per_species[cur_class_of_interest]) |
|
cur_labels = np.array(self.labels_per_species[cur_class_of_interest]) |
|
|
|
|
|
assert self.eval_params['split'] in ['all', 'val', 'test'] |
|
if self.eval_params['split'] != 'all': |
|
num_val = np.floor(len(cur_labels) * self.eval_params['val_frac']).astype(int) |
|
idx_rand = split_rng.permutation(len(cur_labels)) |
|
if self.eval_params['split'] == 'val': |
|
idx_sel = idx_rand[:num_val] |
|
elif self.eval_params['split'] == 'test': |
|
idx_sel = idx_rand[num_val:] |
|
cur_loc_indices = cur_loc_indices[idx_sel] |
|
cur_labels = cur_labels[idx_sel] |
|
|
|
|
|
pred = pred_mtx[cur_loc_indices, tt_id] |
|
|
|
|
|
results['mean_average_precision'][tt_id] = average_precision_score((cur_labels > 0).astype(np.int32), pred) |
|
|
|
|
|
valid_taxa = ~np.isnan(results['mean_average_precision']) |
|
|
|
|
|
results['per_species_average_precision_all'] = copy.deepcopy(results['mean_average_precision']) |
|
per_species_average_precision_valid = results['per_species_average_precision_all'][valid_taxa] |
|
results['mean_average_precision'] = per_species_average_precision_valid.mean() |
|
results['num_eval_species_w_valid_ap'] = valid_taxa.sum() |
|
results['num_eval_species_total'] = len(self.taxa) |
|
|
|
return results |
|
|
|
def report(self, results): |
|
for field in ['mean_average_precision', 'num_eval_species_w_valid_ap', 'num_eval_species_total']: |
|
print(f'{field}: {results[field]}') |
|
|
|
class EvaluatorIUCN: |
|
|
|
def __init__(self, train_params, eval_params): |
|
self.train_params = train_params |
|
self.eval_params = eval_params |
|
with open('paths.json', 'r') as f: |
|
paths = json.load(f) |
|
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: |
|
self.data = json.load(f) |
|
self.obs_locs = np.array(self.data['locs'], dtype=np.float32) |
|
self.taxa = [int(tt) for tt in self.data['taxa_presence'].keys()] |
|
|
|
def run_evaluation(self, model, enc): |
|
results = {} |
|
|
|
results['per_species_average_precision_all'] = np.zeros(len(self.taxa), dtype=np.float32) |
|
|
|
obs_locs = torch.from_numpy(self.obs_locs).to(self.eval_params['device']) |
|
loc_feat = enc.encode(obs_locs) |
|
|
|
|
|
classes_of_interest = torch.from_numpy(np.array([np.where(np.array(self.train_params['class_to_taxa']) == tt)[0] for tt in self.taxa]).squeeze()) |
|
with torch.no_grad(): |
|
|
|
loc_emb = model(loc_feat, return_feats=True) |
|
wt = model.class_emb.weight[classes_of_interest, :] |
|
pred_mtx = torch.matmul(loc_emb, wt.T) |
|
|
|
for tt_id, tt in enumerate(self.taxa): |
|
class_of_interest = np.where(np.array(self.train_params['class_to_taxa']) == tt)[0] |
|
|
|
if len(class_of_interest) == 0: |
|
|
|
pred = None |
|
else: |
|
|
|
pred = pred_mtx[:, tt_id] |
|
|
|
|
|
if pred is None: |
|
results['per_species_average_precision_all'][tt_id] = np.nan |
|
else: |
|
gt = np.zeros(obs_locs.shape[0], dtype=np.float32) |
|
gt[self.data['taxa_presence'][str(tt)]] = 1.0 |
|
|
|
results['per_species_average_precision_all'][tt_id] = average_precision_score(gt, pred) |
|
|
|
valid_taxa = ~np.isnan(results['per_species_average_precision_all']) |
|
|
|
|
|
per_species_average_precision_valid = results['per_species_average_precision_all'][valid_taxa] |
|
results['mean_average_precision'] = per_species_average_precision_valid.mean() |
|
results['num_eval_species_w_valid_ap'] = valid_taxa.sum() |
|
results['num_eval_species_total'] = len(self.taxa) |
|
return results |
|
|
|
def report(self, results): |
|
for field in ['mean_average_precision', 'num_eval_species_w_valid_ap', 'num_eval_species_total']: |
|
print(f'{field}: {results[field]}') |
|
|
|
class EvaluatorGeoPrior: |
|
|
|
def __init__(self, train_params, eval_params): |
|
|
|
self.train_params = train_params |
|
self.eval_params = eval_params |
|
with open('paths.json', 'r') as f: |
|
paths = json.load(f) |
|
|
|
self.data = np.load(os.path.join(paths['geo_prior'], 'geo_prior_model_preds.npz')) |
|
print('\n', self.data['probs'].shape[0], 'total test observations') |
|
|
|
meta = pd.read_csv(os.path.join(paths['geo_prior'], 'geo_prior_model_meta.csv')) |
|
self.obs_locs = np.vstack((meta['longitude'].values, meta['latitude'].values)).T.astype(np.float32) |
|
|
|
self.taxon_map = self.find_mapping_between_models(self.data['model_to_taxa'], self.train_params['class_to_taxa']) |
|
print(self.taxon_map.shape[0], 'out of', len(self.data['model_to_taxa']), 'taxa in both vision and geo models') |
|
|
|
def find_mapping_between_models(self, vision_taxa, geo_taxa): |
|
|
|
|
|
|
|
taxon_map = np.ones((vision_taxa.shape[0], 2), dtype=np.int32)*-1 |
|
taxon_map[:, 0] = np.arange(vision_taxa.shape[0]) |
|
geo_taxa_arr = np.array(geo_taxa) |
|
for tt_id, tt in enumerate(vision_taxa): |
|
ind = np.where(geo_taxa_arr==tt)[0] |
|
if len(ind) > 0: |
|
taxon_map[tt_id, 1] = ind[0] |
|
inds = np.where(taxon_map[:, 1]>-1)[0] |
|
taxon_map = taxon_map[inds, :] |
|
return taxon_map |
|
|
|
def convert_to_inat_vision_order(self, geo_pred_ip, vision_top_k_prob, vision_top_k_inds, vision_taxa, taxon_map): |
|
|
|
vision_pred = np.zeros((geo_pred_ip.shape[0], len(vision_taxa)), dtype=np.float32) |
|
geo_pred = np.ones((geo_pred_ip.shape[0], len(vision_taxa)), dtype=np.float32) |
|
vision_pred[np.arange(vision_pred.shape[0])[..., np.newaxis], vision_top_k_inds] = vision_top_k_prob |
|
|
|
geo_pred[:, taxon_map[:, 0]] = geo_pred_ip[:, taxon_map[:, 1]] |
|
|
|
return geo_pred, vision_pred |
|
|
|
def run_evaluation(self, model, enc): |
|
results = {} |
|
|
|
|
|
batch_start = np.hstack((np.arange(0, self.data['probs'].shape[0], self.eval_params['batch_size']), self.data['probs'].shape[0])) |
|
correct_pred = np.zeros(self.data['probs'].shape[0]) |
|
|
|
print('\nbid\t w geo\t wo geo') |
|
for bb_id, bb in enumerate(range(len(batch_start)-1)): |
|
batch_inds = np.arange(batch_start[bb], batch_start[bb+1]) |
|
|
|
vision_probs = self.data['probs'][batch_inds, :] |
|
vision_inds = self.data['inds'][batch_inds, :] |
|
gt = self.data['labels'][batch_inds] |
|
|
|
obs_locs_batch = torch.from_numpy(self.obs_locs[batch_inds, :]).to(self.eval_params['device']) |
|
loc_feat = enc.encode(obs_locs_batch) |
|
|
|
with torch.no_grad(): |
|
geo_pred = model(loc_feat).cpu().numpy() |
|
|
|
geo_pred, vision_pred = self.convert_to_inat_vision_order(geo_pred, vision_probs, vision_inds, |
|
self.data['model_to_taxa'], self.taxon_map) |
|
|
|
comb_pred = np.argmax(vision_pred*geo_pred, 1) |
|
comb_pred = (comb_pred==gt) |
|
correct_pred[batch_inds] = comb_pred |
|
|
|
results['vision_only_top_1'] = float((self.data['inds'][:, -1] == self.data['labels']).mean()) |
|
results['vision_geo_top_1'] = float(correct_pred.mean()) |
|
return results |
|
|
|
def report(self, results): |
|
print('\nOverall accuracy vision only model', round(results['vision_only_top_1'], 3)) |
|
print('Overall accuracy of geo model ', round(results['vision_geo_top_1'], 3)) |
|
print('Gain ', round(results['vision_geo_top_1'] - results['vision_only_top_1'], 3)) |
|
|
|
class EvaluatorGeoFeature: |
|
|
|
def __init__(self, train_params, eval_params): |
|
self.train_params = train_params |
|
self.eval_params = eval_params |
|
with open('paths.json', 'r') as f: |
|
paths = json.load(f) |
|
self.data_path = paths['geo_feature'] |
|
self.country_mask = tifffile.imread(os.path.join(paths['masks'], 'USA_MASK.tif')) == 1 |
|
self.raster_names = ['ABOVE_GROUND_CARBON', 'ELEVATION', 'LEAF_AREA_INDEX', 'NON_TREE_VEGITATED', 'NOT_VEGITATED', 'POPULATION_DENSITY', 'SNOW_COVER', 'SOIL_MOISTURE', 'TREE_COVER'] |
|
self.raster_names_log_transform = ['POPULATION_DENSITY'] |
|
|
|
def load_raster(self, raster_name, log_transform=False): |
|
raster = tifffile.imread(os.path.join(self.data_path, raster_name + '.tif')).astype(np.float32) |
|
valid_mask = ~np.isnan(raster).copy() & self.country_mask |
|
|
|
if log_transform: |
|
raster[valid_mask] = np.log1p(raster[valid_mask] - raster[valid_mask].min()) |
|
|
|
raster[valid_mask] -= raster[valid_mask].min() |
|
raster[valid_mask] /= raster[valid_mask].max() |
|
|
|
return raster, valid_mask |
|
|
|
def get_split_labels(self, raster, split_ids, split_of_interest): |
|
|
|
inds_y, inds_x = np.where(split_ids==split_of_interest) |
|
return raster[inds_y, inds_x] |
|
|
|
def get_split_feats(self, model, enc, split_ids, split_of_interest): |
|
locs = utils.coord_grid(self.country_mask.shape, split_ids=split_ids, split_of_interest=split_of_interest) |
|
locs = torch.from_numpy(locs).to(self.eval_params['device']) |
|
locs_enc = enc.encode(locs) |
|
with torch.no_grad(): |
|
feats = model(locs_enc, return_feats=True).cpu().numpy() |
|
return feats |
|
|
|
def run_evaluation(self, model, enc): |
|
results = {} |
|
for raster_name in self.raster_names: |
|
do_log_transform = raster_name in self.raster_names_log_transform |
|
raster, valid_mask = self.load_raster(raster_name, do_log_transform) |
|
split_ids = utils.create_spatial_split(raster, valid_mask, cell_size=self.eval_params['cell_size']) |
|
feats_train = self.get_split_feats(model, enc, split_ids=split_ids, split_of_interest=1) |
|
feats_test = self.get_split_feats(model, enc, split_ids=split_ids, split_of_interest=2) |
|
labels_train = self.get_split_labels(raster, split_ids, 1) |
|
labels_test = self.get_split_labels(raster, split_ids, 2) |
|
scaler = MinMaxScaler() |
|
feats_train_scaled = scaler.fit_transform(feats_train) |
|
feats_test_scaled = scaler.transform(feats_test) |
|
clf = RidgeCV(alphas=(0.1, 1.0, 10.0), normalize=False, cv=10, fit_intercept=True, scoring='r2').fit(feats_train_scaled, labels_train) |
|
train_score = clf.score(feats_train_scaled, labels_train) |
|
test_score = clf.score(feats_test_scaled, labels_test) |
|
results[f'train_r2_{raster_name}'] = float(train_score) |
|
results[f'test_r2_{raster_name}'] = float(test_score) |
|
results[f'alpha_{raster_name}'] = float(clf.alpha_) |
|
return results |
|
|
|
def report(self, results): |
|
report_fields = [x for x in results if 'test_r2' in x] |
|
for field in report_fields: |
|
print(f'{field}: {results[field]}') |
|
print(np.mean([results[field] for field in report_fields])) |
|
|
|
def launch_eval_run(overrides): |
|
|
|
eval_params = setup.get_default_params_eval(overrides) |
|
|
|
|
|
eval_params['model_path'] = os.path.join(eval_params['exp_base'], eval_params['experiment_name'], eval_params['ckp_name']) |
|
train_params = torch.load(eval_params['model_path'], map_location='cpu') |
|
model = models.get_model(train_params['params']) |
|
model.load_state_dict(train_params['state_dict'], strict=True) |
|
model = model.to(eval_params['device']) |
|
model.eval() |
|
|
|
|
|
if train_params['params']['input_enc'] in ['env', 'sin_cos_env']: |
|
raster = datasets.load_env().to(eval_params['device']) |
|
else: |
|
raster = None |
|
enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster) |
|
|
|
t = time.time() |
|
if eval_params['eval_type'] == 'snt': |
|
eval_params['split'] = 'test' |
|
eval_params['val_frac'] = 0.50 |
|
eval_params['split_seed'] = 7499 |
|
evaluator = EvaluatorSNT(train_params['params'], eval_params) |
|
results = evaluator.run_evaluation(model, enc) |
|
evaluator.report(results) |
|
elif eval_params['eval_type'] == 'iucn': |
|
evaluator = EvaluatorIUCN(train_params['params'], eval_params) |
|
results = evaluator.run_evaluation(model, enc) |
|
evaluator.report(results) |
|
elif eval_params['eval_type'] == 'geo_prior': |
|
evaluator = EvaluatorGeoPrior(train_params['params'], eval_params) |
|
results = evaluator.run_evaluation(model, enc) |
|
evaluator.report(results) |
|
elif eval_params['eval_type'] == 'geo_feature': |
|
evaluator = EvaluatorGeoFeature(train_params['params'], eval_params) |
|
results = evaluator.run_evaluation(model, enc) |
|
evaluator.report(results) |
|
else: |
|
raise NotImplementedError('Eval type not implemented.') |
|
print(f'evaluation completed in {np.around((time.time()-t)/60, 1)} min') |
|
return results |
|
|