---
tags:
- super-resolution
- Image-to-Image
---
# Accelerating Image Super-Resolution Networks with Pixel-Level Classification
[![Project Page](https://img.shields.io/badge/Project-Page-green)](https://3587jjh.github.io/PCSR/)
[![arXiv](https://img.shields.io/badge/arXiv-2407.21448-b31b1b)](https://arxiv.org/abs/2407.21448)
Abstract: In recent times, the need for effective super-resolution (SR) techniques has surged, especially for large-scale images ranging 2K to 8K resolutions. For DNN-based SISR, decomposing images into overlapping patches is typically necessary due to computational constraints. In such patch-decomposing scheme, one can allocate computational resources differently based on each patch's difficulty to further improve efficiency while maintaining SR performance. However, this approach has a limitation: computational resources is uniformly allocated within a patch, leading to lower efficiency when the patch contain pixels with varying levels of restoration difficulty. To address the issue, we propose the Pixel-level Classifier for Single Image Super-Resolution (PCSR), a novel method designed to distribute computational resources adaptively at the pixel level. A PCSR model comprises a backbone, a pixel-level classifier, and a set of pixel-level upsamplers with varying capacities. The pixel-level classifier assigns each pixel to an appropriate upsampler based on its restoration difficulty, thereby optimizing computational resource usage. Our method allows for performance and computational cost balance during inference without re-training. Our experiments demonstrate PCSR's advantage over existing patch-distributing methods in PSNR-FLOP trade-offs across different backbone models and benchmarks.
## Dependencies
- Python 3.7
- Pytorch 1.13
- NVIDIA GPU + CUDA
- Python packages: `pip install numpy opencv-python pandas tqdm fast_pytorch_kmeans`
## How to Use
git clone https://huggingface.co/3587jjh/pcsr_carn
```python
####### demo.py #######
import torch
import models
from torchvision import transforms
from utils import *
from PIL import Image
import numpy as np
img_path = 'myimage.png' # only support .png
scale = 4 # only support x4
# k: hyperparameter to traverse PSNR-FLOPs trade-off. smaller k → larger FLOPs & PSNR. range is about [-1,2].
# adaptive: whether to use automatic decision of k
# no_refinement: whether not to use pixel-wise refinement (postprocessing for reducing artifacts)
# parser.add_argument('--opacity', type=float, default=0.65, help='opacity for colored visualization')
# parser.add_argument('--pixel_batch_size', type=int, default=300000)
resume_path = 'carn-pcsr-phase1.pth'
sv_file = torch.load(resume_path)
model = models.make(sv_file['model'], load_sd=True).cuda()
model.eval()
rgb_mean = torch.tensor([0.4488, 0.4371, 0.4040], device='cuda').view(1,3,1,1)
rgb_std = torch.tensor([1.0, 1.0, 1.0], device='cuda').view(1,3,1,1)
with torch.no_grad():
# prepare inputs
lr = transforms.ToTensor()(Image.open(img_path)).unsqueeze(0).cuda() # (1,3,h,w), range=[0,1]
h,w = lr.shape[-2:]
H,W = h*scale, w*scale
coord = make_coord((H,W), flatten=True, device='cuda').unsqueeze(0)
cell = torch.ones_like(coord)
cell[:,:,0] *= 2/H
cell[:,:,1] *= 2/W
inp_lr = (lr - rgb_mean) / rgb_std
pred, flag = model(inp_lr, coord=coord, cell=cell, scale=scale, k=0,
pixel_batch_size=300000, adaptive_cluster=True, refinement=True)
flops = get_model_flops(model, inp_lr, coord=coord, cell=cell, scale=scale, k=0,
pixel_batch_size=300000, adaptive_cluster=True, refinement=True)
max_flops = get_model_flops(model, inp_lr, coord=coord, cell=cell, scale=scale, k=-25,
pixel_batch_size=300000, adaptive_cluster=False, refinement=True)
print('flops: {:.1f}G ({:.1f} %) | max_flops: {:.1f}G (100 %)'.format(flops/1e9,
(flops / max_flops)*100, max_flops/1e9))
pred = pred.transpose(1,2).view(-1,3,H,W)
pred = pred * rgb_std + rgb_mean
pred = tensor2numpy(pred)
Image.fromarray(pred).save(f'output.png')
flag = flag.view(-1,1,H,W).repeat(1,3,1,1).squeeze(0).detach().cpu()
H,W = pred.shape[:2]
vis_img = np.zeros_like(pred)
vis_img[flag[0] == 0] = np.array([0,255,0])
vis_img[flag[0] == 1] = np.array([255,0,0])
vis_img = vis_img*0.35 + pred*0.65
Image.fromarray(vis_img.astype('uint8')).save('output_vis.png')
```
## Citation
```
@misc{jeong2024acceleratingimagesuperresolutionnetworks,
title={Accelerating Image Super-Resolution Networks with Pixel-Level Classification},
author={Jinho Jeong and Jinwoo Kim and Younghyun Jo and Seon Joo Kim},
year={2024},
eprint={2407.21448},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2407.21448},
}
```