lino / src /models /utils /model_utils.py
algohunt
initial_commit
c295391
import os
import torch
import numpy as np
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F
import matplotlib.pyplot as plt
def pca_show(feats):
feats = feats.cpu().detach()
B, C, H, W = feats.shape
n_components = 3
from sklearn.decomposition import PCA
pca = PCA(n_components=n_components)
feats = F.interpolate(feats, size=(128, 128), mode='bilinear', align_corners=False)
features = feats[0, :, :, :].squeeze().reshape(C,-1).permute(1,0).cpu()
pca.fit(features)
pca_features = pca.transform(features)
pca_features = (pca_features - pca_features.min()) / (pca_features.max() - pca_features.min())
pca_features = pca_features * 255
pca_features = pca_features.reshape(128, 128, n_components).astype(np.uint8)
plt.imshow(pca_features)
def loadmodel(model, filename, strict=True, remove_prefix=False):
if os.path.exists(filename):
params = torch.load('%s' % filename)
if remove_prefix:
new_params = {k.replace('module.', ''): v for k, v in params.items()}
model.load_state_dict(new_params,strict=strict)
print("prefix successfully removed")
else:
model.load_state_dict(params,strict=strict)
print('Load %s' % filename)
else:
print('Model Not Found')
return model
def loadoptimizer(optimizer, filename):
if os.path.exists(filename):
params = torch.load('%s' % filename)
optimizer.load_state_dict(params)
print('Load %s' % filename)
return optimizer
def loadscheduler(scheduler, filename):
if os.path.exists(filename):
params = torch.load('%s' % filename)
scheduler.load_state_dict(params)
print('Load %s' % filename)
else:
print('Scheduler Not Found')
return scheduler
def savemodel(model, filename):
print('Save %s' % filename)
torch.save(model.state_dict(), filename)
def saveoptimizer(optimizer, filename):
print('Save %s' % filename)
torch.save(optimizer.state_dict(), filename)
def savescheduler(scheduler, filename):
print('Save %s' % filename)
torch.save(scheduler.state_dict(), filename)
def optimizer_setup_Adam(net, lr = 0.0001, init=True, step_size=3, stype='step'):
print(f'Optimizer (Adam) lr={lr}')
if init==True:
net.init_weights()
net = torch.nn.DataParallel(net)
optim_params = [{'params': net.parameters(), 'lr': lr},] # confirmed
optimizer = torch.optim.Adam(optim_params, betas=(0.9, 0.999), weight_decay=0)
if stype == 'cos':
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 30, eta_min=0, last_epoch=-1)
print('Cosine aneealing learning late scheduler')
if stype == 'step':
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.8)
print(f'Step late scheduler x0.8 decay every {step_size}')
return net, optimizer, scheduler
def optimizer_setup_SGD(net, lr = 0.01, momentum= 0.9, init=True):
print(f'Optimizer (SGD with momentum) lr={lr}')
if init==True:
net.init_weights()
net = torch.nn.DataParallel(net)
optim_params = [{'params': net.parameters(), 'lr': lr},] # confirmed
return net, torch.optim.SGD(optim_params, momentum=momentum, weight_decay=1e-4, nesterov=True)
def optimizer_setup_AdamW(net, lr = 0.001, eps=1.0e-8, step_size = 20, init=True, stype='step', use_data_parallel=False):
print(f'Optimizer (AdamW) lr={lr}')
if init==True:
net.init_weights()
if use_data_parallel:
net = torch.nn.DataParallel(net)
optim_params = [{'params': net.parameters(), 'lr': lr},] # confirmed
optimizer = torch.optim.AdamW(optim_params, betas=(0.9, 0.999), eps=eps, weight_decay=0.01)
if stype=='cos':
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 30, eta_min=0, last_epoch=-1)
print('Cosine aneealing learning late scheduler')
if stype == 'step':
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.8)
print(f'Step late scheduler x0.8 decay every {step_size}')
return net, optimizer, scheduler
def mode_change(net, Training):
if Training == True:
for param in net.parameters():
param.requires_grad = True
net.train()
if Training == False:
for param in net.parameters():
param.requires_grad = False
net.eval()
def get_n_params(model):
pp=0
for p in list(model.parameters()):
nn=1
for s in list(p.size()):
nn = nn*s
pp += nn
return pp
def loadCheckpoint(path, model, cuda=True):
if cuda:
checkpoint = torch.load(path)
else:
checkpoint = torch.load(path, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['state_dict'])
def saveCheckpoint(save_path, epoch=-1, model=None, optimizer=None, records=None, args=None):
state = {'state_dict': model.state_dict(), 'model': args.model}
records = {'epoch': epoch, 'optimizer':optimizer.state_dict(), 'records': records,
'args': args}
torch.save(state, os.path.join(save_path, 'checkp_%d.pth.tar' % (epoch)))
torch.save(records, os.path.join(save_path, 'checkp_%d_rec.pth.tar' % (epoch)))
def masking(img, mask):
# img [B, C, H, W]
# mask [B, 1, H, W] [0,1]
img_masked = img * mask.expand((-1, img.shape[1], -1, -1))
return img_masked
def print_model_parameters(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print('# parameters: %d' % params)
def angular_error(x1, x2, mask = None): # tensor [B, 3, H, W]
if mask is not None:
dot = torch.sum(x1 * x2 * mask, dim=1, keepdim=True)
dot = torch.max(torch.min(dot, torch.Tensor([1.0-1.0e-12])), torch.Tensor([-1.0+1.0e-12]))
emap = torch.abs(180 * torch.acos(dot)/np.pi) * mask
mae = torch.sum(emap) / torch.sum(mask)
return mae, emap
if mask is None:
dot = torch.sum(x1 * x2, dim=1, keepdim=True)
dot = torch.max(torch.min(dot, torch.Tensor([1.0-1.0e-12])), torch.Tensor([-1.0+1.0e-12]))
error = torch.abs(180 * torch.acos(dot)/np.pi)
return error
def write_errors(filepath, error, trainid, numimg, objname = []):
dt_now = datetime.datetime.now()
print(filepath)
if len(objname) > 0:
with open(filepath, 'a') as f:
f.write('%s %03d %s %02d %.2f\n' % (dt_now, numimg, objname, trainid, error))
else:
with open(filepath, 'a') as f:
f.write('%s %03d %02d %.2f\n' % (dt_now, numimg, trainid, error))
def save_nparray_as_hdf5(self, a, filename):
h5f = h5py.File(filename, 'w')
h5f.create_dataset('dataset_1', data=a)
h5f.close()
def freeze_params(net):
for param in net.parameters():
param.requires_grad = False
def unfreeze_params(net):
for param in net.parameters():
param.requires_grad = True
def make_index_list(maxNumImages, numImageList):
index = np.zeros((len(numImageList) * maxNumImages), np.int32)
for k in range(len(numImageList)):
index[maxNumImages*k:maxNumImages*k+numImageList[k]] = 1
return index