File size: 4,752 Bytes
2571f24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
__all__ = ['restrict']

from .api import grid_push
from .utils import make_list, meshgrid_ij
from . import backend, jitfields
import torch


def restrict(image, factor=None, shape=None, anchor='c',
             interpolation=1, reduce_sum=False, **kwargs):
    """Restrict an image by a factor or to a specific shape.

    Notes
    -----
    .. A least one of `factor` and `shape` must be specified
    .. If `anchor in ('centers', 'edges')`, exactly one of `factor` or
       `shape must be specified.
    .. If `anchor in ('first', 'last')`, `factor` must be provided even
       if `shape` is specified.
    .. Because of rounding, it is in general not assured that
       `resize(resize(x, f), 1/f)` returns a tensor with the same shape as x.

        edges          centers          first           last
    e - + - + - e   + - + - + - +   + - + - + - +   + - + - + - +
    | . | . | . |   | c | . | c |   | f | . | . |   | . | . | . |
    + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +
    | . | . | . |   | . | . | . |   | . | . | . |   | . | . | . |
    + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +
    | . | . | . |   | c | . | c |   | . | . | . |   | . | . | l |
    e _ + _ + _ e   + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +

    Parameters
    ----------
    image : (batch, channel, *inshape) tensor
        Image to resize
    factor : float or list[float], optional
        Resizing factor
        * > 1 : larger image <-> smaller voxels
        * < 1 : smaller image <-> larger voxels
    shape : (ndim,) list[int], optional
        Output shape
    anchor : {'centers', 'edges', 'first', 'last'} or list, default='centers'
        * In cases 'c' and 'e', the volume shape is multiplied by the
          zoom factor (and eventually truncated), and two anchor points
          are used to determine the voxel size.
        * In cases 'f' and 'l', a single anchor point is used so that
          the voxel size is exactly divided by the zoom factor.
          This case with an integer factor corresponds to subslicing
          the volume (e.g., `vol[::f, ::f, ::f]`).
        * A list of anchors (one per dimension) can also be provided.
    interpolation : int or sequence[int], default=1
        Interpolation order.
    reduce_sum : bool, default=False
        Do not normalize by the number of accumulated values per voxel

    Returns
    -------
    restricted : (batch, channel, *shape) tensor
        Restricted image

    """
    if backend.jitfields and jitfields.available:
        return jitfields.restrict(image, factor, shape, anchor,
                                  interpolation, reduce_sum, **kwargs)

    factor = make_list(factor) if factor else []
    shape = make_list(shape) if shape else []
    anchor = make_list(anchor)
    nb_dim = max(len(factor), len(shape), len(anchor)) or (image.dim() - 2)
    anchor = [a[0].lower() for a in make_list(anchor, nb_dim)]
    bck = dict(dtype=image.dtype, device=image.device)

    # compute output shape
    inshape = image.shape[-nb_dim:]
    if factor:
        factor = make_list(factor, nb_dim)
    elif not shape:
        raise ValueError('One of `factor` or `shape` must be provided')
    if shape:
        shape = make_list(shape, nb_dim)
    else:
        shape = [int(i/f) for i, f in zip(inshape, factor)]

    if not factor:
        factor = [i/o for o, i in zip(shape, inshape)]

    # compute transformation grid
    lin = []
    fullscale = 1
    for anch, f, inshp, outshp in zip(anchor, factor, inshape, shape):
        if anch == 'c':    # centers
            lin.append(torch.linspace(0, outshp - 1, inshp, **bck))
            fullscale *= (inshp - 1) / (outshp - 1)
        elif anch == 'e':  # edges
            scale = outshp / inshp
            shift = 0.5 * (scale - 1)
            fullscale *= scale
            lin.append(torch.arange(0., inshp, **bck) * scale + shift)
        elif anch == 'f':  # first voxel
            # scale = 1/f
            # shift = 0
            fullscale *= 1/f
            lin.append(torch.arange(0., inshp, **bck) / f)
        elif anch == 'l':  # last voxel
            # scale = 1/f
            shift = (outshp - 1) - (inshp - 1) / f
            fullscale *= 1/f
            lin.append(torch.arange(0., inshp, **bck) / f + shift)
        else:
            raise ValueError('Unknown anchor {}'.format(anch))

    # scatter
    kwargs.setdefault('bound', 'nearest')
    kwargs.setdefault('extrapolate', True)
    kwargs.setdefault('interpolation', interpolation)
    kwargs.setdefault('prefilter', False)
    grid = torch.stack(meshgrid_ij(*lin), dim=-1)
    resized = grid_push(image, grid, shape, **kwargs)
    if not reduce_sum:
        resized /= fullscale

    return resized