|
import torch |
|
import numpy as np |
|
import math |
|
import datetime |
|
|
|
class CoordEncoder: |
|
|
|
def __init__(self, input_enc, raster=None): |
|
self.input_enc = input_enc |
|
self.raster = raster |
|
|
|
def encode(self, locs, normalize=True): |
|
|
|
if normalize: |
|
locs = normalize_coords(locs) |
|
if self.input_enc == 'sin_cos': |
|
loc_feats = encode_loc(locs) |
|
elif self.input_enc == 'env': |
|
loc_feats = bilinear_interpolate(locs, self.raster) |
|
elif self.input_enc == 'sin_cos_env': |
|
loc_feats = encode_loc(locs) |
|
context_feats = bilinear_interpolate(locs, self.raster) |
|
loc_feats = torch.cat((loc_feats, context_feats), 1) |
|
else: |
|
raise NotImplementedError('Unknown input encoding.') |
|
return loc_feats |
|
|
|
def normalize_coords(locs): |
|
|
|
|
|
|
|
locs[:,0] /= 180.0 |
|
locs[:,1] /= 90.0 |
|
|
|
return locs |
|
|
|
def encode_loc(loc_ip, concat_dim=1): |
|
|
|
|
|
feats = torch.cat((torch.sin(math.pi*loc_ip), torch.cos(math.pi*loc_ip)), concat_dim) |
|
return feats |
|
|
|
def bilinear_interpolate(loc_ip, data, remove_nans_raster=True): |
|
|
|
|
|
|
|
|
|
|
|
assert data is not None |
|
|
|
|
|
loc = (loc_ip.clone() + 1) / 2.0 |
|
loc[:,1] = 1 - loc[:,1] |
|
|
|
|
|
assert not torch.any(torch.isnan(loc)) |
|
|
|
if remove_nans_raster: |
|
data[torch.isnan(data)] = 0.0 |
|
|
|
|
|
loc[:, 0] *= (data.shape[1]-1) |
|
loc[:, 1] *= (data.shape[0]-1) |
|
|
|
loc_int = torch.floor(loc).long() |
|
xx = loc_int[:, 0] |
|
yy = loc_int[:, 1] |
|
xx_plus = xx + 1 |
|
xx_plus[xx_plus > (data.shape[1]-1)] = data.shape[1]-1 |
|
yy_plus = yy + 1 |
|
yy_plus[yy_plus > (data.shape[0]-1)] = data.shape[0]-1 |
|
|
|
loc_delta = loc - torch.floor(loc) |
|
dx = loc_delta[:, 0].unsqueeze(1) |
|
dy = loc_delta[:, 1].unsqueeze(1) |
|
|
|
interp_val = data[yy, xx, :]*(1-dx)*(1-dy) + data[yy, xx_plus, :]*dx*(1-dy) + \ |
|
data[yy_plus, xx, :]*(1-dx)*dy + data[yy_plus, xx_plus, :]*dx*dy |
|
|
|
return interp_val |
|
|
|
def rand_samples(batch_size, device, rand_type='uniform'): |
|
|
|
|
|
if rand_type == 'spherical': |
|
rand_loc = torch.rand(batch_size, 2).to(device) |
|
theta1 = 2.0*math.pi*rand_loc[:, 0] |
|
theta2 = torch.acos(2.0*rand_loc[:, 1] - 1.0) |
|
lat = 1.0 - 2.0*theta2/math.pi |
|
lon = (theta1/math.pi) - 1.0 |
|
rand_loc = torch.cat((lon.unsqueeze(1), lat.unsqueeze(1)), 1) |
|
|
|
elif rand_type == 'uniform': |
|
rand_loc = torch.rand(batch_size, 2).to(device)*2.0 - 1.0 |
|
|
|
return rand_loc |
|
|
|
def get_time_stamp(): |
|
cur_time = str(datetime.datetime.now()) |
|
date, time = cur_time.split(' ') |
|
h, m, s = time.split(':') |
|
s = s.split('.')[0] |
|
time_stamp = '{}-{}-{}-{}'.format(date, h, m, s) |
|
return time_stamp |
|
|
|
def coord_grid(grid_size, split_ids=None, split_of_interest=None): |
|
|
|
|
|
feats = np.zeros((grid_size[0], grid_size[1], 2), dtype=np.float32) |
|
mg = np.meshgrid(np.linspace(-180, 180, feats.shape[1]), np.linspace(90, -90, feats.shape[0])) |
|
feats[:, :, 0] = mg[0] |
|
feats[:, :, 1] = mg[1] |
|
if split_ids is None or split_of_interest is None: |
|
|
|
|
|
return feats.reshape(feats.shape[0]*feats.shape[1], 2) |
|
else: |
|
|
|
ind_y, ind_x = np.where(split_ids==split_of_interest) |
|
|
|
|
|
return feats[ind_y, ind_x, :] |
|
|
|
def create_spatial_split(raster, mask, train_amt=1.0, cell_size=25): |
|
|
|
|
|
|
|
split_ids = np.ones((raster.shape[0], raster.shape[1])) |
|
start = cell_size |
|
for ii in np.arange(0, split_ids.shape[0], cell_size): |
|
if start == 0: |
|
start = cell_size |
|
else: |
|
start = 0 |
|
for jj in np.arange(start, split_ids.shape[1], cell_size*2): |
|
split_ids[ii:ii+cell_size, jj:jj+cell_size] = 2 |
|
split_ids = split_ids*mask |
|
if train_amt < 1.0: |
|
|
|
tr_y, tr_x = np.where(split_ids==1) |
|
inds = np.random.choice(len(tr_y), int(len(tr_y)*(1.0-train_amt)), replace=False) |
|
split_ids[tr_y[inds], tr_x[inds]] = 0 |
|
return split_ids |
|
|