File size: 7,892 Bytes
bed96e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
######### import ###########
import os
from config import Config
opt = Config('training.yml') # 导出为一个类
import torch
print(torch.cuda.is_available())
gpus = ','.join([str(i) for i in opt.GPU])
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # 表示按照PCI_BUS_ID顺序从0开始排列GPU设备。environ是一个字符串所对应环境的映像对象,environ['HOME']就代表了当前这个用户的主目录
os.environ["CUDA_VISIBLE_DEVICES"] = gpus
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1," #设置当前使用的GPU设备为1,0号两个设备,名称依次为'/gpu:0'、'/gpu:1'。表示优先使用1号设备,然后使用0号设备
torch.backends.cudnn.benchmark = True # 加快网络运行速度
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import utilss
from torch.utils.data import DataLoader
import random
import time
import numpy as np
from data_RGB import get_training_data, get_validation_data
from Restormer import Restormer
import losses
from warmup_scheduler import GradualWarmupScheduler
from tqdm import tqdm
from pdb import set_trace as stx
import utils
print(torch.cuda.is_available())
######### Set Seeds ###########
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234) # 为CPU设置种子用于生成随机数,以使得结果是确定的
torch.cuda.manual_seed_all(1234) # 为当前所有的GPU设置随机种子;
start_epoch = 1
mode = opt.MODEL.MODE # deraining
session = opt.MODEL.SESSION # MPRNet
result_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'results', session) # opt.TRAINING.SAVE_DIR= './checkpoints'
model_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'models', session)
utilss.dir_utils.mkdir(result_dir)
utilss.dir_utils.mkdir(model_dir)
train_dir = opt.TRAINING.TRAIN_DIR # TRAIN_DIR: './Datasets/train'
val_dir = opt.TRAINING.VAL_DIR # VAL_DIR: './Datasets/test/Rain5H'
# factor = 8
######### Model ###########
model_restoration = Restormer()
device_ids = [i for i in range(torch.cuda.device_count())]
# device = torch.device('gpu' if torch.cuda.is_available() else 'cpu')
# model_restoration = model_restoration.to(device)
print(torch.cuda.is_available())
if torch.cuda.device_count() > 1:
print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n")
model_restoration.cuda() # 将模型加载到GPU上去。这种方法不被提倡,而建议使用model.to(device)的方式,这样可以显示指定需要使用的计算资源,特别是有多个GPU的情况下。
new_lr = opt.OPTIM.LR_INITIAL # LR_INITIAL: 2e-4
optimizer = optim.Adam(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999), eps=1e-8)
######### Scheduler ###########
warmup_epochs = 3
scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS - warmup_epochs,
eta_min=opt.OPTIM.LR_MIN) # torch.optim 调整学习率
scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs,
after_scheduler=scheduler_cosine) # warmup_scheduler中 优化器
######### Resume ###########
if opt.TRAINING.RESUME: # RESUME: False
path_chk_rest = utilss.get_last_path(model_dir, '_latest.pth')
utilss.load_checkpoint(model_restoration, path_chk_rest)
start_epoch = utilss.load_start_epoch(path_chk_rest) + 1
utilss.load_optim(optimizer, path_chk_rest)
for i in range(1, start_epoch):
scheduler.step()
new_lr = scheduler.get_lr()[0]
print('------------------------------------------------------------------------------')
print("==> Resuming Training with learning rate:", new_lr)
print('------------------------------------------------------------------------------')
if len(device_ids) > 1: # 有多个GPU则可以进行数据并行运算
print("有多个GPU,可以进行数据并行运算")
model_restoration = nn.DataParallel(model_restoration, device_ids=device_ids)
######### Loss ###########
criterion_char = losses.CharbonnierLoss()
criterion_edge = losses.EdgeLoss()
######### DataLoaders ###########
train_dataset = get_training_data(train_dir, {'patch_size':opt.TRAINING.TRAIN_PS}) # TRAIN_PS: 256
# print("train_dataset.shape: ",train_dataset.shape)
train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=8,
drop_last=False, pin_memory=True) # BATCH_SIZE: 16
# print("train_loader.shape: ",train_loader.size)
val_dataset = get_validation_data(val_dir, {'patch_size':opt.TRAINING.VAL_PS}) # VAL_PS: 128
val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=4, drop_last=False,
pin_memory=True)
print('===> Start Epoch {} End Epoch {}'.format(start_epoch, opt.OPTIM.NUM_EPOCHS + 1)) # NUM_EPOCHS: 250
print('===> Loading datasets')
best_psnr = 0
best_epoch = 0
for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1):
epoch_start_time = time.time()
epoch_loss = 0
train_id = 1
model_restoration.train()
for i, data in enumerate(tqdm(train_loader), 0):
# zero_grad
for param in model_restoration.parameters():
param.grad = None
target = data[0].cuda()
input_ = data[1].cuda()
# target = data[0].to(device)
# input_ = data[1].to(device)
print("before in model,input_.shape: ",input_.shape)
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
restored = model_restoration(input_)
# Compute loss at each stage
loss_char = criterion_char(restored, target)
loss_edge = criterion_edge(restored, target)
loss = (loss_char) + (0.05 * loss_edge)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
#### Evaluation ####
# VAL_AFTER_EVERY: 5
if epoch % opt.TRAINING.VAL_AFTER_EVERY == 0:
model_restoration.eval()
psnr_val_rgb = []
for ii, data_val in enumerate((val_loader), 0):
target = data_val[0].cuda()
input_ = data_val[1].cuda()
# target = data_val[0].to(device)
# input_ = data_val[1].to(device)
restored = 0
with torch.no_grad():
restored = model_restoration(input_)
for res, tar in zip(restored, target):
psnr_val_rgb.append(utilss.torchPSNR(res, tar))
psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item()
if psnr_val_rgb > best_psnr:
best_psnr = psnr_val_rgb
best_epoch = epoch
torch.save({'epoch': epoch,
'state_dict': model_restoration.state_dict(),
'optimizer': optimizer.state_dict()
}, os.path.join(model_dir, "model_best.pth"))
print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr))
torch.save({'epoch': epoch,
'state_dict': model_restoration.state_dict(),
'optimizer': optimizer.state_dict()
}, os.path.join(model_dir, f"model_epoch_{epoch}.pth"))
scheduler.step()
print("------------------------------------------------------------------")
print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.8f}".format(epoch, time.time() - epoch_start_time,
epoch_loss, scheduler.get_lr()[0]))
print("------------------------------------------------------------------")
torch.save({'epoch': epoch,
'state_dict': model_restoration.state_dict(),
'optimizer': optimizer.state_dict()
}, os.path.join(model_dir, "model_latest.pth"))
|