| | import os.path
|
| | import logging
|
| | import torch
|
| | import argparse
|
| | import json
|
| | import glob
|
| |
|
| | from pprint import pprint
|
| | from fvcore.nn import FlopCountAnalysis
|
| | from utils.model_summary import get_model_activation, get_model_flops
|
| | from utils import utils_logger
|
| | from utils import utils_image as util
|
| |
|
| |
|
| | def select_model(args, device):
|
| |
|
| |
|
| | model_id = args.model_id
|
| | if model_id == 0:
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | from models.team00_EFDN import EFDN
|
| | name, data_range = f"{model_id:02}_EFDN_baseline", 1.0
|
| | model_path = os.path.join('model_zoo', 'team00_EFDN.pth')
|
| | model = EFDN()
|
| | model.load_state_dict(torch.load(model_path), strict=True)
|
| | elif model_id == 23:
|
| | from models.team23_DSCF import DSCF
|
| |
|
| | name, data_range = f"{model_id:02}_DSCF", 1.0
|
| | model_path = os.path.join('model_zoo', 'team23_DSCF.pth')
|
| | model = DSCF(3,3,feature_channels=26,upscale=4)
|
| | state_dict = torch.load(model_path)
|
| |
|
| | model.load_state_dict(state_dict, strict=False)
|
| | else:
|
| | raise NotImplementedError(f"Model {model_id} is not implemented.")
|
| |
|
| |
|
| | model.eval()
|
| | tile = None
|
| | for k, v in model.named_parameters():
|
| | v.requires_grad = False
|
| | model = model.to(device)
|
| | return model, name, data_range, tile
|
| |
|
| |
|
| | def select_dataset(data_dir, mode):
|
| |
|
| | if mode == "test":
|
| | path = [
|
| | (
|
| | p.replace("_HR", "_LR").replace(".png", "x4.png"),
|
| | p
|
| | ) for p in sorted(glob.glob(os.path.join(data_dir, "DIV2K_LSDIR_test_HR/*.png")))
|
| | ]
|
| |
|
| |
|
| | elif mode == "valid":
|
| | path = [
|
| | (
|
| | p.replace("_HR", "_LR").replace(".png", "x4.png"),
|
| | p
|
| | ) for p in sorted(glob.glob(os.path.join(data_dir, "DIV2K_LSDIR_valid_HR/*.png")))
|
| | ]
|
| | else:
|
| | raise NotImplementedError(f"{mode} is not implemented in select_dataset")
|
| |
|
| | return path
|
| |
|
| |
|
| | def forward(img_lq, model, tile=None, tile_overlap=32, scale=4):
|
| | if tile is None:
|
| |
|
| | output = model(img_lq)
|
| | else:
|
| |
|
| | b, c, h, w = img_lq.size()
|
| | tile = min(tile, h, w)
|
| | tile_overlap = tile_overlap
|
| | sf = scale
|
| |
|
| | stride = tile - tile_overlap
|
| | h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
|
| | w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
|
| | E = torch.zeros(b, c, h*sf, w*sf).type_as(img_lq)
|
| | W = torch.zeros_like(E)
|
| |
|
| | for h_idx in h_idx_list:
|
| | for w_idx in w_idx_list:
|
| | in_patch = img_lq[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
|
| | out_patch = model(in_patch)
|
| | out_patch_mask = torch.ones_like(out_patch)
|
| |
|
| | E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch)
|
| | W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask)
|
| | output = E.div_(W)
|
| |
|
| | return output
|
| |
|
| | def run(model, model_name, data_range, tile, logger, device, args, mode="test"):
|
| |
|
| | sf = 4
|
| | border = sf
|
| | results = dict()
|
| | results[f"{mode}_runtime"] = []
|
| | results[f"{mode}_psnr"] = []
|
| | if args.ssim:
|
| | results[f"{mode}_ssim"] = []
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | data_path = select_dataset(args.data_dir, mode)
|
| | save_path = os.path.join(args.save_dir, model_name, mode)
|
| | util.mkdir(save_path)
|
| |
|
| | start = torch.cuda.Event(enable_timing=True)
|
| | end = torch.cuda.Event(enable_timing=True)
|
| |
|
| | for i, (img_lr, img_hr) in enumerate(data_path):
|
| |
|
| |
|
| |
|
| |
|
| | img_name, ext = os.path.splitext(os.path.basename(img_hr))
|
| | img_lr = util.imread_uint(img_lr, n_channels=3)
|
| | img_lr = util.uint2tensor4(img_lr, data_range)
|
| | img_lr = img_lr.to(device)
|
| |
|
| |
|
| |
|
| |
|
| | start.record()
|
| | img_sr = forward(img_lr, model, tile)
|
| | end.record()
|
| | torch.cuda.synchronize()
|
| | results[f"{mode}_runtime"].append(start.elapsed_time(end))
|
| | img_sr = util.tensor2uint(img_sr, data_range)
|
| |
|
| |
|
| |
|
| |
|
| | img_hr = util.imread_uint(img_hr, n_channels=3)
|
| | img_hr = img_hr.squeeze()
|
| | img_hr = util.modcrop(img_hr, sf)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | psnr = util.calculate_psnr(img_sr, img_hr, border=border)
|
| | results[f"{mode}_psnr"].append(psnr)
|
| |
|
| | if args.ssim:
|
| | ssim = util.calculate_ssim(img_sr, img_hr, border=border)
|
| | results[f"{mode}_ssim"].append(ssim)
|
| | logger.info("{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.".format(img_name + ext, psnr, ssim))
|
| | else:
|
| | logger.info("{:s} - PSNR: {:.2f} dB".format(img_name + ext, psnr))
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | results[f"{mode}_memory"] = torch.cuda.max_memory_allocated(torch.cuda.current_device()) / 1024 ** 2
|
| | results[f"{mode}_ave_runtime"] = sum(results[f"{mode}_runtime"]) / len(results[f"{mode}_runtime"])
|
| | results[f"{mode}_ave_psnr"] = sum(results[f"{mode}_psnr"]) / len(results[f"{mode}_psnr"])
|
| | if args.ssim:
|
| | results[f"{mode}_ave_ssim"] = sum(results[f"{mode}_ssim"]) / len(results[f"{mode}_ssim"])
|
| |
|
| |
|
| | logger.info("{:>16s} : {:<.3f} [M]".format("Max Memory", results[f"{mode}_memory"]))
|
| | logger.info("------> Average runtime of ({}) is : {:.6f} milliseconds".format("test" if mode == "test" else "valid", results[f"{mode}_ave_runtime"]))
|
| | logger.info("------> Average PSNR of ({}) is : {:.6f} dB".format("test" if mode == "test" else "valid", results[f"{mode}_ave_psnr"]))
|
| |
|
| | return results
|
| |
|
| |
|
| | def main(args):
|
| |
|
| | utils_logger.logger_info("NTIRE2025-EfficientSR", log_path="NTIRE2025-EfficientSR.log")
|
| | logger = logging.getLogger("NTIRE2025-EfficientSR")
|
| |
|
| |
|
| |
|
| |
|
| | torch.cuda.current_device()
|
| | torch.cuda.empty_cache()
|
| | torch.backends.cudnn.benchmark = False
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| |
|
| | json_dir = os.path.join(os.getcwd(), "results.json")
|
| | if not os.path.exists(json_dir):
|
| | results = dict()
|
| | else:
|
| | with open(json_dir, "r") as f:
|
| | results = json.load(f)
|
| |
|
| |
|
| |
|
| |
|
| | model, model_name, data_range, tile = select_model(args, device)
|
| | logger.info(model_name)
|
| |
|
| |
|
| | if True:
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | valid_results = run(model, model_name, data_range, tile, logger, device, args, mode="valid")
|
| |
|
| | results[model_name] = valid_results
|
| |
|
| |
|
| | if args.include_test:
|
| | test_results = run(model, model_name, data_range, tile, logger, device, args, mode="test")
|
| | results[model_name].update(test_results)
|
| |
|
| | input_dim = (3, 256, 256)
|
| | activations, num_conv = get_model_activation(model, input_dim)
|
| | activations = activations/10**6
|
| | logger.info("{:>16s} : {:<.4f} [M]".format("#Activations", activations))
|
| | logger.info("{:>16s} : {:<d}".format("#Conv2d", num_conv))
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | input_fake = torch.rand(1, 3, 256, 256).to(device)
|
| | flops = FlopCountAnalysis(model, input_fake).total()
|
| | flops = flops/10**9
|
| | logger.info("{:>16s} : {:<.4f} [G]".format("FLOPs", flops))
|
| |
|
| | num_parameters = sum(map(lambda x: x.numel(), model.parameters()))
|
| | num_parameters = num_parameters/10**6
|
| | logger.info("{:>16s} : {:<.4f} [M]".format("#Params", num_parameters))
|
| | results[model_name].update({"activations": activations, "num_conv": num_conv, "flops": flops, "num_parameters": num_parameters})
|
| |
|
| | with open(json_dir, "w") as f:
|
| | json.dump(results, f)
|
| | if args.include_test:
|
| | fmt = "{:20s}\t{:10s}\t{:10s}\t{:14s}\t{:14s}\t{:14s}\t{:10s}\t{:10s}\t{:8s}\t{:8s}\t{:8s}\n"
|
| | s = fmt.format("Model", "Val PSNR", "Test PSNR", "Val Time [ms]", "Test Time [ms]", "Ave Time [ms]",
|
| | "Params [M]", "FLOPs [G]", "Acts [M]", "Mem [M]", "Conv")
|
| | else:
|
| | fmt = "{:20s}\t{:10s}\t{:14s}\t{:10s}\t{:10s}\t{:8s}\t{:8s}\t{:8s}\n"
|
| | s = fmt.format("Model", "Val PSNR", "Val Time [ms]", "Params [M]", "FLOPs [G]", "Acts [M]", "Mem [M]", "Conv")
|
| | for k, v in results.items():
|
| | val_psnr = f"{v['valid_ave_psnr']:2.2f}"
|
| | val_time = f"{v['valid_ave_runtime']:3.2f}"
|
| | mem = f"{v['valid_memory']:2.2f}"
|
| |
|
| | num_param = f"{v['num_parameters']:2.3f}"
|
| | flops = f"{v['flops']:2.2f}"
|
| | acts = f"{v['activations']:2.2f}"
|
| | conv = f"{v['num_conv']:4d}"
|
| | if args.include_test:
|
| |
|
| | test_psnr = f"{v['test_ave_psnr']:2.2f}"
|
| | test_time = f"{v['test_ave_runtime']:3.2f}"
|
| | ave_time = f"{(v['valid_ave_runtime'] + v['test_ave_runtime']) / 2:3.2f}"
|
| | s += fmt.format(k, val_psnr, test_psnr, val_time, test_time, ave_time, num_param, flops, acts, mem, conv)
|
| | else:
|
| | s += fmt.format(k, val_psnr, val_time, num_param, flops, acts, mem, conv)
|
| | with open(os.path.join(os.getcwd(), 'results.txt'), "w") as f:
|
| | f.write(s)
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | parser = argparse.ArgumentParser("NTIRE2025-EfficientSR")
|
| | parser.add_argument("--data_dir", default="../", type=str)
|
| | parser.add_argument("--save_dir", default="../results", type=str)
|
| | parser.add_argument("--model_id", default=0, type=int)
|
| | parser.add_argument("--include_test", action="store_true", help="Inference on the `DIV2K_LSDIR_test` set")
|
| | parser.add_argument("--ssim", action="store_true", help="Calculate SSIM")
|
| |
|
| | args = parser.parse_args()
|
| | pprint(args)
|
| |
|
| | main(args)
|
| |
|