Spaces:
Running
Running
import math | |
import torch | |
def compute_same_pad(kernel_size, stride): | |
if isinstance(kernel_size, int): | |
kernel_size = [kernel_size] | |
if isinstance(stride, int): | |
stride = [stride] | |
assert len(stride) == len( | |
kernel_size | |
), "Pass kernel size and stride both as int, or both as equal length iterable" | |
return [((k - 1) * s + 1) // 2 for k, s in zip(kernel_size, stride)] | |
def uniform_binning_correction(x, n_bits=8): | |
"""Replaces x^i with q^i(x) = U(x, x + 1.0 / 256.0). | |
Args: | |
x: 4-D Tensor of shape (NCHW) | |
n_bits: optional. | |
Returns: | |
x: x ~ U(x, x + 1.0 / 256) | |
objective: Equivalent to -q(x)*log(q(x)). | |
""" | |
b, c, h, w = x.size() | |
n_bins = 2**n_bits | |
chw = c * h * w | |
x += torch.zeros_like(x).uniform_(0, 1.0 / n_bins) | |
objective = -math.log(n_bins) * chw * torch.ones(b, device=x.device) | |
return x, objective | |
def split_feature(tensor, type="split"): | |
""" | |
type = ["split", "cross"] | |
""" | |
C = tensor.size(1) | |
if type == "split": | |
# return tensor[:, : C // 2, ...], tensor[:, C // 2 :, ...] | |
return tensor[:, :1, ...], tensor[:, 1:, ...] | |
elif type == "cross": | |
# return tensor[:, 0::2, ...], tensor[:, 1::2, ...] | |
return tensor[:, 0::2, ...], tensor[:, 1::2, ...] | |