#!/usr/bin/env python3 # -*- coding:utf-8 -*- import time from contextlib import contextmanager from copy import deepcopy import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from yolov6.utils.events import LOGGER try: import thop # for FLOPs computation except ImportError: thop = None @contextmanager def torch_distributed_zero_first(local_rank: int): """ Decorator to make all processes in distributed training wait for each local_master to do something. """ if local_rank not in [-1, 0]: dist.barrier(device_ids=[local_rank]) yield if local_rank == 0: dist.barrier(device_ids=[0]) def time_sync(): # Waits for all kernels in all streams on a CUDA device to complete if cuda is available. if torch.cuda.is_available(): torch.cuda.synchronize() return time.time() def initialize_weights(model): for m in model.modules(): t = type(m) if t is nn.Conv2d: pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif t is nn.BatchNorm2d: m.eps = 1e-3 m.momentum = 0.03 elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: m.inplace = True def fuse_conv_and_bn(conv, bn): # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ fusedconv = ( nn.Conv2d( conv.in_channels, conv.out_channels, kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, groups=conv.groups, bias=True, ) .requires_grad_(False) .to(conv.weight.device) ) # prepare filters w_conv = conv.weight.clone().view(conv.out_channels, -1) w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) # prepare spatial bias b_conv = ( torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias ) b_bn = bn.bias - bn.weight.mul(bn.running_mean).div( torch.sqrt(bn.running_var + bn.eps) ) fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) return fusedconv def fuse_model(model): from yolov6.layers.common import Conv for m in model.modules(): if type(m) is Conv and hasattr(m, "bn"): m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv delattr(m, "bn") # remove batchnorm m.forward = m.forward_fuse # update forward return model def get_model_info(model, img_size=640): """Get model Params and GFlops. Code base on https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/utils/model_utils.py """ from thop import profile stride = 32 img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) flops, params = profile(deepcopy(model), inputs=(img,), verbose=False) params /= 1e6 flops /= 1e9 img_size = img_size if isinstance(img_size, list) else [img_size, img_size] flops *= img_size[0] * img_size[1] / stride / stride * 2 # Gflops info = "Params: {:.2f}M, Gflops: {:.2f}".format(params, flops) return info