Spaces:
Runtime error
Runtime error
import termcolor,os,shutil,torch | |
from easydict import EasyDict as edict | |
from collections import OrderedDict | |
import math | |
import numpy as np | |
from torch.nn import init | |
def get_time(sec): | |
""" | |
Convert seconds to days, hours, minutes, and seconds | |
""" | |
d = int(sec//(24*60*60)) | |
h = int(sec//(60*60)%24) | |
m = int((sec//60)%60) | |
s = int(sec%60) | |
return d,h,m,s | |
# convert to colored strings | |
def red(message,**kwargs): return termcolor.colored(str(message),color="red",attrs=[k for k,v in kwargs.items() if v is True]) | |
def green(message,**kwargs): return termcolor.colored(str(message),color="green",attrs=[k for k,v in kwargs.items() if v is True]) | |
def blue(message,**kwargs): return termcolor.colored(str(message),color="blue",attrs=[k for k,v in kwargs.items() if v is True]) | |
def cyan(message,**kwargs): return termcolor.colored(str(message),color="cyan",attrs=[k for k,v in kwargs.items() if v is True]) | |
def yellow(message,**kwargs): return termcolor.colored(str(message),color="yellow",attrs=[k for k,v in kwargs.items() if v is True]) | |
def magenta(message,**kwargs): return termcolor.colored(str(message),color="magenta",attrs=[k for k,v in kwargs.items() if v is True]) | |
def grey(message,**kwargs): return termcolor.colored(str(message),color="grey",attrs=[k for k,v in kwargs.items() if v is True]) | |
def openreadtxt(file_name): | |
file = open(file_name,'r') | |
file_data = file.read().splitlines() | |
return file_data | |
def to_dict(D,dict_type=dict): | |
D = dict_type(D) | |
for k,v in D.items(): | |
if isinstance(v,dict): | |
D[k] = to_dict(v,dict_type) | |
return D | |
class Log: | |
def __init__(self): pass | |
def process(self,pid): | |
print(grey("Process ID: {}".format(pid),bold=True)) | |
def title(self,message): | |
print(yellow(message,bold=True,underline=True)) | |
def info(self,message): | |
print(magenta(message,bold=True)) | |
def options(self,opt,level=0): | |
for key,value in sorted(opt.items()): | |
if isinstance(value,(dict,edict)): | |
print(" "*level+cyan("* ")+green(key)+":") | |
self.options(value,level+1) | |
else: | |
print(" "*level+cyan("* ")+green(key)+":",yellow(value)) | |
def loss_train(self,opt,ep,lr,loss,timer): | |
if not opt.max_epoch: return | |
message = grey("[train] ",bold=True) | |
message += "epoch {}/{}".format(cyan(ep,bold=True),opt.max_epoch) | |
message += ", lr:{}".format(yellow("{:.2e}".format(lr),bold=True)) | |
message += ", loss:{}".format(red("{:.3e}".format(loss),bold=True)) | |
message += ", time:{}".format(blue("{0}-{1:02d}:{2:02d}:{3:02d}".format(*get_time(timer.elapsed)),bold=True)) | |
message += " (ETA:{})".format(blue("{0}-{1:02d}:{2:02d}:{3:02d}".format(*get_time(timer.arrival)))) | |
print(message) | |
def loss_val(self,opt,loss): | |
message = grey("[val] ",bold=True) | |
message += "loss:{}".format(red("{:.3e}".format(loss),bold=True)) | |
print(message) | |
log = Log() | |
def save_checkpoint(model,ep,latest=False,children=None,output_path=None): | |
os.makedirs("{0}/model".format(output_path),exist_ok=True) | |
checkpoint = dict( | |
epoch=ep, | |
netG=model.netG.state_dict(), | |
netD=model.netD.state_dict() | |
) | |
torch.save(checkpoint,"{0}/model.pth".format(output_path)) | |
if not latest: | |
shutil.copy("{0}/model.pth".format(output_path), | |
"{0}/model/{1}.pth".format(output_path,ep)) # if ep is None, track it instead | |
def filt_ckpt_keys(ckpt, item_name, model_name): | |
# if item_name in ckpt: | |
assert item_name in ckpt, "Cannot find [%s] in the checkpoints." % item_name | |
d = ckpt[item_name] | |
d_filt = OrderedDict() | |
for k, v in d.items(): | |
k_list = k.split('.') | |
if k_list[0] == model_name: | |
if k_list[1] == 'module': | |
d_filt['.'.join(k_list[2:])] = v | |
else: | |
d_filt['.'.join(k_list[1:])] = v | |
return d_filt | |
def requires_grad(model, flag=True): | |
for p in model.parameters(): | |
p.requires_grad = flag | |
def filt_ckpt_keys(ckpt, item_name, model_name): | |
# if item_name in ckpt: | |
assert item_name in ckpt, "Cannot find [%s] in the checkpoints." % item_name | |
d = ckpt[item_name] | |
d_filt = OrderedDict() | |
for k, v in d.items(): | |
k_list = k.split('.') | |
if k_list[0] == model_name: | |
if k_list[1] == 'module': | |
d_filt['.'.join(k_list[2:])] = v | |
else: | |
d_filt['.'.join(k_list[1:])] = v | |
return d_filt | |
def get_ray_pano(batch_img): | |
_,_,H,W = batch_img.size() | |
_y = np.repeat(np.array(range(W)).reshape(1,W), H, axis=0) | |
_x = np.repeat(np.array(range(H)).reshape(1,H), W, axis=0).T | |
_theta = (1 - 2 * (_x) / H) * np.pi/2 # latitude | |
_phi = 2*math.pi*(0.5 - (_y)/W ) # longtitude | |
axis0 = (np.cos(_theta)*np.cos(_phi)).reshape(1,H, W) | |
axis1 = np.sin(_theta).reshape(1,H, W) | |
axis2 = (-np.cos(_theta)*np.sin(_phi)).reshape(1, H, W) | |
original_coord = np.concatenate((axis0, axis1, axis2), axis=0) | |
return original_coord | |
def init_weights(net, init_type='kaiming', init_gain=0.02): | |
"""Initialize network weights. | |
Parameters: | |
net (network) -- network to be initialized | |
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal | |
init_gain (float) -- scaling factor for normal, xavier and orthogonal. | |
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might | |
work better for some applications. Feel free to try yourself. | |
""" | |
def init_func(m): # define the initialization function | |
classname = m.__class__.__name__ | |
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): | |
if init_type == 'normal': | |
init.normal_(m.weight.data, 0.0, init_gain) | |
elif init_type == 'xavier': | |
init.xavier_normal_(m.weight.data, gain=init_gain) | |
elif init_type == 'kaiming': | |
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') | |
elif init_type == 'orthogonal': | |
init.orthogonal_(m.weight.data, gain=init_gain) | |
else: | |
raise NotImplementedError('initialization method [%s] is not implemented' % init_type) | |
if hasattr(m, 'bias') and m.bias is not None: | |
init.constant_(m.bias.data, 0.0) | |
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. | |
init.normal_(m.weight.data, 1.0, init_gain) | |
init.constant_(m.bias.data, 0.0) | |
print('initialize network with %s' % init_type) | |
net.apply(init_func) | |
if __name__=='__main__': | |
a = torch.zeros([2,3,200,100]) | |
cood = get_ray_pano(a) | |