|
import math
|
|
import torch
|
|
from math import log2
|
|
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
|
|
def is_power_of_two(val):
|
|
return log2(val).is_integer()
|
|
|
|
|
|
def default(val, d):
|
|
return val if exists(val) else d
|
|
|
|
|
|
def get_1d_dct(i, freq, L):
|
|
result = math.cos(math.pi * freq * (i + 0.5) / L) / math.sqrt(L)
|
|
return result * (1 if freq == 0 else math.sqrt(2))
|
|
|
|
|
|
def get_dct_weights(width, channel, fidx_u, fidx_v):
|
|
dct_weights = torch.zeros(1, channel, width, width)
|
|
c_part = channel // len(fidx_u)
|
|
|
|
for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)):
|
|
for x in range(width):
|
|
for y in range(width):
|
|
coor_value = get_1d_dct(x, u_x, width) * get_1d_dct(y, v_y, width)
|
|
dct_weights[:, i * c_part : (i + 1) * c_part, x, y] = coor_value
|
|
|
|
return dct_weights
|
|
|