Spaces:
Build error
Build error
File size: 4,621 Bytes
ad250d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import glob
import sys
from collections import OrderedDict
import tqdm
from natsort import natsort
import argparse
import models.llflow.option_ as option
from models.llflow import Measure, psnr
from models.llflow import imresize
from models import create_model
import torch
from util import opt_get
import numpy as np
import pandas as pd
import os
import cv2
from rich.console import Console
def fiFindByWildcard(wildcard):
return natsort.natsorted(glob.glob(wildcard, recursive=True))
def load_model(conf_path):
opt = option.parse(conf_path, is_train=False)
opt['gpu_ids'] = None
opt = option.dict_to_nonedict(opt)
model = create_model(opt)
model_path = opt_get(opt, ['model_path'], None)
model.load_network(load_path=model_path, network=model.netG)
return model, opt
def predict(model, lr):
model.feed_data({"LQ": t(lr)}, need_GT=False)
model.test()
visuals = model.get_current_visuals(need_GT=False)
return visuals.get('rlt', visuals.get('NORMAL'))
def t(array): return torch.Tensor(np.expand_dims(array.transpose([2, 0, 1]), axis=0).astype(np.float32)) / 255
def rgb(t): return (
np.clip((t[0] if len(t.shape) == 4 else t).detach().cpu().numpy().transpose([1, 2, 0]), 0, 1) * 255).astype(
np.uint8)
def imread(path):
return cv2.imread(path)[:, :, [2, 1, 0]]
def imwrite(path, img):
os.makedirs(os.path.dirname(path), exist_ok=True)
cv2.imwrite(path, img[:, :, [2, 1, 0]])
def imCropCenter(img, size):
h, w, c = img.shape
h_start = max(h // 2 - size // 2, 0)
h_end = min(h_start + size, h)
w_start = max(w // 2 - size // 2, 0)
w_end = min(w_start + size, w)
return img[h_start:h_end, w_start:w_end]
def impad(img, top=0, bottom=0, left=0, right=0, color=255):
return np.pad(img, [(top, bottom), (left, right), (0, 0)], 'reflect')
def hiseq_color_cv2_img(img):
(b, g, r) = cv2.split(img)
bH = cv2.equalizeHist(b)
gH = cv2.equalizeHist(g)
rH = cv2.equalizeHist(r)
result = cv2.merge((bH, gH, rH))
return result
def auto_padding(img, times=16):
# img: numpy image with shape H*W*C
h, w, _ = img.shape
h1, w1 = (times - h % times) // 2, (times - w % times) // 2
h2, w2 = (times - h % times) - h1, (times - w % times) - w1
img = cv2.copyMakeBorder(img, h1, h2, w1, w2, cv2.BORDER_REFLECT)
return img, [h1, h2, w1, w2]
def main(path:str):
parser = argparse.ArgumentParser()
# parser.add_argument("--opt", default="./confs/LOL_smallNet.yml")
parser.add_argument("--opt", default="./models/llflow/LOL_smallNet.yml")
parser.add_argument("-n", "--name", default="unpaired")
# Namespace(opt="./models/llflow/LOL_smallNet.yml", name="unpaired")
# args = parser.parse_args()
args = parser.parse_args()
Console().log(f"🛠️\tLoading model from {args.opt}")
conf_path = args.opt
conf = conf_path.split('/')[-1].replace('.yml', '')
model, opt = load_model(conf_path)
model.netG = model.netG.cuda()
lr_dir = opt['dataroot_unpaired']
# lr_paths = fiFindByWildcard(os.path.join(lr_dir, '*.*'))
lr_paths = path
this_dir = os.path.dirname(os.path.realpath(__file__))
test_dir = os.path.join(this_dir, '..', 'results', conf, args.name)
print(f"Out dir: {test_dir}")
# for lr_path, idx_test in tqdm.tqdm(zip(lr_paths, range(len(lr_paths))), colour='green'):
lr_path = lr_paths
lr = imread(lr_path)
raw_shape = lr.shape
lr, padding_params = auto_padding(lr)
his = hiseq_color_cv2_img(lr)
if opt.get("histeq_as_input", False):
lr = his
lr_t = t(lr)
if opt["datasets"]["train"].get("log_low", False):
lr_t = torch.log(torch.clamp(lr_t + 1e-3, min=1e-3))
if opt.get("concat_histeq", False):
his = t(his)
lr_t = torch.cat([lr_t, his], dim=1)
heat = opt['heat']
with torch.cuda.amp.autocast():
sr_t = model.get_sr(lq=lr_t.cuda(), heat=None)
sr = rgb(torch.clamp(sr_t, 0, 1)[:, :, padding_params[0]:sr_t.shape[2] - padding_params[1],
padding_params[2]:sr_t.shape[3] - padding_params[3]])
assert raw_shape == sr.shape
path_out_sr = os.path.join(test_dir, os.path.basename(lr_path))
# imwrite(path_out_sr, sr)
# cv2.imwrite(path_out_sr, sr[:, :, [2, 1, 0]])
return sr[:, :, [2, 1, 0]]
def format_measurements(meas):
s_out = []
for k, v in meas.items():
v = f"{v:0.2f}" if isinstance(v, float) else v
s_out.append(f"{k}: {v}")
str_out = ", ".join(s_out)
return str_out
if __name__ == "__main__":
main()
|