aliabd commited on
Commit
05fb2e7
1 Parent(s): 101eb22

full demo working

Browse files
Files changed (6) hide show
  1. LICENSE +21 -0
  2. app.py +43 -0
  3. data.py +91 -0
  4. model.py +121 -0
  5. requirements.txt +5 -0
  6. test.py +42 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Xiaoyu Xiang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from data import get_image_list
4
+ from model import create_model
5
+ from data import read_img_path, tensor_to_img, save_image
6
+ import gradio as gr
7
+ import torchtext
8
+ from PIL import Image
9
+ import torch
10
+
11
+ torch.hub.download_url_to_file('https://upload.wikimedia.org/wikipedia/commons/thumb/a/a5/Tsunami_by_hokusai_19th_century.jpg/1920px-Tsunami_by_hokusai_19th_century.jpg', 'wave.jpg')
12
+ torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2020/10/02/13/49/bridge-5621201_1280.jpg', 'building.jpg')
13
+
14
+ torchtext.utils.download_from_url("https://drive.google.com/uc?id=1RILKwUdjjBBngB17JHwhZNBEaW4Mr-Ml", root="./weights/")
15
+ gpu_ids=[]
16
+ model = create_model(gpu_ids)
17
+ # model.eval()
18
+
19
+ def sketch2anime(img, load_size=512):
20
+ img, aus_resize = read_img_path(img.name, load_size)
21
+ aus_tensor = model(img)
22
+ aus_img = tensor_to_img(aus_tensor)
23
+ image_pil = Image.fromarray(aus_img)
24
+ image_pil = image_pil.resize(aus_resize, Image.BICUBIC)
25
+ return image_pil
26
+
27
+
28
+ title = "Anime2Sketch"
29
+ description = "A sketch extractor for illustration, anime art and manga. Read more at the links below."
30
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.05703'>Adversarial Open Domain Adaption for Sketch-to-Photo Synthesis</a> | <a href='https://github.com/Mukosame/Anime2Sketch'>Github Repo</a></p>"
31
+
32
+ gr.Interface(
33
+ sketch2anime,
34
+ [gr.inputs.Image(type="file", label="Input")],
35
+ gr.outputs.Image(type="pil", label="Output"),
36
+ title=title,
37
+ description=description,
38
+ article=article,
39
+ examples=[
40
+ ["test_samples/madoka.jpg"],
41
+ ["building.jpg"],
42
+ ["wave.jpg"]
43
+ ]).launch(debug=True)
data.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import torchvision.transforms as transforms
4
+ import numpy as np
5
+ import torch
6
+
7
+ IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
8
+
9
+ def is_image_file(filename):
10
+ """if a given filename is a valid image
11
+ Parameters:
12
+ filename (str) -- image filename
13
+ """
14
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
15
+
16
+ def get_image_list(path):
17
+ """read the paths of valid images from the given directory path
18
+ Parameters:
19
+ path (str) -- input directory path
20
+ """
21
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
22
+ images = []
23
+ for dirpath, _, fnames in sorted(os.walk(path)):
24
+ for fname in sorted(fnames):
25
+ if is_image_file(fname):
26
+ img_path = os.path.join(dirpath, fname)
27
+ images.append(img_path)
28
+ assert images, '{:s} has no valid image file'.format(path)
29
+ return images
30
+
31
+ def get_transform(load_size=0, grayscale=False, method=Image.BICUBIC, convert=True):
32
+ transform_list = []
33
+ if grayscale:
34
+ transform_list.append(transforms.Grayscale(1))
35
+ if load_size > 0:
36
+ osize = [load_size, load_size]
37
+ transform_list.append(transforms.Resize(osize, method))
38
+ if convert:
39
+ transform_list += [transforms.ToTensor()]
40
+ if grayscale:
41
+ transform_list += [transforms.Normalize((0.5,), (0.5,))]
42
+ else:
43
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
44
+ return transforms.Compose(transform_list)
45
+
46
+ def read_img_path(path, load_size):
47
+ """read tensors from a given image path
48
+ Parameters:
49
+ path (str) -- input image path
50
+ load_size(int) -- the input size. If <= 0, don't resize
51
+ """
52
+ img = Image.open(path).convert('RGB')
53
+ aus_resize = None
54
+ if load_size > 0:
55
+ aus_resize = img.size
56
+ transform = get_transform(load_size=load_size)
57
+ image = transform(img)
58
+ return image.unsqueeze(0), aus_resize
59
+
60
+ def tensor_to_img(input_image, imtype=np.uint8):
61
+ """"Converts a Tensor array into a numpy image array.
62
+ Parameters:
63
+ input_image (tensor) -- the input image tensor array
64
+ imtype (type) -- the desired type of the converted numpy array
65
+ """
66
+
67
+ if not isinstance(input_image, np.ndarray):
68
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
69
+ image_tensor = input_image.data
70
+ else:
71
+ return input_image
72
+ image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
73
+ if image_numpy.shape[0] == 1: # grayscale to RGB
74
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
75
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
76
+ else: # if it is a numpy array, do nothing
77
+ image_numpy = input_image
78
+ return image_numpy.astype(imtype)
79
+
80
+ def save_image(image_numpy, image_path, output_resize=None):
81
+ """Save a numpy image to the disk
82
+ Parameters:
83
+ image_numpy (numpy array) -- input numpy array
84
+ image_path (str) -- the path of the image
85
+ output_resize(None or tuple) -- the output size. If None, don't resize
86
+ """
87
+
88
+ image_pil = Image.fromarray(image_numpy)
89
+ if output_resize:
90
+ image_pil = image_pil.resize(output_resize, Image.BICUBIC)
91
+ image_pil.save(image_path)
model.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import functools
4
+
5
+
6
+ class UnetGenerator(nn.Module):
7
+ """Create a Unet-based generator"""
8
+
9
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
10
+ """Construct a Unet generator
11
+ Parameters:
12
+ input_nc (int) -- the number of channels in input images
13
+ output_nc (int) -- the number of channels in output images
14
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
15
+ image of size 128x128 will become of size 1x1 # at the bottleneck
16
+ ngf (int) -- the number of filters in the last conv layer
17
+ norm_layer -- normalization layer
18
+ We construct the U-Net from the innermost layer to the outermost layer.
19
+ It is a recursive process.
20
+ """
21
+ super(UnetGenerator, self).__init__()
22
+ # construct unet structure
23
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
24
+ for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
25
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
26
+ # gradually reduce the number of filters from ngf * 8 to ngf
27
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
28
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
29
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
30
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
31
+
32
+ def forward(self, input):
33
+ """Standard forward"""
34
+ return self.model(input)
35
+
36
+ class UnetSkipConnectionBlock(nn.Module):
37
+ """Defines the Unet submodule with skip connection.
38
+ X -------------------identity----------------------
39
+ |-- downsampling -- |submodule| -- upsampling --|
40
+ """
41
+
42
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
43
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
44
+ """Construct a Unet submodule with skip connections.
45
+ Parameters:
46
+ outer_nc (int) -- the number of filters in the outer conv layer
47
+ inner_nc (int) -- the number of filters in the inner conv layer
48
+ input_nc (int) -- the number of channels in input images/features
49
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
50
+ outermost (bool) -- if this module is the outermost module
51
+ innermost (bool) -- if this module is the innermost module
52
+ norm_layer -- normalization layer
53
+ use_dropout (bool) -- if use dropout layers.
54
+ """
55
+ super(UnetSkipConnectionBlock, self).__init__()
56
+ self.outermost = outermost
57
+ if type(norm_layer) == functools.partial:
58
+ use_bias = norm_layer.func == nn.InstanceNorm2d
59
+ else:
60
+ use_bias = norm_layer == nn.InstanceNorm2d
61
+ if input_nc is None:
62
+ input_nc = outer_nc
63
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
64
+ stride=2, padding=1, bias=use_bias)
65
+ downrelu = nn.LeakyReLU(0.2, True)
66
+ downnorm = norm_layer(inner_nc)
67
+ uprelu = nn.ReLU(True)
68
+ upnorm = norm_layer(outer_nc)
69
+
70
+ if outermost:
71
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
72
+ kernel_size=4, stride=2,
73
+ padding=1)
74
+ down = [downconv]
75
+ up = [uprelu, upconv, nn.Tanh()]
76
+ model = down + [submodule] + up
77
+ elif innermost:
78
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
79
+ kernel_size=4, stride=2,
80
+ padding=1, bias=use_bias)
81
+ down = [downrelu, downconv]
82
+ up = [uprelu, upconv, upnorm]
83
+ model = down + up
84
+ else:
85
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
86
+ kernel_size=4, stride=2,
87
+ padding=1, bias=use_bias)
88
+ down = [downrelu, downconv, downnorm]
89
+ up = [uprelu, upconv, upnorm]
90
+
91
+ if use_dropout:
92
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
93
+ else:
94
+ model = down + [submodule] + up
95
+
96
+ self.model = nn.Sequential(*model)
97
+
98
+ def forward(self, x):
99
+ if self.outermost:
100
+ return self.model(x)
101
+ else: # add skip connections
102
+ return torch.cat([x, self.model(x)], 1)
103
+
104
+
105
+ def create_model(gpu_ids=[]):
106
+ """Create a model for anime2sketch
107
+ hardcoding the options for simplicity
108
+ """
109
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
110
+ net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
111
+ ckpt = torch.load('weights/netG.pth')
112
+ for key in list(ckpt.keys()):
113
+ if 'module.' in key:
114
+ ckpt[key.replace('module.', '')] = ckpt[key]
115
+ del ckpt[key]
116
+ net.load_state_dict(ckpt)
117
+ if len(gpu_ids) > 0:
118
+ assert(torch.cuda.is_available())
119
+ net.to(gpu_ids[0])
120
+ net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
121
+ return net
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ Pillow
4
+ gradio
5
+ torchtext
test.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test script for anime-to-sketch translation
2
+ Example:
3
+ python3 test.py --dataroot /your_path/dir --load_size 512
4
+ python3 test.py --dataroot /your_path/img.jpg --load_size 512
5
+ """
6
+
7
+ import os
8
+ from data import get_image_list
9
+ from model import create_model
10
+ from data import read_img_path, tensor_to_img, save_image
11
+ import argparse
12
+
13
+
14
+ if __name__ == '__main__':
15
+ parser = argparse.ArgumentParser(description='Anime-to-sketch test options.')
16
+ parser.add_argument('--dataroot','-i', default='test_samples/', type=str)
17
+ parser.add_argument('--load_size','-s', default=512, type=int)
18
+ parser.add_argument('--output_dir','-o', default='results/', type=str)
19
+ parser.add_argument('--gpu_ids', '-g', default=[], help="gpu ids: e.g. 0 0,1,2 0,2.")
20
+ opt = parser.parse_args()
21
+
22
+ # create model
23
+ model = create_model(opt.gpu_ids) # create a model given opt.model and other options
24
+ model.eval()
25
+ # get input data
26
+ if os.path.isdir(opt.dataroot):
27
+ test_list = get_image_list(opt.dataroot)
28
+ elif os.path.isfile(opt.dataroot):
29
+ test_list = [opt.dataroot]
30
+ else:
31
+ raise Exception("{} is not a valid directory or image file.".format(opt.dataroot))
32
+ # save outputs
33
+ save_dir = opt.output_dir
34
+ os.makedirs(save_dir, exist_ok=True)
35
+
36
+ for test_path in test_list:
37
+ basename = os.path.basename(test_path)
38
+ aus_path = os.path.join(save_dir, basename)
39
+ img, aus_resize = read_img_path(test_path, opt.load_size)
40
+ aus_tensor = model(img)
41
+ aus_img = tensor_to_img(aus_tensor)
42
+ save_image(aus_img, aus_path, aus_resize)