|
import gradio as gr |
|
import numpy as np |
|
import matplotlib |
|
matplotlib.use('Agg') |
|
import matplotlib.pyplot as plt |
|
import json |
|
import os |
|
import torch |
|
|
|
import utils |
|
import models |
|
import datasets |
|
|
|
|
|
def load_taxa_metadata(file_path): |
|
taxa_names_file = open(file_path, "r") |
|
data = taxa_names_file.read().split("\n") |
|
data = [dd for dd in data if dd != ''] |
|
taxa_ids = [] |
|
taxa_names = [] |
|
for tt in range(len(data)): |
|
id, nm = data[tt].split('\t') |
|
taxa_ids.append(int(id)) |
|
taxa_names.append(nm) |
|
taxa_names_file.close() |
|
return dict(zip(taxa_ids, taxa_names)) |
|
|
|
|
|
def generate_prediction(taxa_id, selected_model, settings, threshold): |
|
|
|
|
|
if selected_model == 'AN_FULL max 10': |
|
model_path = 'pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_10.pt' |
|
elif selected_model == 'AN_FULL max 100': |
|
model_path = 'pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_100.pt' |
|
elif selected_model == 'AN_FULL max 1000': |
|
model_path = 'pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000.pt' |
|
elif selected_model == 'Distilled env model': |
|
model_path = 'pretrained_models/model_an_full_input_enc_sin_cos_distilled_from_env.pt' |
|
|
|
|
|
with open('paths.json', 'r') as f: |
|
paths = json.load(f) |
|
|
|
|
|
eval_params = {} |
|
eval_params['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
eval_params['model_path'] = model_path |
|
eval_params['taxa_id'] = int(taxa_id) |
|
eval_params['rand_taxa'] = 'Random taxa' in settings |
|
eval_params['set_max_cmap_to_1'] = False |
|
eval_params['disable_ocean_mask'] = 'Disable ocean mask' in settings |
|
eval_params['threshold'] = threshold if 'Threshold' in settings else -1.0 |
|
|
|
|
|
train_params = torch.load(eval_params['model_path'], map_location='cpu') |
|
model = models.get_model(train_params['params']) |
|
model.load_state_dict(train_params['state_dict'], strict=True) |
|
model = model.to(eval_params['device']) |
|
model.eval() |
|
if train_params['params']['input_enc'] in ['env', 'sin_cos_env']: |
|
raster = datasets.load_env(norm=train_params['params']['env_norm']) |
|
else: |
|
raster = None |
|
enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster) |
|
|
|
|
|
if eval_params['rand_taxa']: |
|
print('Selecting random taxa') |
|
eval_params['taxa_id'] = np.random.choice(train_params['params']['class_to_taxa']) |
|
|
|
|
|
if eval_params['taxa_id'] in train_params['params']['class_to_taxa']: |
|
class_of_interest = train_params['params']['class_to_taxa'].index(eval_params['taxa_id']) |
|
else: |
|
print(f'Error: Taxa specified that is not in the model: {eval_params["taxa_id"]}') |
|
fig = plt.figure() |
|
plt.imshow(np.zeros((1,1)), vmin=0, vmax=1.0, cmap=plt.cm.plasma) |
|
plt.axis('off') |
|
plt.tight_layout() |
|
op_html = f'<h2><a href="https://www.inaturalist.org/taxa/{eval_params["taxa_id"]}" target="_blank">{eval_params["taxa_id"]}</a></h2> Error: specified taxa is not in the model.' |
|
|
|
return op_html, fig, eval_params['taxa_id'] |
|
print(f'Loading taxa: {eval_params["taxa_id"]}') |
|
|
|
|
|
mask = np.load(os.path.join(paths['masks'], 'ocean_mask.npy')) |
|
mask_inds = np.where(mask.reshape(-1) == 1)[0] |
|
|
|
|
|
locs = utils.coord_grid(mask.shape) |
|
if not eval_params['disable_ocean_mask']: |
|
locs = locs[mask_inds, :] |
|
locs = torch.from_numpy(locs) |
|
locs_enc = enc.encode(locs).to(eval_params['device']) |
|
|
|
|
|
with torch.no_grad(): |
|
preds = model(locs_enc, return_feats=False, class_of_interest=class_of_interest).cpu().numpy() |
|
|
|
|
|
if eval_params['threshold'] > 0: |
|
print(f'Applying threshold of {eval_params["threshold"]} to the predictions.') |
|
preds[preds<eval_params['threshold']] = 0.0 |
|
preds[preds>=eval_params['threshold']] = 1.0 |
|
|
|
|
|
if not eval_params['disable_ocean_mask']: |
|
op_im = np.ones((mask.shape[0] * mask.shape[1])) * np.nan |
|
op_im[mask_inds] = preds |
|
else: |
|
op_im = preds |
|
|
|
|
|
op_im = op_im.reshape((mask.shape[0], mask.shape[1])) |
|
op_im = np.ma.masked_invalid(op_im) |
|
if eval_params['set_max_cmap_to_1']: |
|
vmax = 1.0 |
|
else: |
|
vmax = np.max(op_im) |
|
|
|
|
|
cmap = plt.cm.plasma |
|
cmap.set_bad(color='none') |
|
|
|
plt.rcParams['figure.figsize'] = 24,12 |
|
fig = plt.figure() |
|
plt.imshow(op_im, vmin=0, vmax=vmax, cmap=cmap) |
|
plt.axis('off') |
|
plt.tight_layout() |
|
|
|
|
|
taxa_name_str = taxa_names[eval_params['taxa_id']] |
|
op_html = f'<h2><a href="https://www.inaturalist.org/taxa/{eval_params["taxa_id"]}" target="_blank">{taxa_name_str}</a></h2> (click for more info)' |
|
return op_html, fig, eval_params['taxa_id'] |
|
|
|
|
|
|
|
taxa_names = load_taxa_metadata('taxa_02_08_2023_names.txt') |
|
|
|
|
|
with gr.Blocks(title="SINR Demo") as demo: |
|
top_text = "Visualization code to explore species range predictions "\ |
|
"from Spatial Implicit Neural Representation (SINR) models from "\ |
|
"[our](https://arxiv.org/abs/2306.02564) ICML 2023 paper." |
|
gr.Markdown("# SINR Visualization Demo") |
|
gr.Markdown(top_text) |
|
|
|
with gr.Row(): |
|
selected_taxa = gr.Number(label="Taxa ID", value=130714) |
|
select_model = gr.Dropdown(["AN_FULL max 10", "AN_FULL max 100", "AN_FULL max 1000", "Distilled env model"], |
|
value="AN_FULL max 1000", label="Model") |
|
with gr.Row(): |
|
settings = gr.CheckboxGroup(["Random taxa", "Disable ocean mask", "Threshold"], label="Settings") |
|
threshold = gr.Slider(0, 1, 0, label="Threshold") |
|
|
|
with gr.Row(): |
|
submit_button = gr.Button("Run Model") |
|
|
|
with gr.Row(): |
|
output_text = gr.HTML(label="Species Name:") |
|
|
|
with gr.Row(): |
|
output_image = gr.Plot(label="Predicted occupancy") |
|
|
|
end_text = "**Note:** Extreme care should be taken before making any decisions "\ |
|
"based on the outputs of models presented here. "\ |
|
"The goal of this work is to demonstrate the promise of large-scale "\ |
|
"representation learning for species range estimation. "\ |
|
"Our models are trained on biased data and have not been calibrated "\ |
|
"or validated beyondthe experiments illustrated in the paper." |
|
gr.Markdown(end_text) |
|
|
|
submit_button.click( |
|
fn = generate_prediction, |
|
inputs=[selected_taxa, select_model, settings, threshold], |
|
outputs=[output_text, output_image, selected_taxa] |
|
) |
|
|
|
demo.launch() |
|
|