yolov6 / yolov6 /utils /torch_utils.py
yourusername's picture
:beers: cheers
2a27594
raw history blame
No virus
3.38 kB
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import time
from contextlib import contextmanager
from copy import deepcopy
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from yolov6.utils.events import LOGGER
try:
import thop # for FLOPs computation
except ImportError:
thop = None
@contextmanager
def torch_distributed_zero_first(local_rank: int):
"""
Decorator to make all processes in distributed training wait for each local_master to do something.
"""
if local_rank not in [-1, 0]:
dist.barrier(device_ids=[local_rank])
yield
if local_rank == 0:
dist.barrier(device_ids=[0])
def time_sync():
# Waits for all kernels in all streams on a CUDA device to complete if cuda is available.
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.time()
def initialize_weights(model):
for m in model.modules():
t = type(m)
if t is nn.Conv2d:
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif t is nn.BatchNorm2d:
m.eps = 1e-3
m.momentum = 0.03
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
m.inplace = True
def fuse_conv_and_bn(conv, bn):
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
fusedconv = (
nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
groups=conv.groups,
bias=True,
)
.requires_grad_(False)
.to(conv.weight.device)
)
# prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
# prepare spatial bias
b_conv = (
torch.zeros(conv.weight.size(0), device=conv.weight.device)
if conv.bias is None
else conv.bias
)
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
torch.sqrt(bn.running_var + bn.eps)
)
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
return fusedconv
def fuse_model(model):
from yolov6.layers.common import Conv
for m in model.modules():
if type(m) is Conv and hasattr(m, "bn"):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, "bn") # remove batchnorm
m.forward = m.forward_fuse # update forward
return model
def get_model_info(model, img_size=640):
"""Get model Params and GFlops.
Code base on https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/utils/model_utils.py
"""
from thop import profile
stride = 32
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device)
flops, params = profile(deepcopy(model), inputs=(img,), verbose=False)
params /= 1e6
flops /= 1e9
img_size = img_size if isinstance(img_size, list) else [img_size, img_size]
flops *= img_size[0] * img_size[1] / stride / stride * 2 # Gflops
info = "Params: {:.2f}M, Gflops: {:.2f}".format(params, flops)
return info