File size: 2,121 Bytes
8390f90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
r""" CHM 4D kernel (psi, iso, and full) generator """
import torch
from .geometry import Geometry
class KernelGenerator:
def __init__(self, ksz, ktype):
self.ksz = ksz
self.idx4d = Geometry.init_idx4d(ksz)
self.kernel = torch.zeros((ksz, ksz, ksz, ksz))
self.center = (ksz // 2, ksz // 2)
self.ktype = ktype
def quadrant(self, crd):
if crd[0] < self.center[0]:
horz_quad = -1
elif crd[0] < self.center[0]:
horz_quad = 1
else:
horz_quad = 0
if crd[1] < self.center[1]:
vert_quad = -1
elif crd[1] < self.center[1]:
vert_quad = 1
else:
vert_quad = 0
return horz_quad, vert_quad
def generate(self):
return None if self.ktype == 'full' else self.generate_chm_kernel()
def generate_chm_kernel(self):
param_dict = {}
for idx in self.idx4d:
src_i, src_j, trg_i, trg_j = idx
d_tail = Geometry.get_distance((src_i, src_j), self.center)
d_head = Geometry.get_distance((trg_i, trg_j), self.center)
d_off = Geometry.get_distance((src_i, src_j), (trg_i, trg_j))
horz_quad, vert_quad = self.quadrant((src_j, src_i))
src_crd = (src_i, src_j)
trg_crd = (trg_i, trg_j)
key = self.build_key(horz_quad, vert_quad, d_head, d_tail, src_crd, trg_crd, d_off)
coord1d = Geometry.get_coord1d((src_i, src_j, trg_i, trg_j), self.ksz)
if param_dict.get(key) is None: param_dict[key] = []
param_dict[key].append(coord1d)
return param_dict
def build_key(self, horz_quad, vert_quad, d_head, d_tail, src_crd, trg_crd, d_off):
if self.ktype == 'iso':
return '%d' % d_off
elif self.ktype == 'psi':
d_max = max(d_head, d_tail)
d_min = min(d_head, d_tail)
return '%d_%d_%d' % (d_max, d_min, d_off)
else:
raise Exception('not implemented.')
|