File size: 5,502 Bytes
5bfe353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7476935
5bfe353
693c90b
5bfe353
7476935
5bfe353
 
 
7476935
5bfe353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#!/usr/bin/env python

from __future__ import annotations

import argparse
import torch
import gradio as gr

from Scenimefy.options.test_options import TestOptions
from Scenimefy.models import create_model
from Scenimefy.utils.util import tensor2im

from PIL import Image
import torchvision.transforms as transforms


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=str, default='cpu')
    parser.add_argument('--theme', type=str)
    parser.add_argument('--live', action='store_true')
    parser.add_argument('--share', action='store_true')
    parser.add_argument('--port', type=int)
    parser.add_argument('--disable-queue',
                        dest='enable_queue',
                        action='store_false')
    parser.add_argument('--allow-flagging', type=str, default='never')
    parser.add_argument('--allow-screenshot', action='store_true')
    return parser.parse_args()

TITLE = '''
        Scene Stylization with <a href="https://github.com/Yuxinn-J/Scenimefy">Scenimefy</a>
        '''
DESCRIPTION = '''
<div align=center>
<p> 
Gradio Demo for Scenimefy - a model transforming real-life photos into Shinkai-animation-style images. 
To use it, simply upload your image, or click one of the examples to load them.  
For best outcomes, please pick a natural landscape image similar to the examples below. 
Kindly note that our model is trained on 256x256 resolution images, using much higher resolutions might affect its performance. 
Read more in our <a href="https://arxiv.org/abs/2308.12968">paper</a>. 
</p>
</div>
'''
EXAMPLES = [['0.jpg'], ['1.png'], ['2.jpg'], ['3.png'], ['4.png'], ['5.png'], ['6.jpg'], ['7.png'], ['8.png']]
ARTICLE = r"""
If Scenimefy is helpful, please help to ⭐ the <a href='https://github.com/Yuxinn-J/Scenimefy' target='_blank'>Github Repo</a>. Thank you! 
🤟 **Citation**
If our work is useful for your research, please consider citing:
```bibtex
@inproceedings{jiang2023scenimefy,
  title={Scenimefy: Learning to Craft Anime Scene via Semi-Supervised Image-to-Image Translation},
  author={Jiang, Yuxin and Jiang, Liming and Yang, Shuai and Loy, Chen Change},
  booktitle={ICCV},
  year={2023}
}
```
🗞️ **License**
This project is licensed under <a rel="license" href="https://github.com/Yuxinn-J/Scenimefy/blob/main/LICENSE.md">S-Lab License 1.0</a>. 
Redistribution and use for non-commercial purposes should follow this license.
"""


model = None


def initialize():
    opt = TestOptions().parse()  # get test options
    # os.environ["CUDA_VISIBLE_DEVICES"] = str(1)
    # hard-code some parameters for test
    opt.num_threads = 0   # test code only supports num_threads = 1
    opt.batch_size = 1    # test code only supports batch_size = 1
    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    opt.no_flip = True    # no flip; comment this line if results on flipped images are needed.
    opt.display_id = -1   # no visdom display; the test code saves the results to a HTML file.
   
    # dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
    global model
    model = create_model(opt)      # create a model given opt.model and other options

    dummy_data = {
        'A': torch.ones(1, 3, 256, 256),
        'B': torch.ones(1, 3, 256, 256),
        'A_paths': 'upload.jpg'
    }

    model.data_dependent_initialize(dummy_data)
    model.setup(opt)               # regular setup: load and print networks; create schedulers
    model.parallelize()
    return model


def __make_power_2(img, base, method=Image.BICUBIC):
    ow, oh = img.size
    h = int(round(oh / base) * base)
    w = int(round(ow / base) * base)
    if h == oh and w == ow:
        return img

    return img.resize((w, h), method)


def get_transform():
    method=Image.BICUBIC
    transform_list = []
    # if opt.preprocess == 'none':
    transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
    transform_list += [transforms.ToTensor()]
    transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)


def inference(img):
    transform = get_transform()
    A = transform(img.convert('RGB')) # A.shape: torch.Size([3, 260, 460])
    A = A.unsqueeze(0) # A.shape: torch.Size([1, 3, 260, 460])
    
    upload_data = {
        'A': A,
        'B': torch.ones_like(A),
        'A_paths': 'upload.jpg'
    }

    global model
    model.set_input(upload_data)  # unpack data from data loader
    model.test()           # run inference
    visuals = model.get_current_visuals()
    return tensor2im(visuals['fake_B'])


def main():
    args = parse_args()
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print('*** Now using %s.'%(args.device))
    
    global model 
    model = initialize()

    gr.Interface(
        inference, 
        gr.Image(type="pil", label='Input'),
        gr.Image(type="pil", label='Output').style(height=300),
        theme=args.theme, 
        title=TITLE,
        description=DESCRIPTION, 
        article=ARTICLE, 
        examples=EXAMPLES,
        allow_screenshot=args.allow_screenshot,
        allow_flagging=args.allow_flagging,
        live=args.live
    ).launch(
        enable_queue=args.enable_queue,
        server_port=args.port,
        share=args.share
    )

if __name__ == '__main__':
    main()