TalhaUsuf
llflow(fix): works from llflow directory
a152a04
import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
import glob
import sys
from collections import OrderedDict
import tqdm
from natsort import natsort
import argparse
import option_ as option
# import option_ as option
from Measure import Measure, psnr
from imresize import imresize
from models import create_model
import torch
from util import opt_get
import numpy as np
import pandas as pd
import os
import cv2
from rich.console import Console
def fiFindByWildcard(wildcard):
return natsort.natsorted(glob.glob(wildcard, recursive=True))
def load_model(conf_path):
opt = option.parse(conf_path, is_train=False)
opt['gpu_ids'] = None
opt = option.dict_to_nonedict(opt)
model = create_model(opt)
model_path = opt_get(opt, ['model_path'], None)
model.load_network(load_path=model_path, network=model.netG)
return model, opt
def predict(model, lr):
model.feed_data({"LQ": t(lr)}, need_GT=False)
model.test()
visuals = model.get_current_visuals(need_GT=False)
return visuals.get('rlt', visuals.get('NORMAL'))
def t(array): return torch.Tensor(np.expand_dims(array.transpose([2, 0, 1]), axis=0).astype(np.float32)) / 255
def rgb(t): return (
np.clip((t[0] if len(t.shape) == 4 else t).detach().cpu().numpy().transpose([1, 2, 0]), 0, 1) * 255).astype(
np.uint8)
def imread(path):
return cv2.imread(path)[:, :, [2, 1, 0]]
def imwrite(path, img):
os.makedirs(os.path.dirname(path), exist_ok=True)
cv2.imwrite(path, img[:, :, [2, 1, 0]])
def imCropCenter(img, size):
h, w, c = img.shape
h_start = max(h // 2 - size // 2, 0)
h_end = min(h_start + size, h)
w_start = max(w // 2 - size // 2, 0)
w_end = min(w_start + size, w)
return img[h_start:h_end, w_start:w_end]
def impad(img, top=0, bottom=0, left=0, right=0, color=255):
return np.pad(img, [(top, bottom), (left, right), (0, 0)], 'reflect')
def hiseq_color_cv2_img(img):
(b, g, r) = cv2.split(img)
bH = cv2.equalizeHist(b)
gH = cv2.equalizeHist(g)
rH = cv2.equalizeHist(r)
result = cv2.merge((bH, gH, rH))
return result
def auto_padding(img, times=16):
# img: numpy image with shape H*W*C
h, w, _ = img.shape
h1, w1 = (times - h % times) // 2, (times - w % times) // 2
h2, w2 = (times - h % times) - h1, (times - w % times) - w1
img = cv2.copyMakeBorder(img, h1, h2, w1, w2, cv2.BORDER_REFLECT)
return img, [h1, h2, w1, w2]
def main(path:str):
parser = argparse.ArgumentParser()
parser.add_argument("--opt", default="./confs/LOL_smallNet.yml")
# parser.add_argument("--opt", default="E:\startup\image_enhancement\models\llflow\LOL_smallNet.yml")
parser.add_argument("-n", "--name", default="unpaired")
# Namespace(opt="./models/llflow/LOL_smallNet.yml", name="unpaired")
# args = parser.parse_args()
args = parser.parse_args()
Console().log(f"🛠️\tLoading model from {args.opt}")
conf_path = args.opt
conf = conf_path.split('/')[-1].replace('.yml', '')
model, opt = load_model(conf_path)
model.netG = model.netG.cuda()
lr_dir = opt['dataroot_unpaired']
# lr_paths = fiFindByWildcard(os.path.join(lr_dir, '*.*'))
lr_paths = path
this_dir = os.path.dirname(os.path.realpath(__file__))
test_dir = os.path.join(this_dir, '..', 'results', conf, args.name)
print(f"Out dir: {test_dir}")
# for lr_path, idx_test in tqdm.tqdm(zip(lr_paths, range(len(lr_paths))), colour='green'):
lr_path = lr_paths
lr = imread(lr_path)
raw_shape = lr.shape
lr, padding_params = auto_padding(lr)
his = hiseq_color_cv2_img(lr)
if opt.get("histeq_as_input", False):
lr = his
lr_t = t(lr)
if opt["datasets"]["train"].get("log_low", False):
lr_t = torch.log(torch.clamp(lr_t + 1e-3, min=1e-3))
if opt.get("concat_histeq", False):
his = t(his)
lr_t = torch.cat([lr_t, his], dim=1)
heat = opt['heat']
with torch.cuda.amp.autocast():
sr_t = model.get_sr(lq=lr_t.cuda(), heat=None)
sr = rgb(torch.clamp(sr_t, 0, 1)[:, :, padding_params[0]:sr_t.shape[2] - padding_params[1],
padding_params[2]:sr_t.shape[3] - padding_params[3]])
assert raw_shape == sr.shape
path_out_sr = os.path.join(test_dir, os.path.basename(lr_path))
# imwrite(path_out_sr, sr)
cv2.imwrite(path_out_sr, sr[:, :, [2, 1, 0]])
Console().print(f"image saved 💾 to {path_out_sr}")
return sr[:, :, [2, 1, 0]]
def format_measurements(meas):
s_out = []
for k, v in meas.items():
v = f"{v:0.2f}" if isinstance(v, float) else v
s_out.append(f"{k}: {v}")
str_out = ", ".join(s_out)
return str_out
if __name__ == "__main__":
main(r"E:\startup\image_enhancement\n000129\0026_01.jpg")