michaelriedl's picture
Initial dump
002ca81
raw
history blame contribute delete
861 Bytes
import math
import torch
from math import log2
def exists(val):
return val is not None
def is_power_of_two(val):
return log2(val).is_integer()
def default(val, d):
return val if exists(val) else d
def get_1d_dct(i, freq, L):
result = math.cos(math.pi * freq * (i + 0.5) / L) / math.sqrt(L)
return result * (1 if freq == 0 else math.sqrt(2))
def get_dct_weights(width, channel, fidx_u, fidx_v):
dct_weights = torch.zeros(1, channel, width, width)
c_part = channel // len(fidx_u)
for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)):
for x in range(width):
for y in range(width):
coor_value = get_1d_dct(x, u_x, width) * get_1d_dct(y, v_y, width)
dct_weights[:, i * c_part : (i + 1) * c_part, x, y] = coor_value
return dct_weights