Spaces:
Paused
Paused
| import argparse | |
| import subprocess | |
| from tqdm import tqdm | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader | |
| import os | |
| import torch.nn as nn | |
| # from utils.dataset_utils import DenoiseTestDataset, DerainDehazeDataset | |
| # from utils.val_utils import AverageMeter, compute_psnr_ssim | |
| # from utils.image_io import save_image_tensor | |
| from PIL import Image | |
| from torchvision.transforms import ToTensor | |
| import lightning.pytorch as pl | |
| import torch.nn.functional as F | |
| from net.prompt_xrestormer import PromptXRestormer | |
| import json | |
| # crop an image to the multiple of base | |
| def crop_img(image, base=64): | |
| h = image.shape[0] | |
| w = image.shape[1] | |
| crop_h = h % base | |
| crop_w = w % base | |
| return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :] | |
| class PromptXRestormerIRModel(pl.LightningModule): | |
| def __init__(self): | |
| super().__init__() | |
| self.net = PromptXRestormer( | |
| inp_channels=3, | |
| out_channels=3, | |
| dim = 48, | |
| num_blocks = [2,4,4,4], | |
| num_refinement_blocks = 4, | |
| channel_heads= [1,1,1,1], | |
| spatial_heads= [1,2,4,8], | |
| overlap_ratio= [0.5, 0.5, 0.5, 0.5], | |
| ffn_expansion_factor = 2.66, | |
| bias = False, | |
| LayerNorm_type = 'WithBias', ## Other option 'BiasFree' | |
| dual_pixel_task = False, ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 | |
| scale = 1,prompt = True | |
| ) | |
| self.loss_fn = nn.L1Loss() | |
| def forward(self,x): | |
| return self.net(x) | |
| def np_to_pil(img_np): | |
| """ | |
| Converts image in np.array format to PIL image. | |
| From C x W x H [0..1] to W x H x C [0...255] | |
| :param img_np: | |
| :return: | |
| """ | |
| ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) | |
| if img_np.shape[0] == 1: | |
| ar = ar[0] | |
| else: | |
| assert img_np.shape[0] == 3, img_np.shape | |
| ar = ar.transpose(1, 2, 0) | |
| return Image.fromarray(ar) | |
| def torch_to_np(img_var): | |
| """ | |
| Converts an image in torch.Tensor format to np.array. | |
| From 1 x C x W x H [0..1] to C x W x H [0..1] | |
| :param img_var: | |
| :return: | |
| """ | |
| return img_var.detach().cpu().numpy()[0] | |
| def save_image_tensor(image_tensor, output_path="output/"): | |
| image_np = torch_to_np(image_tensor) | |
| # print(image_np.shape) | |
| p = np_to_pil(image_np) | |
| p.save(output_path) | |
| if __name__ == '__main__': | |
| np.random.seed(0) | |
| torch.manual_seed(0) | |
| torch.cuda.set_device(0) | |
| ckpt_path = "/home/jiachen/MyGradio/ckpt/promptxrestormer_epoch=64-step=578630.ckpt" | |
| print("CKPT name : {}".format(ckpt_path)) | |
| net = PromptXRestormerIRModel().load_from_checkpoint(ckpt_path).cuda() | |
| net.eval() | |
| degraded_path = "/home/jiachen/MyGradio/test_images/noisy_myimage.jpg" | |
| degraded_img = crop_img(np.array(Image.open(degraded_path).convert('RGB')), base=16) | |
| toTensor = ToTensor() | |
| degraded_img = toTensor(degraded_img) | |
| print(degraded_img.shape) | |
| with torch.no_grad(): | |
| degraded_img = degraded_img.unsqueeze(0).cuda() | |
| _, _, H_old, W_old = degraded_img.shape | |
| h_pad = (H_old // 64 + 1) * 64 - H_old | |
| w_pad = (W_old // 64 + 1) * 64 - W_old | |
| degraded_img = torch.cat([degraded_img, torch.flip(degraded_img, [2])], 2)[:,:,:H_old+h_pad,:] | |
| degraded_img = torch.cat([degraded_img, torch.flip(degraded_img, [3])], 3)[:,:,:,:W_old+w_pad] | |
| print("inputImage size", degraded_img.shape) | |
| restored = net(degraded_img) | |
| restored = restored[:,:,:H_old:,:W_old] | |
| save_image_tensor(restored, "output.png") | |