Spaces:
Running
Running
import numpy as np | |
import logging | |
import os | |
def count_params(model): | |
param_num = sum(p.numel() for p in model.parameters()) | |
return param_num / 1e6 | |
def color_map(dataset='pascal'): | |
cmap = np.zeros((256, 3), dtype='uint8') | |
if dataset == 'pascal' or dataset == 'coco': | |
def bitget(byteval, idx): | |
return (byteval & (1 << idx)) != 0 | |
for i in range(256): | |
r = g = b = 0 | |
c = i | |
for j in range(8): | |
r = r | (bitget(c, 0) << 7-j) | |
g = g | (bitget(c, 1) << 7-j) | |
b = b | (bitget(c, 2) << 7-j) | |
c = c >> 3 | |
cmap[i] = np.array([r, g, b]) | |
elif dataset == 'cityscapes': | |
cmap[0] = np.array([128, 64, 128]) | |
cmap[1] = np.array([244, 35, 232]) | |
cmap[2] = np.array([70, 70, 70]) | |
cmap[3] = np.array([102, 102, 156]) | |
cmap[4] = np.array([190, 153, 153]) | |
cmap[5] = np.array([153, 153, 153]) | |
cmap[6] = np.array([250, 170, 30]) | |
cmap[7] = np.array([220, 220, 0]) | |
cmap[8] = np.array([107, 142, 35]) | |
cmap[9] = np.array([152, 251, 152]) | |
cmap[10] = np.array([70, 130, 180]) | |
cmap[11] = np.array([220, 20, 60]) | |
cmap[12] = np.array([255, 0, 0]) | |
cmap[13] = np.array([0, 0, 142]) | |
cmap[14] = np.array([0, 0, 70]) | |
cmap[15] = np.array([0, 60, 100]) | |
cmap[16] = np.array([0, 80, 100]) | |
cmap[17] = np.array([0, 0, 230]) | |
cmap[18] = np.array([119, 11, 32]) | |
return cmap | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self, length=0): | |
self.length = length | |
self.reset() | |
def reset(self): | |
if self.length > 0: | |
self.history = [] | |
else: | |
self.count = 0 | |
self.sum = 0.0 | |
self.val = 0.0 | |
self.avg = 0.0 | |
def update(self, val, num=1): | |
if self.length > 0: | |
# currently assert num==1 to avoid bad usage, refine when there are some explict requirements | |
assert num == 1 | |
self.history.append(val) | |
if len(self.history) > self.length: | |
del self.history[0] | |
self.val = self.history[-1] | |
self.avg = np.mean(self.history) | |
else: | |
self.val = val | |
self.sum += val * num | |
self.count += num | |
self.avg = self.sum / self.count | |
logs = set() | |
def init_log(name, level=logging.INFO): | |
if (name, level) in logs: | |
return | |
logs.add((name, level)) | |
logger = logging.getLogger(name) | |
logger.setLevel(level) | |
ch = logging.StreamHandler() | |
ch.setLevel(level) | |
if "SLURM_PROCID" in os.environ: | |
rank = int(os.environ["SLURM_PROCID"]) | |
logger.addFilter(lambda record: rank == 0) | |
else: | |
rank = 0 | |
format_str = "[%(asctime)s][%(levelname)8s] %(message)s" | |
formatter = logging.Formatter(format_str) | |
ch.setFormatter(formatter) | |
logger.addHandler(ch) | |
return logger | |