fffiloni's picture
Create model.py
c381dd3
raw
history blame
673 Bytes
import logging
import os
import time
import pickle
import torch
import torch.distributed as dist
from fvcore.nn import FlopCountAnalysis
from fvcore.nn import flop_count_table
from fvcore.nn import flop_count_str
logger = logging.getLogger(__name__)
NORM_MODULES = [
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
# NaiveSyncBatchNorm inherits from BatchNorm2d
torch.nn.GroupNorm,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.LocalResponseNorm,
]
def register_norm_module(cls):
NORM_MODULES.append(cls)
return cls