Spaces:
Runtime error
Runtime error
File size: 2,606 Bytes
51f6859 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
# Copyright (c) OpenMMLab. All rights reserved.
from torch.autograd import Function
from torch.nn import functional as F
class SigmoidGeometricMean(Function):
"""Forward and backward function of geometric mean of two sigmoid
functions.
This implementation with analytical gradient function substitutes
the autograd function of (x.sigmoid() * y.sigmoid()).sqrt(). The
original implementation incurs none during gradient backprapagation
if both x and y are very small values.
"""
@staticmethod
def forward(ctx, x, y):
x_sigmoid = x.sigmoid()
y_sigmoid = y.sigmoid()
z = (x_sigmoid * y_sigmoid).sqrt()
ctx.save_for_backward(x_sigmoid, y_sigmoid, z)
return z
@staticmethod
def backward(ctx, grad_output):
x_sigmoid, y_sigmoid, z = ctx.saved_tensors
grad_x = grad_output * z * (1 - x_sigmoid) / 2
grad_y = grad_output * z * (1 - y_sigmoid) / 2
return grad_x, grad_y
sigmoid_geometric_mean = SigmoidGeometricMean.apply
def interpolate_as(source, target, mode='bilinear', align_corners=False):
"""Interpolate the `source` to the shape of the `target`.
The `source` must be a Tensor, but the `target` can be a Tensor or a
np.ndarray with the shape (..., target_h, target_w).
Args:
source (Tensor): A 3D/4D Tensor with the shape (N, H, W) or
(N, C, H, W).
target (Tensor | np.ndarray): The interpolation target with the shape
(..., target_h, target_w).
mode (str): Algorithm used for interpolation. The options are the
same as those in F.interpolate(). Default: ``'bilinear'``.
align_corners (bool): The same as the argument in F.interpolate().
Returns:
Tensor: The interpolated source Tensor.
"""
assert len(target.shape) >= 2
def _interpolate_as(source, target, mode='bilinear', align_corners=False):
"""Interpolate the `source` (4D) to the shape of the `target`."""
target_h, target_w = target.shape[-2:]
source_h, source_w = source.shape[-2:]
if target_h != source_h or target_w != source_w:
source = F.interpolate(
source,
size=(target_h, target_w),
mode=mode,
align_corners=align_corners)
return source
if len(source.shape) == 3:
source = source[:, None, :, :]
source = _interpolate_as(source, target, mode, align_corners)
return source[:, 0, :, :]
else:
return _interpolate_as(source, target, mode, align_corners)
|