# Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved. # This program is free software; you can redistribute it and/or modify # it under the terms of the MIT License. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # MIT License for more details. import numpy as np import torch class BaseModule(torch.nn.Module): def __init__(self): super(BaseModule, self).__init__() @property def nparams(self): """ Returns number of trainable parameters of the module. """ num_params = 0 for name, param in self.named_parameters(): if param.requires_grad: num_params += np.prod(param.detach().cpu().numpy().shape) return num_params def relocate_input(self, x: list): """ Relocates provided tensors to the same device set for the module. """ device = next(self.parameters()).device for i in range(len(x)): if isinstance(x[i], torch.Tensor) and x[i].device != device: x[i] = x[i].to(device) return x