Spaces:
Running
Running
import numpy as np | |
import torch | |
import os | |
from torch.autograd import Variable | |
from skimage.metrics import peak_signal_noise_ratio as compare_psnr | |
from skimage.metrics import mean_squared_error as compare_mse | |
from skimage.metrics import structural_similarity as compare_ssim | |
import pandas as pd | |
from model.OneRestore import OneRestore | |
from model.Embedder import Embedder | |
def load_embedder_ckpt(device, freeze_model=False, ckpt_name=None, | |
combine_type = ['clear', 'low', 'haze', 'rain', 'snow',\ | |
'low_haze', 'low_rain', 'low_snow', 'haze_rain',\ | |
'haze_snow', 'low_haze_rain', 'low_haze_snow']): | |
if ckpt_name != None: | |
if torch.cuda.is_available(): | |
model_info = torch.load(ckpt_name) | |
else: | |
model_info = torch.load(ckpt_name, map_location=torch.device('cpu')) | |
print('==> loading existing Embedder model:', ckpt_name) | |
model = Embedder(combine_type) | |
model.load_state_dict(model_info) | |
model.to("cuda" if torch.cuda.is_available() else "cpu") | |
else: | |
print('==> Initialize Embedder model.') | |
model = Embedder(combine_type) | |
model.to("cuda" if torch.cuda.is_available() else "cpu") | |
if freeze_model: | |
freeze(model) | |
return model | |
def load_restore_ckpt(device, freeze_model=False, ckpt_name=None): | |
if ckpt_name != None: | |
if torch.cuda.is_available(): | |
model_info = torch.load(ckpt_name) | |
else: | |
model_info = torch.load(ckpt_name, map_location=torch.device('cpu')) | |
print('==> loading existing OneRestore model:', ckpt_name) | |
model = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu") | |
model.load_state_dict(model_info) | |
else: | |
print('==> Initialize OneRestore model.') | |
model = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu") | |
model = torch.nn.DataParallel(model).to("cuda" if torch.cuda.is_available() else "cpu") | |
if freeze_model: | |
freeze(model) | |
total = sum([param.nelement() for param in model.parameters()]) | |
print("Number of OneRestore parameter: %.2fM" % (total/1e6)) | |
return model | |
def load_restore_ckpt_with_optim(device, local_rank=None, freeze_model=False, ckpt_name=None, lr=None): | |
if ckpt_name != None: | |
if torch.cuda.is_available(): | |
model_info = torch.load(ckpt_name) | |
else: | |
model_info = torch.load(ckpt_name, map_location=torch.device('cpu')) | |
print('==> loading existing OneRestore model:', ckpt_name) | |
model = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu") | |
optimizer = torch.optim.Adam(model.parameters(), lr=lr) if lr != None else None | |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) if local_rank != None else model | |
if local_rank != None: | |
model.load_state_dict(model_info['state_dict']) | |
else: | |
weights_dict = {} | |
for k, v in model_info['state_dict'].items(): | |
new_k = k.replace('module.', '') if 'module' in k else k | |
weights_dict[new_k] = v | |
model.load_state_dict(weights_dict) | |
optimizer = torch.optim.Adam(model.parameters()) | |
optimizer.load_state_dict(model_info['optimizer']) | |
cur_epoch = model_info['epoch'] | |
else: | |
print('==> Initialize OneRestore model.') | |
model = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu") | |
optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) if local_rank != None else torch.nn.DataParallel(model) | |
cur_epoch = 0 | |
if freeze_model: | |
freeze(model) | |
total = sum([param.nelement() for param in model.parameters()]) | |
print("Number of OneRestore parameter: %.2fM" % (total/1e6)) | |
return model, optimizer, cur_epoch | |
def load_embedder_ckpt_with_optim(device, args, combine_type = ['clear', 'low', 'haze', 'rain', 'snow',\ | |
'low_haze', 'low_rain', 'low_snow', 'haze_rain', 'haze_snow', 'low_haze_rain', 'low_haze_snow']): | |
print('Init embedder') | |
# seed | |
if args.seed == -1: | |
args.seed = np.random.randint(1, 10000) | |
seed = args.seed | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
print('Training embedder seed:', seed) | |
# embedder model | |
embedder = Embedder(combine_type).to("cuda" if torch.cuda.is_available() else "cpu") | |
if args.pre_weight == '': | |
optimizer = torch.optim.Adam(embedder.parameters(), lr=args.lr) | |
cur_epoch = 1 | |
else: | |
try: | |
embedder_info = torch.load(f'{args.check_dir}/{args.pre_weight}') | |
if torch.cuda.is_available(): | |
embedder_info = torch.load(f'{args.check_dir}/{args.pre_weight}') | |
else: | |
embedder_info = torch.load(f'{args.check_dir}/{args.pre_weight}', map_location=torch.device('cpu')) | |
embedder.load_state_dict(embedder_info['state_dict']) | |
optimizer = torch.optim.Adam(embedder.parameters(), lr=args.lr) | |
optimizer.load_state_dict(embedder_info['optimizer']) | |
cur_epoch = embedder_info['epoch'] + 1 | |
except: | |
print('Pre-trained model loading error!') | |
return embedder, optimizer, cur_epoch, device | |
def freeze_text_embedder(m): | |
"""Freezes module m. | |
""" | |
m.eval() | |
for name, para in m.named_parameters(): | |
if name == 'embedder.weight' or name == 'mlp.0.weight' or name == 'mlp.0.bias': | |
print(name) | |
para.requires_grad = False | |
para.grad = None | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def data_process(data, args, device): | |
combine_type = args.degr_type | |
b,n,c,w,h = data.size() | |
pos_data = data[:,0,:,:,:] | |
inp_data = torch.zeros((b,c,w,h)) | |
inp_class = [] | |
neg_data = torch.zeros((b,n-2,c,w,h)) | |
index = np.random.randint(1, n, (b)) | |
for i in range(b): | |
k = 0 | |
for j in range(n): | |
if j == 0: | |
continue | |
elif index[i] == j: | |
inp_class.append(combine_type[index[i]]) | |
inp_data[i, :, :, :] = data[i, index[i], :, :,:] | |
else: | |
neg_data[i,k,:,:,:] = data[i, j, :, :,:] | |
k=k+1 | |
return pos_data.to("cuda" if torch.cuda.is_available() else "cpu"), [inp_data.to("cuda" if torch.cuda.is_available() else "cpu"), inp_class], neg_data.to("cuda" if torch.cuda.is_available() else "cpu") | |
def print_args(argspar): | |
print("\nParameter Print") | |
for p, v in zip(argspar.__dict__.keys(), argspar.__dict__.values()): | |
print('\t{}: {}'.format(p, v)) | |
print('\n') | |
def adjust_learning_rate(optimizer, epoch, lr_update_freq): | |
if not epoch % lr_update_freq and epoch: | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = param_group['lr'] /2 | |
return optimizer | |
def tensor_metric(img, imclean, model, data_range=1): | |
img_cpu = img.data.cpu().numpy().astype(np.float32).transpose(0,2,3,1) | |
imgclean = imclean.data.cpu().numpy().astype(np.float32).transpose(0,2,3,1) | |
SUM = 0 | |
for i in range(img_cpu.shape[0]): | |
if model == 'PSNR': | |
SUM += compare_psnr(imgclean[i, :, :, :], img_cpu[i, :, :, :],data_range=data_range) | |
elif model == 'MSE': | |
SUM += compare_mse(imgclean[i, :, :, :], img_cpu[i, :, :, :]) | |
elif model == 'SSIM': | |
SUM += compare_ssim(imgclean[i, :, :, :], img_cpu[i, :, :, :], data_range=data_range, multichannel = True) | |
# due to the skimage vision problem, you can replace above line by | |
# SUM += compare_ssim(imgclean[i, :, :, :], img_cpu[i, :, :, :], data_range=data_range, channel_axis=-1) | |
else: | |
print('Model False!') | |
return SUM/img_cpu.shape[0] | |
def save_checkpoint(stateF, checkpoint, epoch, psnr_t1,ssim_t1,psnr_t2,ssim_t2, filename='model.tar'): | |
torch.save(stateF, checkpoint + 'OneRestore_model_%d_%.4f_%.4f_%.4f_%.4f.tar'%(epoch,psnr_t1,ssim_t1,psnr_t2,ssim_t2)) | |
def load_excel(x): | |
data1 = pd.DataFrame(x) | |
writer = pd.ExcelWriter('./mertic_result.xlsx') | |
data1.to_excel(writer, 'PSNR-SSIM', float_format='%.5f') | |
# writer.save() | |
writer.close() | |
def freeze(m): | |
"""Freezes module m. | |
""" | |
m.eval() | |
for p in m.parameters(): | |
p.requires_grad = False | |
p.grad = None | |