import os, time, argparse from PIL import Image import numpy as np import torch from torchvision import transforms from torchvision.utils import save_image as imwrite from utils.utils import print_args, load_restore_ckpt, load_embedder_ckpt transform_resize = transforms.Compose([ transforms.Resize([224,224]), transforms.ToTensor() ]) def main(args): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #train print('> Model Initialization...') embedder = load_embedder_ckpt(device, freeze_model=True, ckpt_name=args.embedder_model_path) restorer = load_restore_ckpt(device, freeze_model=True, ckpt_name=args.restore_model_path) os.makedirs(args.output,exist_ok=True) files = os.listdir(argspar.input) time_record = [] for i in files: lq = Image.open(f'{argspar.input}/{i}') with torch.no_grad(): lq_re = torch.Tensor((np.array(lq)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu") lq_em = transform_resize(lq).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu") start_time = time.time() if args.prompt == None: text_embedding, _, [text] = embedder(lq_em,'image_encoder') print(f'This is {text} degradation estimated by visual embedder.') else: text_embedding, _, [text] = embedder([args.prompt],'text_encoder') print(f'This is {text} degradation generated by input text.') out = restorer(lq_re, text_embedding) run_time = time.time()-start_time time_record.append(run_time) if args.concat: out = torch.cat((lq_re, out), dim=3) imwrite(out, f'{args.output}/{i}', range=(0, 1)) print(f'{i} Running Time: {run_time:.4f}.') print(f'Average time is {np.mean(np.array(run_time))}') os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "0" if __name__ == '__main__': parser = argparse.ArgumentParser(description = "OneRestore Running") # load model parser.add_argument("--embedder-model-path", type=str, default = "./ckpts/embedder_model.tar", help = 'embedder model path') parser.add_argument("--restore-model-path", type=str, default = "./ckpts/onerestore_cdd-11.tar", help = 'restore model path') # select model automatic (prompt=False) or manual (prompt=True, text={'clear', 'low', 'haze', 'rain', 'snow',\ # 'low_haze', 'low_rain', 'low_snow', 'haze_rain', 'haze_snow', 'low_haze_rain', 'low_haze_snow'}) parser.add_argument("--prompt", type=str, default = None, help = 'prompt') parser.add_argument("--input", type=str, default = "./image/", help = 'image path') parser.add_argument("--output", type=str, default = "./output/", help = 'output path') parser.add_argument("--concat", action='store_true', help = 'output path') argspar = parser.parse_args() print_args(argspar) main(argspar)