OneRestore / test.py
gy65896's picture
Upload 51 files
73ba284 verified
raw
history blame
3.2 kB
import os, time, argparse
from PIL import Image
import numpy as np
import torch
from torchvision import transforms
from torchvision.utils import save_image as imwrite
from utils.utils import print_args, load_restore_ckpt, load_embedder_ckpt
transform_resize = transforms.Compose([
transforms.Resize([224,224]),
transforms.ToTensor()
])
def main(args):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#train
print('> Model Initialization...')
embedder = load_embedder_ckpt(device, freeze_model=True, ckpt_name=args.embedder_model_path)
restorer = load_restore_ckpt(device, freeze_model=True, ckpt_name=args.restore_model_path)
os.makedirs(args.output,exist_ok=True)
files = os.listdir(argspar.input)
time_record = []
for i in files:
lq = Image.open(f'{argspar.input}/{i}')
with torch.no_grad():
lq_re = torch.Tensor((np.array(lq)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
lq_em = transform_resize(lq).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
start_time = time.time()
if args.prompt == None:
text_embedding, _, [text] = embedder(lq_em,'image_encoder')
print(f'This is {text} degradation estimated by visual embedder.')
else:
text_embedding, _, [text] = embedder([args.prompt],'text_encoder')
print(f'This is {text} degradation generated by input text.')
out = restorer(lq_re, text_embedding)
run_time = time.time()-start_time
time_record.append(run_time)
if args.concat:
out = torch.cat((lq_re, out), dim=3)
imwrite(out, f'{args.output}/{i}', range=(0, 1))
print(f'{i} Running Time: {run_time:.4f}.')
print(f'Average time is {np.mean(np.array(run_time))}')
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if __name__ == '__main__':
parser = argparse.ArgumentParser(description = "OneRestore Running")
# load model
parser.add_argument("--embedder-model-path", type=str, default = "./ckpts/embedder_model.tar", help = 'embedder model path')
parser.add_argument("--restore-model-path", type=str, default = "./ckpts/onerestore_cdd-11.tar", help = 'restore model path')
# select model automatic (prompt=False) or manual (prompt=True, text={'clear', 'low', 'haze', 'rain', 'snow',\
# 'low_haze', 'low_rain', 'low_snow', 'haze_rain', 'haze_snow', 'low_haze_rain', 'low_haze_snow'})
parser.add_argument("--prompt", type=str, default = None, help = 'prompt')
parser.add_argument("--input", type=str, default = "./image/", help = 'image path')
parser.add_argument("--output", type=str, default = "./output/", help = 'output path')
parser.add_argument("--concat", action='store_true', help = 'output path')
argspar = parser.parse_args()
print_args(argspar)
main(argspar)