#!/usr/bin/env python
from __future__ import annotations
import argparse
import torch
import gradio as gr
from Scenimefy.options.test_options import TestOptions
from Scenimefy.models import create_model
from Scenimefy.utils.util import tensor2im
from PIL import Image
import torchvision.transforms as transforms
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--theme', type=str)
parser.add_argument('--live', action='store_true')
parser.add_argument('--share', action='store_true')
parser.add_argument('--port', type=int)
parser.add_argument('--disable-queue',
dest='enable_queue',
action='store_false')
parser.add_argument('--allow-flagging', type=str, default='never')
parser.add_argument('--allow-screenshot', action='store_true')
return parser.parse_args()
TITLE = '''
Scene Stylization with Scenimefy
'''
DESCRIPTION = '''
Gradio Demo for Scenimefy.
To use it, simply upload your image, or click one of the examples to load them.
For best outcomes, please pick a natural landscape image similar to the examples below.
Kindly note that our model is trained on 256x256 resolution images, using much higher resolutions might affect its performance.
Read more at the links below.
'''
EXAMPLES = [['0.png'], ['1.jpg'], ['2.png'], ['3.png'], ['4.jpg'], ['5.png'], ['6.jpg'], ['7.png'], ['8.png']]
ARTICLE = r"""
If Scenimefy is helpful, please help to ⭐ the Github Repo. Thank you!
🤟 **Citation**
If our work is useful for your research, please consider citing:
```bibtex
@inproceedings{jiang2023scenimefy,
title={Scenimefy: Learning to Craft Anime Scene via Semi-Supervised Image-to-Image Translation},
author={Jiang, Yuxin and Jiang, Liming and Yang, Shuai and Loy, Chen Change},
booktitle={ICCV},
year={2023}
}
```
🗞️ **License**
This project is licensed under S-Lab License 1.0.
Redistribution and use for non-commercial purposes should follow this license.
"""
model = None
def initialize():
opt = TestOptions().parse() # get test options
# os.environ["CUDA_VISIBLE_DEVICES"] = str(1)
# hard-code some parameters for test
opt.num_threads = 0 # test code only supports num_threads = 1
opt.batch_size = 1 # test code only supports batch_size = 1
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
# dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
global model
model = create_model(opt) # create a model given opt.model and other options
dummy_data = {
'A': torch.ones(1, 3, 256, 256),
'B': torch.ones(1, 3, 256, 256),
'A_paths': 'upload.jpg'
}
model.data_dependent_initialize(dummy_data)
model.setup(opt) # regular setup: load and print networks; create schedulers
model.parallelize()
return model
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if h == oh and w == ow:
return img
return img.resize((w, h), method)
def get_transform():
method=Image.BICUBIC
transform_list = []
# if opt.preprocess == 'none':
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
transform_list += [transforms.ToTensor()]
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def inference(img):
transform = get_transform()
A = transform(img.convert('RGB')) # A.shape: torch.Size([3, 260, 460])
A = A.unsqueeze(0) # A.shape: torch.Size([1, 3, 260, 460])
upload_data = {
'A': A,
'B': torch.ones_like(A),
'A_paths': 'upload.jpg'
}
global model
model.set_input(upload_data) # unpack data from data loader
model.test() # run inference
visuals = model.get_current_visuals()
return tensor2im(visuals['fake_B'])
def main():
args = parse_args()
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('*** Now using %s.'%(args.device))
global model
model = initialize()
gr.Interface(
inference,
gr.Image(type="pil", label='Input'),
gr.Image(type="pil", label='Output').style(height=300),
theme=args.theme,
title=TITLE,
description=DESCRIPTION,
article=ARTICLE,
examples=EXAMPLES,
allow_screenshot=args.allow_screenshot,
allow_flagging=args.allow_flagging,
live=args.live
).launch(
enable_queue=args.enable_queue,
server_port=args.port,
share=args.share
)
if __name__ == '__main__':
main()