|
|
|
""" |
|
PyTorch utils |
|
""" |
|
|
|
import datetime |
|
import logging |
|
import math |
|
import os |
|
import platform |
|
import subprocess |
|
import time |
|
from contextlib import contextmanager |
|
from copy import deepcopy |
|
from pathlib import Path |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision |
|
|
|
from utils.general import LOGGER |
|
|
|
try: |
|
import thop |
|
except ImportError: |
|
thop = None |
|
|
|
|
|
@contextmanager |
|
def torch_distributed_zero_first(local_rank: int): |
|
""" |
|
Decorator to make all processes in distributed training wait for each local_master to do something. |
|
""" |
|
if local_rank not in [-1, 0]: |
|
dist.barrier(device_ids=[local_rank]) |
|
yield |
|
if local_rank == 0: |
|
dist.barrier(device_ids=[0]) |
|
|
|
|
|
def date_modified(path=__file__): |
|
|
|
t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime) |
|
return f'{t.year}-{t.month}-{t.day}' |
|
|
|
|
|
def git_describe(path=Path(__file__).parent): |
|
|
|
s = f'git -C {path} describe --tags --long --always' |
|
try: |
|
return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1] |
|
except subprocess.CalledProcessError as e: |
|
return '' |
|
|
|
|
|
def select_device(device='', batch_size=None): |
|
|
|
s = f'YOLOv5 π {git_describe() or date_modified()} torch {torch.__version__} ' |
|
device = str(device).strip().lower().replace('cuda:', '') |
|
cpu = device == 'cpu' |
|
if cpu: |
|
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' |
|
elif device: |
|
os.environ['CUDA_VISIBLE_DEVICES'] = device |
|
assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' |
|
|
|
cuda = not cpu and torch.cuda.is_available() |
|
if cuda: |
|
devices = device.split(',') if device else '0' |
|
n = len(devices) |
|
if n > 1 and batch_size: |
|
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}' |
|
space = ' ' * (len(s) + 1) |
|
for i, d in enumerate(devices): |
|
p = torch.cuda.get_device_properties(i) |
|
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2:.0f}MiB)\n" |
|
else: |
|
s += 'CPU\n' |
|
|
|
LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) |
|
return torch.device('cuda:0' if cuda else 'cpu') |
|
|
|
|
|
def time_sync(): |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.synchronize() |
|
return time.time() |
|
|
|
|
|
def profile(input, ops, n=10, device=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
results = [] |
|
logging.basicConfig(format="%(message)s", level=logging.INFO) |
|
device = device or select_device() |
|
print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}" |
|
f"{'input':>24s}{'output':>24s}") |
|
|
|
for x in input if isinstance(input, list) else [input]: |
|
x = x.to(device) |
|
x.requires_grad = True |
|
for m in ops if isinstance(ops, list) else [ops]: |
|
m = m.to(device) if hasattr(m, 'to') else m |
|
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m |
|
tf, tb, t = 0, 0, [0, 0, 0] |
|
try: |
|
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 |
|
except: |
|
flops = 0 |
|
|
|
try: |
|
for _ in range(n): |
|
t[0] = time_sync() |
|
y = m(x) |
|
t[1] = time_sync() |
|
try: |
|
_ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward() |
|
t[2] = time_sync() |
|
except Exception as e: |
|
|
|
t[2] = float('nan') |
|
tf += (t[1] - t[0]) * 1000 / n |
|
tb += (t[2] - t[1]) * 1000 / n |
|
mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 |
|
s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' |
|
s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list' |
|
p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 |
|
print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}') |
|
results.append([p, flops, mem, tf, tb, s_in, s_out]) |
|
except Exception as e: |
|
print(e) |
|
results.append(None) |
|
torch.cuda.empty_cache() |
|
return results |
|
|
|
|
|
def is_parallel(model): |
|
|
|
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) |
|
|
|
|
|
def de_parallel(model): |
|
|
|
return model.module if is_parallel(model) else model |
|
|
|
|
|
def intersect_dicts(da, db, exclude=()): |
|
|
|
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape} |
|
|
|
|
|
def initialize_weights(model): |
|
for m in model.modules(): |
|
t = type(m) |
|
if t is nn.Conv2d: |
|
pass |
|
elif t is nn.BatchNorm2d: |
|
m.eps = 1e-3 |
|
m.momentum = 0.03 |
|
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: |
|
m.inplace = True |
|
|
|
|
|
def find_modules(model, mclass=nn.Conv2d): |
|
|
|
return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)] |
|
|
|
|
|
def sparsity(model): |
|
|
|
a, b = 0, 0 |
|
for p in model.parameters(): |
|
a += p.numel() |
|
b += (p == 0).sum() |
|
return b / a |
|
|
|
|
|
def prune(model, amount=0.3): |
|
|
|
import torch.nn.utils.prune as prune |
|
print('Pruning model... ', end='') |
|
for name, m in model.named_modules(): |
|
if isinstance(m, nn.Conv2d): |
|
prune.l1_unstructured(m, name='weight', amount=amount) |
|
prune.remove(m, 'weight') |
|
print(' %.3g global sparsity' % sparsity(model)) |
|
|
|
|
|
def fuse_conv_and_bn(conv, bn): |
|
|
|
fusedconv = nn.Conv2d(conv.in_channels, |
|
conv.out_channels, |
|
kernel_size=conv.kernel_size, |
|
stride=conv.stride, |
|
padding=conv.padding, |
|
groups=conv.groups, |
|
bias=True).requires_grad_(False).to(conv.weight.device) |
|
|
|
|
|
w_conv = conv.weight.clone().view(conv.out_channels, -1) |
|
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) |
|
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) |
|
|
|
|
|
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias |
|
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) |
|
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) |
|
|
|
return fusedconv |
|
|
|
|
|
def model_info(model, verbose=False, img_size=640): |
|
|
|
n_p = sum(x.numel() for x in model.parameters()) |
|
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) |
|
if verbose: |
|
print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}") |
|
for i, (name, p) in enumerate(model.named_parameters()): |
|
name = name.replace('module_list.', '') |
|
print('%5g %40s %9s %12g %20s %10.3g %10.3g' % |
|
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std())) |
|
|
|
try: |
|
from thop import profile |
|
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 |
|
img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) |
|
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 |
|
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] |
|
fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) |
|
except (ImportError, Exception): |
|
fs = '' |
|
|
|
LOGGER.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}") |
|
|
|
|
|
def load_classifier(name='resnet101', n=2): |
|
|
|
model = torchvision.models.__dict__[name](pretrained=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
filters = model.fc.weight.shape[1] |
|
model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True) |
|
model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True) |
|
model.fc.out_features = n |
|
return model |
|
|
|
|
|
def scale_img(img, ratio=1.0, same_shape=False, gs=32): |
|
|
|
if ratio == 1.0: |
|
return img |
|
else: |
|
h, w = img.shape[2:] |
|
s = (int(h * ratio), int(w * ratio)) |
|
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) |
|
if not same_shape: |
|
h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w)) |
|
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) |
|
|
|
|
|
def copy_attr(a, b, include=(), exclude=()): |
|
|
|
for k, v in b.__dict__.items(): |
|
if (len(include) and k not in include) or k.startswith('_') or k in exclude: |
|
continue |
|
else: |
|
setattr(a, k, v) |
|
|
|
|
|
class EarlyStopping: |
|
|
|
def __init__(self, patience=30): |
|
self.best_fitness = 0.0 |
|
self.best_epoch = 0 |
|
self.patience = patience or float('inf') |
|
self.possible_stop = False |
|
|
|
def __call__(self, epoch, fitness): |
|
if fitness >= self.best_fitness: |
|
self.best_epoch = epoch |
|
self.best_fitness = fitness |
|
delta = epoch - self.best_epoch |
|
self.possible_stop = delta >= (self.patience - 1) |
|
stop = delta >= self.patience |
|
if stop: |
|
LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. ' |
|
f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n' |
|
f'To update EarlyStopping(patience={self.patience}) pass a new patience value, ' |
|
f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.') |
|
return stop |
|
|
|
|
|
class ModelEMA: |
|
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models |
|
Keep a moving average of everything in the model state_dict (parameters and buffers). |
|
This is intended to allow functionality like |
|
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage |
|
A smoothed version of the weights is necessary for some training schemes to perform well. |
|
This class is sensitive where it is initialized in the sequence of model init, |
|
GPU assignment and distributed training wrappers. |
|
""" |
|
|
|
def __init__(self, model, decay=0.9999, updates=0): |
|
|
|
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() |
|
|
|
|
|
self.updates = updates |
|
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) |
|
for p in self.ema.parameters(): |
|
p.requires_grad_(False) |
|
|
|
def update(self, model): |
|
|
|
with torch.no_grad(): |
|
self.updates += 1 |
|
d = self.decay(self.updates) |
|
|
|
msd = model.module.state_dict() if is_parallel(model) else model.state_dict() |
|
for k, v in self.ema.state_dict().items(): |
|
if v.dtype.is_floating_point: |
|
v *= d |
|
v += (1 - d) * msd[k].detach() |
|
|
|
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): |
|
|
|
copy_attr(self.ema, model, include, exclude) |
|
|