import enum import time from typing import List, Tuple, Type import torch import warnings import os import thop from ...common.others import get_cur_time_str class ModelSaveMethod(enum.Enum): """ - WEIGHT: save model by `torch.save(model.state_dict(), ...)` - FULL: save model by `torch.save(model, ...)` - JIT: convert model to JIT format and save it by `torch.jit.save(jit_model, ...)` """ WEIGHT = 0 FULL = 1 JIT = 2 def save_model(model: torch.nn.Module, model_file_path: str, save_method: ModelSaveMethod, model_input_size: Tuple[int]=None): """Save a PyTorch model. Args: model (torch.nn.Module): A PyTorch model. model_file_path (str): Target model file path. save_method (ModelSaveMethod): The method to save model. model_input_size (Tuple[int], optional): \ This is required if :attr:`save_method` is :attr:`ModelSaveMethod.JIT`. \ Typically be `(1, 3, 32, 32)` or `(1, 3, 224, 224)`. \ Defaults to None. """ model.eval() if save_method == ModelSaveMethod.WEIGHT: torch.save(model.state_dict(), model_file_path) elif save_method == ModelSaveMethod.FULL: with warnings.catch_warnings(): warnings.simplefilter("ignore") torch.save(model, model_file_path) elif save_method == ModelSaveMethod.JIT: assert model_input_size is not None dummy_input = torch.ones(model_input_size, device=get_model_device(model)) new_model = torch.jit.trace(model, dummy_input, check_trace=False) torch.jit.save(new_model, model_file_path) def get_model_size(model: torch.nn.Module, return_MB=False): """Get size of a PyTorch model (default in Byte). Args: model (torch.nn.Module): A PyTorch model. return_MB (bool, optional): Return result in MB (/= 1024**2). Defaults to False. Returns: int: Model size. """ pid = os.getpid() tmp_model_file_path = './tmp-get-model-size-{}-{}.model'.format(pid, get_cur_time_str()) save_model(model, tmp_model_file_path, ModelSaveMethod.WEIGHT) model_size = os.path.getsize(tmp_model_file_path) os.remove(tmp_model_file_path) if return_MB: model_size /= 1024**2 return model_size def get_model_device(model: torch.nn.Module): """Get device of a PyTorch model. Args: model (torch.nn.Module): A PyTorch model. Returns: str: The device of :attr:`model` ('cpu' or 'cuda:x'). """ return list(model.parameters())[0].device def get_model_latency(model: torch.nn.Module, model_input_size: Tuple[int], sample_num: int, device: str, warmup_sample_num: int, return_detail=False): """Get the latency (inference time) of a PyTorch model. Reference: https://deci.ai/resources/blog/measure-inference-time-deep-neural-networks/ Args: model (torch.nn.Module): A PyTorch model. model_input_size (Tuple[int]): Typically be `(1, 3, 32, 32)` or `(1, 3, 224, 224)`. sample_num (int): How many inputs which size is :attr:`model_input_size` will be tested and compute the average latency as result. device (str): Typically be 'cpu' or 'cuda'. warmup_sample_num (int): Let model perform some dummy inference to warm up the test environment to avoid measurement loss. return_detail (bool, optional): Beside the average latency, return all result measured. Defaults to False. Returns: Union[float, Tuple[float, List[float]]]: The average latency (and all lantecy data) of :attr:`model`. """ if isinstance(model_input_size, tuple): dummy_input = torch.rand(model_input_size).to(device) else: dummy_input = model_input_size model = model.to(device) model.eval() # warm up with torch.no_grad(): for _ in range(warmup_sample_num): model(dummy_input) infer_time_list = [] if device == 'cuda' or 'cuda' in str(device): with torch.no_grad(): for _ in range(sample_num): s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) s.record() model(dummy_input) e.record() torch.cuda.synchronize() cur_model_infer_time = s.elapsed_time(e) / 1000. infer_time_list += [cur_model_infer_time] else: with torch.no_grad(): for _ in range(sample_num): start = time.time() model(dummy_input) cur_model_infer_time = time.time() - start infer_time_list += [cur_model_infer_time] avg_infer_time = sum(infer_time_list) / sample_num if return_detail: return avg_infer_time, infer_time_list return avg_infer_time def get_model_flops_and_params(model: torch.nn.Module, model_input_size: Tuple[int], return_M=False): """Get FLOPs and number of parameters of a PyTorch model. Args: model (torch.nn.Module): A PyTorch model. model_input_size (Tuple[int]): Typically be `(1, 3, 32, 32)` or `(1, 3, 224, 224)`. Returns: Tuple[float, float]: FLOPs and number of parameters of :attr:`model`. """ device = get_model_device(model) ops, param = thop.profile(model, (torch.ones(model_input_size).to(device), ), verbose=False) ops, param = ops * 2, param if return_M: ops, param = ops / 1e6, param / 1e6 return ops, param def get_module(model: torch.nn.Module, module_name: str): """Get a module from a PyTorch model. Example: >>> from torchvision.models import resnet18 >>> model = resnet18() >>> get_module(model, 'layer1.0') BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) Args: model (torch.nn.Module): A PyTorch model. module_name (str): Module name. Returns: torch.nn.Module: Corrsponding module. """ for name, module in model.named_modules(): if name == module_name: return module return None def get_parameter(model: torch.nn.Module, param_name: str): return getattr( get_module(model, '.'.join(param_name.split('.')[0: -1])), param_name.split('.')[-1] ) def get_super_module(model: torch.nn.Module, module_name: str): """Get the super module of a module in a PyTorch model. Example: >>> from torchvision.models import resnet18 >>> model = resnet18() >>> get_super_module(model, 'layer1.0.conv1') BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) Args: model (torch.nn.Module): A PyTorch model. module_name (str): Module name. Returns: torch.nn.Module: Super module of module :attr:`module_name`. """ super_module_name = '.'.join(module_name.split('.')[0:-1]) return get_module(model, super_module_name) def set_module(model: torch.nn.Module, module_name: str, module: torch.nn.Module): """Set module in a PyTorch model. Example: >>> from torchvision.models import resnet18 >>> model = resnet18() >>> set_module(model, 'layer1.0', torch.nn.Conv2d(64, 64, 3)) >>> model ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( --> (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BasicBlock( ... ) ... ) ... ) Args: model (torch.nn.Module): A PyTorch model. module_name (str): Module name. module (torch.nn.Module): Target module which will be set into :attr:`model`. """ super_module = get_super_module(model, module_name) setattr(super_module, module_name.split('.')[-1], module) def get_ith_layer(model: torch.nn.Module, i: int): """Get i-th layer in a PyTorch model. Example: >>> from torchvision.models import vgg16 >>> model = vgg16() >>> get_ith_layer(model, 5) Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) Args: model (torch.nn.Module): A PyTorch model. i (int): Index of target layer. Returns: torch.nn.Module: i-th layer in :attr:`model`. """ j = 0 for module in model.modules(): if len(list(module.children())) > 0: continue if j == i: return module j += 1 return None def get_ith_layer_name(model: torch.nn.Module, i: int): """Get the name of i-th layer in a PyTorch model. Example: >>> from torchvision.models import vgg16 >>> model = vgg16() >>> get_ith_layer_name(model, 5) 'features.5' Args: model (torch.nn.Module): A PyTorch model. i (int): Index of target layer. Returns: str: The name of i-th layer in :attr:`model`. """ j = 0 for name, module in model.named_modules(): if len(list(module.children())) > 0: continue if j == i: return name j += 1 return None def set_ith_layer(model: torch.nn.Module, i: int, layer: torch.nn.Module): """Set i-th layer in a PyTorch model. Example: >>> from torchvision.models import vgg16 >>> model = vgg16() >>> model VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ... ) ... ) >>> set_ith_layer(model, 2, torch.nn.Conv2d(64, 128, 3)) VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) --> (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ... ) ... ) Args: model (torch.nn.Module): A PyTorch model. i (int): Index of target layer. layer (torch.nn.Module): The layer which will be set into :attr:`model`. """ j = 0 for name, module in model.named_modules(): if len(list(module.children())) > 0: continue if j == i: set_module(model, name, layer) return j += 1 def get_all_specific_type_layers_name(model: torch.nn.Module, types: Tuple[Type[torch.nn.Module]]): """Get names of all layers which are give types in a PyTorch model. (e.g. `Conv2d`, `Linear`) Example: >>> from torchvision.models import vgg16 >>> model = vgg16() >>> get_all_specific_type_layers_name(model, (torch.nn.Conv2d)) ['features.0', 'features.2', 'features.5', ...] Args: model (torch.nn.Module): A PyTorch model. types (Tuple[Type[torch.nn.Module]]): Target types, e.g. `(e.g. torch.nn.Conv2d, torch.nn.Linear)` Returns: List[str]: Names of all layers which are give types. """ res = [] for name, m in model.named_modules(): if isinstance(m, types): res += [name] return res class LayerActivation: """Collect the input and output of a middle module of a PyTorch model during inference. Layer is a wide concept in this class. A module (e.g. ResBlock in ResNet) can be also regarded as a "layer". Example: >>> from torchvision.models import vgg16 >>> model = vgg16() >>> # collect the input and output of 5th layer in VGG16 >>> layer_activation = LayerActivation(get_ith_layer(model, 5), 'cuda') >>> model(torch.rand((1, 3, 224, 224))) >>> layer_activation.input tensor([[...]]) >>> layer_activation.output tensor([[...]]) >>> layer_activation.remove() """ def __init__(self, layer: torch.nn.Module, detach: bool, device: str): """Register forward hook on corresponding layer. Args: layer (torch.nn.Module): Target layer. device (str): Where the collected data is located. """ self.hook = layer.register_forward_hook(self._hook_fn) self.detach = detach self.device = device self.input: torch.Tensor = None self.output: torch.Tensor = None self.layer = layer def __str__(self): return '- ' + str(self.layer) def _hook_fn(self, module, input, output): # TODO: input or output may be a tuple if isinstance(input, tuple): self.input = input[0].to(self.device) else: self.input = input.to(self.device) if isinstance(output, tuple): self.output = output[0].to(self.device) else: self.output = output.to(self.device) if self.detach: self.input = self.input.detach() self.output = self.output.detach() def remove(self): """Remove the hook in the model to avoid performance effect. Use this after using the collected data. """ self.hook.remove() class LayerActivation2: """Collect the input and output of a middle module of a PyTorch model during inference. Layer is a wide concept in this class. A module (e.g. ResBlock in ResNet) can be also regarded as a "layer". Example: >>> from torchvision.models import vgg16 >>> model = vgg16() >>> # collect the input and output of 5th layer in VGG16 >>> layer_activation = LayerActivation(get_ith_layer(model, 5), 'cuda') >>> model(torch.rand((1, 3, 224, 224))) >>> layer_activation.input tensor([[...]]) >>> layer_activation.output tensor([[...]]) >>> layer_activation.remove() """ def __init__(self, layer: torch.nn.Module): """Register forward hook on corresponding layer. Args: layer (torch.nn.Module): Target layer. device (str): Where the collected data is located. """ assert layer is not None self.hook = layer.register_forward_hook(self._hook_fn) self.input: torch.Tensor = None self.output: torch.Tensor = None self.layer = layer def __str__(self): return '- ' + str(self.layer) def _hook_fn(self, module, input, output): self.input = input self.output = output def remove(self): """Remove the hook in the model to avoid performance effect. Use this after using the collected data. """ self.hook.remove() class LayerActivation3: """Collect the input and output of a middle module of a PyTorch model during inference. Layer is a wide concept in this class. A module (e.g. ResBlock in ResNet) can be also regarded as a "layer". Example: >>> from torchvision.models import vgg16 >>> model = vgg16() >>> # collect the input and output of 5th layer in VGG16 >>> layer_activation = LayerActivation(get_ith_layer(model, 5), 'cuda') >>> model(torch.rand((1, 3, 224, 224))) >>> layer_activation.input tensor([[...]]) >>> layer_activation.output tensor([[...]]) >>> layer_activation.remove() """ def __init__(self, layer: torch.nn.Module, detach: bool, device: str): """Register forward hook on corresponding layer. Args: layer (torch.nn.Module): Target layer. device (str): Where the collected data is located. """ self.hook = layer.register_forward_hook(self._hook_fn) self.detach = detach self.device = device self.input: torch.Tensor = None self.output: torch.Tensor = None self.layer = layer def __str__(self): return '- ' + str(self.layer) def _hook_fn(self, module, input, output): # TODO: input or output may be a tuple self.input = input self.output = output # if self.detach: # self.input = self.input.detach() # self.output = self.output.detach() def remove(self): """Remove the hook in the model to avoid performance effect. Use this after using the collected data. """ self.hook.remove() class LayerActivationWrapper: """A wrapper of :attr:`LayerActivation` which has the same API, but broaden the concept "layer". Now a series of layers can be regarded as "hyper-layer" in this class. Example: >>> from torchvision.models import vgg16 >>> model = vgg16() >>> # collect the input of 5th layer, and output of 7th layer in VGG16 >>> # i.e. regard 5th~7th layer as a whole module, >>> # and collect the input and output of this module >>> layer_activation = LayerActivationWrapper([ LayerActivation(get_ith_layer(model, 5), 'cuda'), LayerActivation(get_ith_layer(model, 6), 'cuda') LayerActivation(get_ith_layer(model, 7), 'cuda') ]) >>> model(torch.rand((1, 3, 224, 224))) >>> layer_activation.input tensor([[...]]) >>> layer_activation.output tensor([[...]]) >>> layer_activation.remove() """ def __init__(self, las: List[LayerActivation]): """ Args: las (List[LayerActivation]): The layer activations of a series of layers. """ self.las = las def __str__(self): return '\n'.join([str(la) for la in self.las]) @property def input(self): """Get the collected input data of first layer. Returns: torch.Tensor: Collected input data of first layer. """ return self.las[0].input @property def output(self): """Get the collected input data of last layer. Returns: torch.Tensor: Collected input data of last layer. """ return self.las[-1].output def remove(self): """Remove all hooks in the model to avoid performance effect. Use this after using the collected data. """ [la.remove() for la in self.las] class TimeProfiler: """ (NOT VERIFIED. DON'T USE ME) """ def __init__(self, layer: torch.nn, device): self.before_infer_hook = layer.register_forward_pre_hook(self.before_hook_fn) self.after_infer_hook = layer.register_forward_hook(self.after_hook_fn) self.device = device self.infer_time = None self._start_time = None if self.device != 'cpu': self.s, self.e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) def before_hook_fn(self, module, input): if self.device == 'cpu': self._start_time = time.time() else: self.s.record() def after_hook_fn(self, module, input, output): if self.device == 'cpu': self.infer_time = time.time() - self._start_time else: self.e.record() torch.cuda.synchronize() self.infer_time = self.s.elapsed_time(self.e) / 1000. def remove(self): self.before_infer_hook.remove() self.after_infer_hook.remove() class TimeProfilerWrapper: """ (NOT VERIFIED. DON'T USE ME) """ def __init__(self, tps: List[TimeProfiler]): self.tps = tps @property def infer_time(self): return sum([tp.infer_time for tp in self.tps]) def remove(self): [tp.remove() for tp in self.tps]