import os, sys, time import shutil import datetime import torch import torch.nn.functional as torch_F import socket import contextlib import socket import torch.distributed as dist from collections import defaultdict, deque class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average. """ def __init__(self, window_size=20, fmt=None): if fmt is None: fmt = "{median:.4f} ({global_avg:.4f})" self.deque = deque(maxlen=window_size) self.total = 0.0 self.count = 0 self.fmt = fmt def update(self, value, n=1): self.deque.append(value) self.count += n self.total += value * n @property def median(self): d = torch.tensor(list(self.deque)) return d.median().item() @property def avg(self): d = torch.tensor(list(self.deque), dtype=torch.float32) return d.mean().item() @property def global_avg(self): return self.total / self.count @property def max(self): return max(self.deque) @property def value(self): return self.deque[-1] def __str__(self): return self.fmt.format( median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value) class MetricLogger(object): def __init__(self, delimiter="\t"): self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter def update(self, **kwargs): for k, v in kwargs.items(): if v is None: continue if isinstance(v, torch.Tensor): v = v.item() assert isinstance(v, (float, int)) self.meters[k].update(v) def __getattr__(self, attr): if attr in self.meters: return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): loss_str.append( "{}: {}".format(name, str(meter)) ) return self.delimiter.join(loss_str) def add_meter(self, name, meter): self.meters[name] = meter def log_every(self, iterable, print_freq, header=None): i = 0 if not header: header = '' start_time = time.time() end = time.time() iter_time = SmoothedValue(fmt='{avg:.4f}') data_time = SmoothedValue(fmt='{avg:.4f}') space_fmt = ':' + str(len(str(len(iterable)))) + 'd' log_msg = [ header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}' ] if torch.cuda.is_available(): log_msg.append('max mem: {memory:.0f}') log_msg = self.delimiter.join(log_msg) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) yield obj iter_time.update(time.time() - end) if i % print_freq == 0 or i == len(iterable) - 1: eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time), memory=torch.cuda.max_memory_allocated() / MB)) else: print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time))) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('{} Total time: {} ({:.4f} s / it)'.format( header, total_time_str, total_time / len(iterable))) def print_eval(opt, loss=None, chamfer=None, depth_metrics=None): message = "[eval] " if loss is not None: message += "loss:{}".format("{:.3e}".format(loss.all)) if chamfer is not None: message += " chamfer:{}|{}|{}".format("{:.4f}".format(chamfer[0]), "{:.4f}".format(chamfer[1]), "{:.4f}".format((chamfer[0]+chamfer[1])/2)) if depth_metrics is not None: for k, v in depth_metrics.items(): message += "{}:{}, ".format(k, "{:.4f}".format(v)) message = message[:-2] print(message) def update_timer(opt, timer, ep, it_per_ep): momentum = 0.99 timer.elapsed = time.time()-timer.start timer.it = timer.it_end-timer.it_start # compute speed with moving average timer.it_mean = timer.it_mean*momentum+timer.it*(1-momentum) if timer.it_mean is not None else timer.it timer.arrival = timer.it_mean*it_per_ep*(opt.max_epoch-ep) # move tensors to device in-place def move_to_device(X, device): if isinstance(X, dict): for k, v in X.items(): X[k] = move_to_device(v, device) elif isinstance(X, list): for i, e in enumerate(X): X[i] = move_to_device(e, device) elif isinstance(X, tuple) and hasattr(X, "_fields"): # collections.namedtuple dd = X._asdict() dd = move_to_device(dd, device) return type(X)(**dd) elif isinstance(X, torch.Tensor): return X.to(device=device, non_blocking=True) return X # detach tensors def detach_tensors(X): if isinstance(X, dict): for k, v in X.items(): X[k] = detach_tensors(v) elif isinstance(X, list): for i, e in enumerate(X): X[i] = detach_tensors(e) elif isinstance(X, tuple) and hasattr(X, "_fields"): # collections.namedtuple dd = X._asdict() dd = detach_tensors(dd) return type(X)(**dd) elif isinstance(X, torch.Tensor): return X.detach() return X # this recursion seems to only work for the outer loop when dict_type is not dict def to_dict(D, dict_type=dict): D = dict_type(D) for k, v in D.items(): if isinstance(v, dict): D[k] = to_dict(v, dict_type) return D def get_child_state_dict(state_dict, key): out_dict = {} for k, v in state_dict.items(): if k.startswith("module."): param_name = k[7:] else: param_name = k if param_name.startswith("{}.".format(key)): out_dict[".".join(param_name.split(".")[1:])] = v return out_dict def resume_checkpoint(opt, model, best): load_name = "{0}/best.ckpt".format(opt.output_path) if best else "{0}/latest.ckpt".format(opt.output_path) checkpoint = torch.load(load_name, map_location=torch.device(opt.device)) model.graph.module.load_state_dict(checkpoint["graph"], strict=True) # load the training stats for key in model.__dict__: if key.split("_")[0] in ["optim", "sched", "scaler"] and key in checkpoint: if opt.device == 0: print("restoring {}...".format(key)) getattr(model, key).load_state_dict(checkpoint[key]) # also need to record ep, it, best_val if we are returning ep, it = checkpoint["epoch"], checkpoint["iter"] best_val, best_ep = checkpoint["best_val"], checkpoint["best_ep"] if "best_ep" in checkpoint else 0 print("resuming from epoch {0} (iteration {1})".format(ep, it)) return ep, it, best_val, best_ep def load_checkpoint(opt, model, load_name): # load_name as to be given checkpoint = torch.load(load_name, map_location=torch.device(opt.device)) # load individual (possibly partial) children modules for name, child in model.graph.module.named_children(): child_state_dict = get_child_state_dict(checkpoint["graph"], name) if child_state_dict: if opt.device == 0: print("restoring {}...".format(name)) child.load_state_dict(child_state_dict, strict=True) else: if opt.device == 0: print("skipping {}...".format(name)) return None, None, None, None def restore_checkpoint(opt, model, load_name=None, resume=False, best=False, evaluate=False): # we cannot load and resume at the same time assert not (load_name is not None and resume) # when resuming we want everything to be the same if resume: ep, it, best_val, best_ep = resume_checkpoint(opt, model, best) # loading is more flexible, as we can only load parts of the model else: ep, it, best_val, best_ep = load_checkpoint(opt, model, load_name) return ep, it, best_val, best_ep def save_checkpoint(opt, model, ep, it, best_val, best_ep, latest=False, best=False, children=None): os.makedirs("{0}/checkpoint".format(opt.output_path), exist_ok=True) if isinstance(model.graph, torch.nn.DataParallel) or isinstance(model.graph, torch.nn.parallel.DistributedDataParallel): graph = model.graph.module else: graph = model.graph if children is not None: graph_state_dict = { k: v for k, v in graph.state_dict().items() if k.startswith(children) } else: graph_state_dict = graph.state_dict() checkpoint = dict( epoch=ep, iter=it, best_val=best_val, best_ep=best_ep, graph=graph_state_dict, ) for key in model.__dict__: if key.split("_")[0] in ["optim", "sched", "scaler"]: checkpoint.update({key: getattr(model, key).state_dict()}) torch.save(checkpoint, "{0}/latest.ckpt".format(opt.output_path)) if best: shutil.copy("{0}/latest.ckpt".format(opt.output_path), "{0}/best.ckpt".format(opt.output_path)) if not latest: shutil.copy("{0}/latest.ckpt".format(opt.output_path), "{0}/checkpoint/ep{1}.ckpt".format(opt.output_path, ep)) def check_socket_open(hostname, port): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) is_open = False try: s.bind((hostname, port)) except socket.error: is_open = True finally: s.close() return is_open def get_layer_dims(layers): # return a list of tuples (k_in, k_out) return list(zip(layers[:-1], layers[1:])) @contextlib.contextmanager def suppress(stdout=False, stderr=False): with open(os.devnull, "w") as devnull: if stdout: old_stdout, sys.stdout = sys.stdout, devnull if stderr: old_stderr, sys.stderr = sys.stderr, devnull try: yield finally: if stdout: sys.stdout = old_stdout if stderr: sys.stderr = old_stderr def toggle_grad(model, requires_grad): for p in model.parameters(): p.requires_grad_(requires_grad) def compute_grad2(d_outs, x_in): d_outs = [d_outs] if not isinstance(d_outs, list) else d_outs reg = 0 for d_out in d_outs: batch_size = x_in.size(0) grad_dout = torch.autograd.grad( outputs=d_out.sum(), inputs=x_in, create_graph=True, retain_graph=True, only_inputs=True )[0] grad_dout2 = grad_dout.pow(2) assert(grad_dout2.size() == x_in.size()) reg += grad_dout2.view(batch_size, -1).sum(1) return reg / len(d_outs) # import matplotlib.pyplot as plt def interpolate_depth(depth_input, mask_input, size, bg_depth=20): assert len(depth_input.shape) == len(mask_input.shape) == 4 mask = (mask_input > 0.5).float() depth_valid = depth_input * mask depth_valid = torch_F.interpolate(depth_valid, size, mode='bilinear', align_corners=False) mask = torch_F.interpolate(mask, size, mode='bilinear', align_corners=False) depth_out = depth_valid / (mask + 1.e-6) mask_binary = (mask > 0.5).float() depth_out = depth_out * mask_binary + bg_depth * (1 - mask_binary) return depth_out, mask_binary # import matplotlib.pyplot as plt # import torchvision def interpolate_coordmap(coord_map, mask_input, size, bg_coord=0): assert len(coord_map.shape) == len(mask_input.shape) == 4 mask = (mask_input > 0.5).float() coord_valid = coord_map * mask coord_valid = torch_F.interpolate(coord_valid, size, mode='bilinear', align_corners=False) mask = torch_F.interpolate(mask, size, mode='bilinear', align_corners=False) coord_out = coord_valid / (mask + 1.e-6) mask_binary = (mask > 0.5).float() coord_out = coord_out * mask_binary + bg_coord * (1 - mask_binary) return coord_out, mask_binary def cleanup(): dist.destroy_process_group() def is_port_in_use(port): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: return s.connect_ex(('localhost', port)) == 0 def setup(rank, world_size, port_no): full_address = 'tcp://127.0.0.1:' + str(port_no) dist.init_process_group("nccl", init_method=full_address, rank=rank, world_size=world_size) def print_grad(grad, prefix=''): print("{} --- Grad Abs Mean, Grad Max, Grad Min: {:.5f} | {:.5f} | {:.5f}".format(prefix, grad.abs().mean().item(), grad.max().item(), grad.min().item())) class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count class EasyDict(dict): def __init__(self, d=None, **kwargs): if d is None: d = {} else: d = dict(d) if kwargs: d.update(**kwargs) for k, v in d.items(): setattr(self, k, v) # Class attributes for k in self.__class__.__dict__.keys(): if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'): setattr(self, k, getattr(self, k)) def __setattr__(self, name, value): if isinstance(value, (list, tuple)): value = [self.__class__(x) if isinstance(x, dict) else x for x in value] elif isinstance(value, dict) and not isinstance(value, self.__class__): value = self.__class__(value) super(EasyDict, self).__setattr__(name, value) super(EasyDict, self).__setitem__(name, value) __setitem__ = __setattr__ def update(self, e=None, **f): d = e or dict() d.update(f) for k in d: setattr(self, k, d[k]) def pop(self, k, d=None): delattr(self, k) return super(EasyDict, self).pop(k, d)