# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch from torch import nn def sum_tensor(inp, axes, keepdim=False): axes = np.unique(axes).astype(int) if keepdim: for ax in axes: inp = inp.sum(int(ax), keepdim=True) else: for ax in sorted(axes, reverse=True): inp = inp.sum(int(ax)) return inp def mean_tensor(inp, axes, keepdim=False): axes = np.unique(axes).astype(int) if keepdim: for ax in axes: inp = inp.mean(int(ax), keepdim=True) else: for ax in sorted(axes, reverse=True): inp = inp.mean(int(ax)) return inp def flip(x, dim): """ flips the tensor at dimension dim (mirroring!) :param x: :param dim: :return: """ indices = [slice(None)] * x.dim() indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) return x[tuple(indices)]