|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Misc functions. |
|
|
|
Mostly copy-paste from torchvision references or other public repos like DETR and DINO: |
|
https://github.com/facebookresearch/detr/blob/master/util/misc.py |
|
https://github.com/facebookresearch/dino/blob/main/utils.py |
|
""" |
|
import datetime |
|
import logging |
|
import os |
|
import subprocess |
|
import sys |
|
import time |
|
from collections import defaultdict, deque |
|
|
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
from torch import nn |
|
|
|
|
|
def get_logger(file_path_name): |
|
""" |
|
build a logger which both write on the desk and also on the terminal |
|
""" |
|
logger = logging.getLogger() |
|
logger.setLevel("INFO") |
|
BASIC_FORMAT = "%(levelname)s:%(message)s" |
|
DATE_FORMAT = "" |
|
formatter = logging.Formatter(BASIC_FORMAT, DATE_FORMAT) |
|
chlr = logging.StreamHandler() |
|
chlr.setFormatter(formatter) |
|
chlr.setLevel("INFO") |
|
fhlr = logging.FileHandler(file_path_name) |
|
fhlr.setFormatter(formatter) |
|
logger.addHandler(chlr) |
|
logger.addHandler(fhlr) |
|
|
|
return logger |
|
|
|
|
|
def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs): |
|
""" |
|
Re-start from checkpoint |
|
""" |
|
if not os.path.isfile(ckp_path): |
|
return |
|
print("Found checkpoint at {}".format(ckp_path)) |
|
|
|
|
|
checkpoint = torch.load(ckp_path, map_location="cpu") |
|
|
|
|
|
|
|
for key, value in kwargs.items(): |
|
if key in checkpoint and value is not None: |
|
try: |
|
msg = value.load_state_dict(checkpoint[key], strict=False) |
|
print( |
|
"=> loaded '{}' from checkpoint '{}' with msg {}".format( |
|
key, ckp_path, msg |
|
) |
|
) |
|
except TypeError: |
|
try: |
|
msg = value.load_state_dict(checkpoint[key]) |
|
print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path)) |
|
except ValueError: |
|
print( |
|
"=> failed to load '{}' from checkpoint: '{}'".format( |
|
key, ckp_path |
|
) |
|
) |
|
else: |
|
print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path)) |
|
|
|
|
|
if run_variables is not None: |
|
for var_name in run_variables: |
|
if var_name in checkpoint: |
|
run_variables[var_name] = checkpoint[var_name] |
|
|
|
|
|
def bool_flag(s): |
|
""" |
|
Parse boolean arguments from the command line. |
|
""" |
|
FALSY_STRINGS = {"off", "false", "0"} |
|
TRUTHY_STRINGS = {"on", "true", "1"} |
|
if s.lower() in FALSY_STRINGS: |
|
return False |
|
elif s.lower() in TRUTHY_STRINGS: |
|
return True |
|
else: |
|
raise argparse.ArgumentTypeError("invalid value for a boolean flag") |
|
|
|
|
|
def fix_random_seeds(seed=31): |
|
""" |
|
Fix random seeds. |
|
""" |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
np.random.seed(seed) |
|
|
|
|
|
def has_batchnorms(model): |
|
""" |
|
judge whether a model has batch normalization |
|
""" |
|
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) |
|
for name, module in model.named_modules(): |
|
if isinstance(module, bn_types): |
|
return True |
|
return False |
|
|
|
|
|
class SmoothedValue(object): |
|
"""Track a series of values and provide access to smoothed values over a |
|
window or the global series average. |
|
""" |
|
|
|
def __init__(self, window_size=20, fmt=None): |
|
if fmt is None: |
|
fmt = "{median:.6f} ({global_avg:.6f})" |
|
self.deque = deque(maxlen=window_size) |
|
self.total = 0.0 |
|
self.count = 0 |
|
self.fmt = fmt |
|
|
|
def update(self, value, n=1): |
|
self.deque.append(value) |
|
self.count += n |
|
self.total += value * n |
|
|
|
def synchronize_between_processes(self): |
|
""" |
|
Warning: does not synchronize the deque! |
|
""" |
|
if not is_dist_avail_and_initialized(): |
|
return |
|
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") |
|
dist.barrier() |
|
dist.all_reduce(t) |
|
t = t.tolist() |
|
self.count = int(t[0]) |
|
self.total = t[1] |
|
|
|
@property |
|
def median(self): |
|
d = torch.tensor(list(self.deque)) |
|
return d.median().item() |
|
|
|
@property |
|
def avg(self): |
|
d = torch.tensor(list(self.deque), dtype=torch.float32) |
|
return d.mean().item() |
|
|
|
@property |
|
def global_avg(self): |
|
return self.total / self.count |
|
|
|
@property |
|
def max(self): |
|
return max(self.deque) |
|
|
|
@property |
|
def value(self): |
|
return self.deque[-1] |
|
|
|
def __str__(self): |
|
return self.fmt.format( |
|
median=self.median, |
|
avg=self.avg, |
|
global_avg=self.global_avg, |
|
max=self.max, |
|
value=self.value, |
|
) |
|
|
|
|
|
class MetricLogger(object): |
|
""" |
|
build a Metric Logger |
|
""" |
|
|
|
def __init__(self, delimiter="\t"): |
|
self.meters = defaultdict(SmoothedValue) |
|
self.delimiter = delimiter |
|
|
|
def update(self, **kwargs): |
|
for k, v in kwargs.items(): |
|
if isinstance(v, torch.Tensor): |
|
v = v.item() |
|
assert isinstance(v, (float, int)) |
|
self.meters[k].update(v) |
|
|
|
def __getattr__(self, attr): |
|
if attr in self.meters: |
|
return self.meters[attr] |
|
if attr in self.__dict__: |
|
return self.__dict__[attr] |
|
raise AttributeError( |
|
"'{}' object has no attribute '{}'".format(type(self).__name__, attr) |
|
) |
|
|
|
def __str__(self): |
|
loss_str = [] |
|
for name, meter in self.meters.items(): |
|
loss_str.append("{}: {}".format(name, str(meter))) |
|
return self.delimiter.join(loss_str) |
|
|
|
def synchronize_between_processes(self): |
|
for meter in self.meters.values(): |
|
meter.synchronize_between_processes() |
|
|
|
def add_meter(self, name, meter): |
|
self.meters[name] = meter |
|
|
|
def log_every(self, iterable, print_freq, header=None): |
|
i = 0 |
|
if not header: |
|
header = "" |
|
start_time = time.time() |
|
end = time.time() |
|
iter_time = SmoothedValue(fmt="{avg:.6f}") |
|
data_time = SmoothedValue(fmt="{avg:.6f}") |
|
space_fmt = ":" + str(len(str(len(iterable)))) + "d" |
|
if torch.cuda.is_available(): |
|
log_msg = self.delimiter.join( |
|
[ |
|
header, |
|
"[{0" + space_fmt + "}/{1}]", |
|
"eta: {eta}", |
|
"{meters}", |
|
"time: {time}", |
|
"data: {data}", |
|
"max mem: {memory:.0f}", |
|
] |
|
) |
|
else: |
|
log_msg = self.delimiter.join( |
|
[ |
|
header, |
|
"[{0" + space_fmt + "}/{1}]", |
|
"eta: {eta}", |
|
"{meters}", |
|
"time: {time}", |
|
"data: {data}", |
|
] |
|
) |
|
MB = 1024.0 * 1024.0 |
|
for obj in iterable: |
|
data_time.update(time.time() - end) |
|
yield obj |
|
iter_time.update(time.time() - end) |
|
if i % print_freq == 0 or i == len(iterable) - 1: |
|
eta_seconds = iter_time.global_avg * (len(iterable) - i) |
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) |
|
if torch.cuda.is_available(): |
|
print( |
|
log_msg.format( |
|
i, |
|
len(iterable), |
|
eta=eta_string, |
|
meters=str(self), |
|
time=str(iter_time), |
|
data=str(data_time), |
|
memory=torch.cuda.max_memory_allocated() / MB, |
|
) |
|
) |
|
else: |
|
print( |
|
log_msg.format( |
|
i, |
|
len(iterable), |
|
eta=eta_string, |
|
meters=str(self), |
|
time=str(iter_time), |
|
data=str(data_time), |
|
) |
|
) |
|
i += 1 |
|
end = time.time() |
|
total_time = time.time() - start_time |
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
print( |
|
"{} Total time: {} ({:.6f} s / it)".format( |
|
header, total_time_str, total_time / len(iterable) |
|
) |
|
) |
|
|
|
|
|
def get_sha(): |
|
cwd = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
def _run(command): |
|
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() |
|
|
|
sha = "N/A" |
|
diff = "clean" |
|
branch = "N/A" |
|
try: |
|
sha = _run(["git", "rev-parse", "HEAD"]) |
|
subprocess.check_output(["git", "diff"], cwd=cwd) |
|
diff = _run(["git", "diff-index", "HEAD"]) |
|
diff = "has uncommited changes" if diff else "clean" |
|
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) |
|
except Exception: |
|
pass |
|
message = f"sha: {sha}, status: {diff}, branch: {branch}" |
|
return message |
|
|
|
|
|
def is_dist_avail_and_initialized(): |
|
""" |
|
judge whether distributed training is available and well-initialized |
|
""" |
|
if not dist.is_available(): |
|
return False |
|
if not dist.is_initialized(): |
|
return False |
|
return True |
|
|
|
|
|
def get_world_size(): |
|
""" |
|
get the world size |
|
""" |
|
if not is_dist_avail_and_initialized(): |
|
return 1 |
|
return dist.get_world_size() |
|
|
|
|
|
def get_rank(): |
|
""" |
|
get the rank |
|
""" |
|
if not is_dist_avail_and_initialized(): |
|
return 0 |
|
return dist.get_rank() |
|
|
|
|
|
def is_main_process(): |
|
""" |
|
judge whether the current node is the master node |
|
""" |
|
return get_rank() == 0 |
|
|
|
|
|
def save_on_master(*args, **kwargs): |
|
""" |
|
save checkpoint on the master node |
|
""" |
|
if is_main_process(): |
|
torch.save(*args, **kwargs) |
|
|
|
|
|
def setup_for_distributed(is_master): |
|
""" |
|
This function disables printing when not in master process |
|
""" |
|
import builtins as __builtin__ |
|
|
|
builtin_print = __builtin__.print |
|
|
|
def print(*args, **kwargs): |
|
force = kwargs.pop("force", False) |
|
if is_master or force: |
|
builtin_print(*args, **kwargs) |
|
|
|
__builtin__.print = print |
|
|
|
|
|
def init_distributed_ddpjob(args=None): |
|
""" |
|
initialize the ddp job |
|
""" |
|
if dist.is_available() and dist.is_initialized(): |
|
return dist.get_world_size(), dist.get_rank() |
|
|
|
try: |
|
os.environ["MASTER_PORT"] = "40101" |
|
torch.distributed.init_process_group(backend="nccl") |
|
except Exception: |
|
world_size, rank = 1, 0 |
|
print("distributed training not available") |
|
|
|
world_size = dist.get_world_size() |
|
rank = dist.get_rank() |
|
args.gpu = args.rank |
|
args.world_size, args.rank = world_size, rank |
|
return world_size, rank |
|
|
|
|
|
def init_distributed_mode(args): |
|
""" |
|
initialize the normal job |
|
""" |
|
|
|
if "RANK" in os.environ and "WORLD_SIZE" in os.environ: |
|
args.rank = int(os.environ["RANK"]) |
|
args.world_size = int(os.environ["WORLD_SIZE"]) |
|
args.gpu = int(os.environ.get("LOCAL_RANK", 0)) |
|
print( |
|
"args.rank", |
|
args.rank, |
|
"args.world_size", |
|
args.world_size, |
|
"args.gpu", |
|
args.gpu, |
|
) |
|
print("get_rank()", get_rank()) |
|
|
|
elif "SLURM_PROCID" in os.environ: |
|
args.rank = int(os.environ["SLURM_PROCID"]) |
|
args.gpu = args.rank % torch.cuda.device_count() |
|
|
|
|
|
elif torch.cuda.is_available(): |
|
print("Will run the code on one GPU.") |
|
args.rank, args.gpu, args.world_size = 0, 0, 1 |
|
os.environ["MASTER_ADDR"] = "127.0.0.1" |
|
os.environ["MASTER_PORT"] = "2950" |
|
else: |
|
print("Does not support training without GPU.") |
|
sys.exit(1) |
|
|
|
os.environ["MASTER_PORT"] = "6542" |
|
|
|
dist.init_process_group( |
|
backend="nccl", |
|
init_method=args.dist_url, |
|
world_size=args.world_size, |
|
rank=args.rank, |
|
) |
|
|
|
torch.cuda.set_device(args.gpu) |
|
print( |
|
"| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True |
|
) |
|
dist.barrier() |
|
setup_for_distributed(args.rank == 0) |
|
|
|
|
|
def accuracy(output, target, topk=(1,)): |
|
""" |
|
Computes the accuracy over the k top predictions for the specified values of k |
|
""" |
|
maxk = max(topk) |
|
batch_size = target.size(0) |
|
_, pred = output.topk(maxk, 1, True, True) |
|
pred = pred.t() |
|
correct = pred.eq(target.reshape(1, -1).expand_as(pred)) |
|
return [correct[:k].reshape(-1).float().sum(0) * 100.0 / batch_size for k in topk] |
|
|
|
|
|
def multi_scale(samples, model): |
|
""" |
|
build a multi-scale features |
|
""" |
|
v = None |
|
for s in [1, 1 / 2 ** (1 / 2), 1 / 2]: |
|
if s == 1: |
|
inp = samples.clone() |
|
else: |
|
inp = nn.functional.interpolate( |
|
samples, scale_factor=s, mode="bilinear", align_corners=False |
|
) |
|
feats = model.forward_knn(inp).clone() |
|
if v is None: |
|
v = feats |
|
else: |
|
v += feats |
|
v /= 3 |
|
v /= v.norm() |
|
return v |
|
|
|
|
|
class AllGather(torch.autograd.Function): |
|
""" |
|
gather the variable on different nodes toghther |
|
""" |
|
|
|
@staticmethod |
|
def forward(ctx, x): |
|
if ( |
|
dist.is_available() |
|
and dist.is_initialized() |
|
and (dist.get_world_size() > 1) |
|
): |
|
outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())] |
|
dist.all_gather(outputs, x) |
|
return torch.cat(outputs, 0) |
|
return x |
|
|
|
@staticmethod |
|
def backward(ctx, grads): |
|
if ( |
|
dist.is_available() |
|
and dist.is_initialized() |
|
and (dist.get_world_size() > 1) |
|
): |
|
s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank() |
|
e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1) |
|
grads = grads.contiguous() |
|
dist.all_reduce(grads) |
|
return grads[s:e] |
|
return grads |
|
|
|
|
|
class AllReduce(torch.autograd.Function): |
|
""" |
|
reduce the variable on different nodes toghther |
|
""" |
|
|
|
@staticmethod |
|
def forward(ctx, x): |
|
if ( |
|
dist.is_available() |
|
and dist.is_initialized() |
|
and (dist.get_world_size() > 1) |
|
): |
|
x = x.contiguous() / dist.get_world_size() |
|
dist.all_reduce(x) |
|
return x |
|
|
|
@staticmethod |
|
def backward(ctx, grads): |
|
return grads |
|
|
|
|
|
def load_pretrained_weights( |
|
model, pretrained_weights, checkpoint_key, model_name, patch_size |
|
): |
|
if os.path.isfile(pretrained_weights): |
|
state_dict = torch.load(pretrained_weights, map_location="cpu") |
|
if checkpoint_key is not None and checkpoint_key in state_dict: |
|
print(f"Take key {checkpoint_key} in provided checkpoint dict") |
|
state_dict = state_dict[checkpoint_key] |
|
|
|
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
|
|
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} |
|
|
|
state_dict = {k.replace("encoder.", ""): v for k, v in state_dict.items()} |
|
msg = model.load_state_dict(state_dict, strict=False) |
|
print( |
|
"Pretrained weights found at {} and loaded with msg: {}".format( |
|
pretrained_weights, msg |
|
) |
|
) |
|
else: |
|
print( |
|
"There is no reference weights available for this model => We use random weights." |
|
) |
|
|
|
|
|
@torch.no_grad() |
|
def concat_all_gather(tensor): |
|
""" |
|
Performs all_gather operation on the provided tensors. |
|
*** Warning ***: torch.distributed.all_gather has no gradient. |
|
""" |
|
tensors_gather = [ |
|
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) |
|
] |
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False) |
|
|
|
output = torch.cat(tensors_gather, dim=0) |
|
return output |
|
|