p1atdev commited on
Commit
c5d1577
·
1 Parent(s): f67d7a2

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/macos,windows,linux,python
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=macos,windows,linux,python
3
+
4
+ ### Linux ###
5
+ *~
6
+
7
+ # temporary files which can be created if a process still has a handle open of a deleted file
8
+ .fuse_hidden*
9
+
10
+ # KDE directory preferences
11
+ .directory
12
+
13
+ # Linux trash folder which might appear on any partition or disk
14
+ .Trash-*
15
+
16
+ # .nfs files are created when an open file is removed but is still being accessed
17
+ .nfs*
18
+
19
+ ### macOS ###
20
+ # General
21
+ .DS_Store
22
+ .AppleDouble
23
+ .LSOverride
24
+
25
+ # Icon must end with two \r
26
+ Icon
27
+
28
+ # Thumbnails
29
+ ._*
30
+
31
+ # Files that might appear in the root of a volume
32
+ .DocumentRevisions-V100
33
+ .fseventsd
34
+ .Spotlight-V100
35
+ .TemporaryItems
36
+ .Trashes
37
+ .VolumeIcon.icns
38
+ .com.apple.timemachine.donotpresent
39
+
40
+ # Directories potentially created on remote AFP share
41
+ .AppleDB
42
+ .AppleDesktop
43
+ Network Trash Folder
44
+ Temporary Items
45
+ .apdisk
46
+
47
+ ### macOS Patch ###
48
+ # iCloud generated files
49
+ *.icloud
50
+
51
+ ### Python ###
52
+ # Byte-compiled / optimized / DLL files
53
+ __pycache__/
54
+ *.py[cod]
55
+ *$py.class
56
+
57
+ # C extensions
58
+ *.so
59
+
60
+ # Distribution / packaging
61
+ .Python
62
+ build/
63
+ develop-eggs/
64
+ dist/
65
+ downloads/
66
+ eggs/
67
+ .eggs/
68
+ lib/
69
+ lib64/
70
+ parts/
71
+ sdist/
72
+ var/
73
+ wheels/
74
+ share/python-wheels/
75
+ *.egg-info/
76
+ .installed.cfg
77
+ *.egg
78
+ MANIFEST
79
+
80
+ # PyInstaller
81
+ # Usually these files are written by a python script from a template
82
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
83
+ *.manifest
84
+ *.spec
85
+
86
+ # Installer logs
87
+ pip-log.txt
88
+ pip-delete-this-directory.txt
89
+
90
+ # Unit test / coverage reports
91
+ htmlcov/
92
+ .tox/
93
+ .nox/
94
+ .coverage
95
+ .coverage.*
96
+ .cache
97
+ nosetests.xml
98
+ coverage.xml
99
+ *.cover
100
+ *.py,cover
101
+ .hypothesis/
102
+ .pytest_cache/
103
+ cover/
104
+
105
+ # Translations
106
+ *.mo
107
+ *.pot
108
+
109
+ # Django stuff:
110
+ *.log
111
+ local_settings.py
112
+ db.sqlite3
113
+ db.sqlite3-journal
114
+
115
+ # Flask stuff:
116
+ instance/
117
+ .webassets-cache
118
+
119
+ # Scrapy stuff:
120
+ .scrapy
121
+
122
+ # Sphinx documentation
123
+ docs/_build/
124
+
125
+ # PyBuilder
126
+ .pybuilder/
127
+ target/
128
+
129
+ # Jupyter Notebook
130
+ .ipynb_checkpoints
131
+
132
+ # IPython
133
+ profile_default/
134
+ ipython_config.py
135
+
136
+ # pyenv
137
+ # For a library or package, you might want to ignore these files since the code is
138
+ # intended to run in multiple environments; otherwise, check them in:
139
+ # .python-version
140
+
141
+ # pipenv
142
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
143
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
144
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
145
+ # install all needed dependencies.
146
+ #Pipfile.lock
147
+
148
+ # poetry
149
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
150
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
151
+ # commonly ignored for libraries.
152
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
153
+ #poetry.lock
154
+
155
+ # pdm
156
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
157
+ #pdm.lock
158
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
159
+ # in version control.
160
+ # https://pdm.fming.dev/#use-with-ide
161
+ .pdm.toml
162
+
163
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
164
+ __pypackages__/
165
+
166
+ # Celery stuff
167
+ celerybeat-schedule
168
+ celerybeat.pid
169
+
170
+ # SageMath parsed files
171
+ *.sage.py
172
+
173
+ # Environments
174
+ .env
175
+ .venv
176
+ env/
177
+ venv/
178
+ ENV/
179
+ env.bak/
180
+ venv.bak/
181
+
182
+ # Spyder project settings
183
+ .spyderproject
184
+ .spyproject
185
+
186
+ # Rope project settings
187
+ .ropeproject
188
+
189
+ # mkdocs documentation
190
+ /site
191
+
192
+ # mypy
193
+ .mypy_cache/
194
+ .dmypy.json
195
+ dmypy.json
196
+
197
+ # Pyre type checker
198
+ .pyre/
199
+
200
+ # pytype static type analyzer
201
+ .pytype/
202
+
203
+ # Cython debug symbols
204
+ cython_debug/
205
+
206
+ # PyCharm
207
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
208
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
209
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
210
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
211
+ #.idea/
212
+
213
+ ### Python Patch ###
214
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
215
+ poetry.toml
216
+
217
+ # ruff
218
+ .ruff_cache/
219
+
220
+ # LSP config files
221
+ pyrightconfig.json
222
+
223
+ ### Windows ###
224
+ # Windows thumbnail cache files
225
+ Thumbs.db
226
+ Thumbs.db:encryptable
227
+ ehthumbs.db
228
+ ehthumbs_vista.db
229
+
230
+ # Dump file
231
+ *.stackdump
232
+
233
+ # Folder config file
234
+ [Dd]esktop.ini
235
+
236
+ # Recycle Bin used on file shares
237
+ $RECYCLE.BIN/
238
+
239
+ # Windows Installer files
240
+ *.cab
241
+ *.msi
242
+ *.msix
243
+ *.msm
244
+ *.msp
245
+
246
+ # Windows shortcuts
247
+ *.lnk
248
+
249
+ # End of https://www.toptal.com/developers/gitignore/api/macos,windows,linux,python
250
+
251
+ *.pth
252
+ gradio_cached_examples
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: MangaLineExtraction
3
- emoji: 💻
4
  colorFrom: green
5
  colorTo: yellow
6
  sdk: gradio
@@ -10,4 +10,7 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
1
  ---
2
+ title: Anime to Sketch
3
+ emoji: 💭
4
  colorFrom: green
5
  colorTo: yellow
6
  sdk: gradio
 
10
  license: mit
11
  ---
12
 
13
+ Original repo:
14
+ - MangaLineExtraction: https://github.com/ljsabc/MangaLineExtraction_PyTorch
15
+ - Anime2Sketch: https://github.com/Mukosame/Anime2Sketch
16
+
anime2sketch/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Xiaoyu Xiang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
anime2sketch/model.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import torchvision.transforms as transforms
4
+
5
+ try:
6
+ from torchvision.transforms import InterpolationMode
7
+
8
+ bic = InterpolationMode.BICUBIC
9
+ except ImportError:
10
+ bic = Image.BICUBIC
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import functools
16
+
17
+ IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".webp"]
18
+
19
+
20
+ class UnetGenerator(nn.Module):
21
+ """Create a Unet-based generator"""
22
+
23
+ def __init__(
24
+ self,
25
+ input_nc,
26
+ output_nc,
27
+ num_downs,
28
+ ngf=64,
29
+ norm_layer=nn.BatchNorm2d,
30
+ use_dropout=False,
31
+ ):
32
+ """Construct a Unet generator
33
+ Parameters:
34
+ input_nc (int) -- the number of channels in input images
35
+ output_nc (int) -- the number of channels in output images
36
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
37
+ image of size 128x128 will become of size 1x1 # at the bottleneck
38
+ ngf (int) -- the number of filters in the last conv layer
39
+ norm_layer -- normalization layer
40
+ We construct the U-Net from the innermost layer to the outermost layer.
41
+ It is a recursive process.
42
+ """
43
+ super(UnetGenerator, self).__init__()
44
+ # construct unet structure
45
+ unet_block = UnetSkipConnectionBlock(
46
+ ngf * 8,
47
+ ngf * 8,
48
+ input_nc=None,
49
+ submodule=None,
50
+ norm_layer=norm_layer,
51
+ innermost=True,
52
+ ) # add the innermost layer
53
+ for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
54
+ unet_block = UnetSkipConnectionBlock(
55
+ ngf * 8,
56
+ ngf * 8,
57
+ input_nc=None,
58
+ submodule=unet_block,
59
+ norm_layer=norm_layer,
60
+ use_dropout=use_dropout,
61
+ )
62
+ # gradually reduce the number of filters from ngf * 8 to ngf
63
+ unet_block = UnetSkipConnectionBlock(
64
+ ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer
65
+ )
66
+ unet_block = UnetSkipConnectionBlock(
67
+ ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer
68
+ )
69
+ unet_block = UnetSkipConnectionBlock(
70
+ ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer
71
+ )
72
+ self.model = UnetSkipConnectionBlock(
73
+ output_nc,
74
+ ngf,
75
+ input_nc=input_nc,
76
+ submodule=unet_block,
77
+ outermost=True,
78
+ norm_layer=norm_layer,
79
+ ) # add the outermost layer
80
+
81
+ def forward(self, input):
82
+ """Standard forward"""
83
+ return self.model(input)
84
+
85
+
86
+ class UnetSkipConnectionBlock(nn.Module):
87
+ """Defines the Unet submodule with skip connection.
88
+ X -------------------identity----------------------
89
+ |-- downsampling -- |submodule| -- upsampling --|
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ outer_nc,
95
+ inner_nc,
96
+ input_nc=None,
97
+ submodule=None,
98
+ outermost=False,
99
+ innermost=False,
100
+ norm_layer=nn.BatchNorm2d,
101
+ use_dropout=False,
102
+ ):
103
+ """Construct a Unet submodule with skip connections.
104
+ Parameters:
105
+ outer_nc (int) -- the number of filters in the outer conv layer
106
+ inner_nc (int) -- the number of filters in the inner conv layer
107
+ input_nc (int) -- the number of channels in input images/features
108
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
109
+ outermost (bool) -- if this module is the outermost module
110
+ innermost (bool) -- if this module is the innermost module
111
+ norm_layer -- normalization layer
112
+ use_dropout (bool) -- if use dropout layers.
113
+ """
114
+ super(UnetSkipConnectionBlock, self).__init__()
115
+ self.outermost = outermost
116
+ if type(norm_layer) == functools.partial:
117
+ use_bias = norm_layer.func == nn.InstanceNorm2d
118
+ else:
119
+ use_bias = norm_layer == nn.InstanceNorm2d
120
+ if input_nc is None:
121
+ input_nc = outer_nc
122
+ downconv = nn.Conv2d(
123
+ input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias
124
+ )
125
+ downrelu = nn.LeakyReLU(0.2, True)
126
+ downnorm = norm_layer(inner_nc)
127
+ uprelu = nn.ReLU(True)
128
+ upnorm = norm_layer(outer_nc)
129
+
130
+ if outermost:
131
+ upconv = nn.ConvTranspose2d(
132
+ inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1
133
+ )
134
+ down = [downconv]
135
+ up = [uprelu, upconv, nn.Tanh()]
136
+ model = down + [submodule] + up
137
+ elif innermost:
138
+ upconv = nn.ConvTranspose2d(
139
+ inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias
140
+ )
141
+ down = [downrelu, downconv]
142
+ up = [uprelu, upconv, upnorm]
143
+ model = down + up
144
+ else:
145
+ upconv = nn.ConvTranspose2d(
146
+ inner_nc * 2,
147
+ outer_nc,
148
+ kernel_size=4,
149
+ stride=2,
150
+ padding=1,
151
+ bias=use_bias,
152
+ )
153
+ down = [downrelu, downconv, downnorm]
154
+ up = [uprelu, upconv, upnorm]
155
+
156
+ if use_dropout:
157
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
158
+ else:
159
+ model = down + [submodule] + up
160
+
161
+ self.model = nn.Sequential(*model)
162
+
163
+ def forward(self, x):
164
+ if self.outermost:
165
+ return self.model(x)
166
+ else: # add skip connections
167
+ return torch.cat([x, self.model(x)], 1)
168
+
169
+
170
+ class Anime2Sketch:
171
+ def __init__(
172
+ self, model_path: str = "./models/netG.pth", device: str = "cpu"
173
+ ) -> None:
174
+ norm_layer = functools.partial(
175
+ nn.InstanceNorm2d, affine=False, track_running_stats=False
176
+ )
177
+ net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
178
+ ckpt = torch.load(model_path)
179
+
180
+ for key in list(ckpt.keys()):
181
+ if "module." in key:
182
+ ckpt[key.replace("module.", "")] = ckpt[key]
183
+ del ckpt[key]
184
+
185
+ net.load_state_dict(ckpt)
186
+
187
+ self.model = net
188
+
189
+ if torch.cuda.is_available() and device == "cuda":
190
+ self.device = "cuda"
191
+ self.model.to(device)
192
+ else:
193
+ self.device = "cpu"
194
+ self.model.to("cpu")
195
+
196
+ def predict(self, image: Image.Image, load_size: int = 512) -> Image:
197
+ try:
198
+ aus_resize = None
199
+ if load_size > 0:
200
+ aus_resize = image.size
201
+ transform = self.get_transform(load_size=load_size)
202
+ image = transform(image)
203
+ img = image.unsqueeze(0)
204
+ except:
205
+ raise Exception("Error in reading image {}".format(image.filename))
206
+
207
+ aus_tensor = self.model(img.to(self.device))
208
+ aus_img = self.tensor_to_img(aus_tensor)
209
+
210
+ image_pil = Image.fromarray(aus_img)
211
+ if aus_resize:
212
+ bic = Image.BICUBIC
213
+ image_pil = image_pil.resize(aus_resize, bic)
214
+
215
+ return image_pil
216
+
217
+ def get_transform(self, load_size=0, grayscale=False, method=bic, convert=True):
218
+ transform_list = []
219
+ if grayscale:
220
+ transform_list.append(transforms.Grayscale(1))
221
+ if load_size > 0:
222
+ osize = [load_size, load_size]
223
+ transform_list.append(transforms.Resize(osize, method))
224
+ if convert:
225
+ transform_list += [transforms.ToTensor()]
226
+ if grayscale:
227
+ transform_list += [transforms.Normalize((0.5,), (0.5,))]
228
+ else:
229
+ transform_list += [
230
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
231
+ ]
232
+ return transforms.Compose(transform_list)
233
+
234
+ def tensor_to_img(self, input_image, imtype=np.uint8):
235
+ """ "Converts a Tensor array into a numpy image array.
236
+ Parameters:
237
+ input_image (tensor) -- the input image tensor array
238
+ imtype (type) -- the desired type of the converted numpy array
239
+ """
240
+
241
+ if not isinstance(input_image, np.ndarray):
242
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
243
+ image_tensor = input_image.data
244
+ else:
245
+ return input_image
246
+ image_numpy = (
247
+ image_tensor[0].cpu().float().numpy()
248
+ ) # convert it into a numpy array
249
+ if image_numpy.shape[0] == 1: # grayscale to RGB
250
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
251
+ image_numpy = (
252
+ (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
253
+ ) # post-processing: tranpose and scaling
254
+ else: # if it is a numpy array, do nothing
255
+ image_numpy = input_image
256
+ return image_numpy.astype(imtype)
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from setup import setup
3
+ import cv2
4
+ from PIL import Image
5
+ from manga_line_extraction.model import MangaLineExtractor
6
+ from anime2sketch.model import Anime2Sketch
7
+
8
+ setup()
9
+
10
+ print("Setup finished")
11
+
12
+ extractor = MangaLineExtractor("./models/erika.pth", "cpu")
13
+ to_sketch = Anime2Sketch("./models/netG.pth", "cpu")
14
+
15
+ print("Model loaded")
16
+
17
+
18
+ def extract(image):
19
+ return extractor.predict(image)
20
+
21
+
22
+ def convert_to_sketch(image):
23
+ return to_sketch.predict(image)
24
+
25
+
26
+ def start(image):
27
+ return [extract(image), convert_to_sketch(Image.fromarray(image).convert("RGB"))]
28
+
29
+
30
+ def ui():
31
+ with gr.Blocks() as blocks:
32
+ gr.Markdown(
33
+ """
34
+ # Anime to Sketch
35
+ Unofficial demo for converting illustrations into sketches.
36
+ Original repos:
37
+ - [MangaLineExtraction_PyTorch](https://github.com/ljsabc/MangaLineExtraction_PyTorch)
38
+ - [Anime2Sketch](https://github.com/Mukosame/Anime2Sketch)
39
+ """
40
+ )
41
+
42
+ with gr.Row():
43
+ with gr.Column():
44
+ input_img = gr.Image(label="Input", interactive=True)
45
+
46
+ extract_btn = gr.Button("Extract", variant="primary")
47
+
48
+ with gr.Column():
49
+ # with gr.Row():
50
+ extract_output_img = gr.Image(
51
+ label="MangaLineExtraction", interactive=False
52
+ )
53
+ to_sketch_output_img = gr.Image(label="Anime2Sketch", interactive=False)
54
+
55
+ gr.Examples(
56
+ fn=start,
57
+ examples=[
58
+ ["./examples/1.jpg"],
59
+ ["./examples/2.jpg"],
60
+ ["./examples/3.jpg"],
61
+ ["./examples/4.jpg"],
62
+ ["./examples/5.jpg"],
63
+ ["./examples/6.jpg"],
64
+ ["./examples/7.jpg"],
65
+ ["./examples/8.jpg"],
66
+ ],
67
+ inputs=[input_img],
68
+ outputs=[extract_output_img, to_sketch_output_img],
69
+ label="Examples",
70
+ cache_examples=True,
71
+ )
72
+
73
+ gr.Markdown("Images are from nijijourney.")
74
+
75
+ extract_btn.click(
76
+ fn=start,
77
+ inputs=[input_img],
78
+ outputs=[extract_output_img, to_sketch_output_img],
79
+ )
80
+
81
+ return blocks
82
+
83
+
84
+ if __name__ == "__main__":
85
+ ui().launch()
examples/1.jpg ADDED
examples/2.jpg ADDED
examples/3.jpg ADDED
examples/4.jpg ADDED
examples/5.jpg ADDED
examples/6.jpg ADDED
examples/7.jpg ADDED
examples/8.jpg ADDED
manga_line_extraction/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Miaomiao Li
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
manga_line_extraction/model.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.data.dataset import Dataset
5
+ from PIL import Image
6
+ import fnmatch
7
+ import cv2
8
+
9
+ import sys
10
+
11
+ import numpy as np
12
+
13
+ # torch.set_printoptions(precision=10)
14
+
15
+
16
+ class _bn_relu_conv(nn.Module):
17
+ def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
18
+ super(_bn_relu_conv, self).__init__()
19
+ self.model = nn.Sequential(
20
+ nn.BatchNorm2d(in_filters, eps=1e-3),
21
+ nn.LeakyReLU(0.2),
22
+ nn.Conv2d(
23
+ in_filters,
24
+ nb_filters,
25
+ (fw, fh),
26
+ stride=subsample,
27
+ padding=(fw // 2, fh // 2),
28
+ padding_mode="zeros",
29
+ ),
30
+ )
31
+
32
+ def forward(self, x):
33
+ return self.model(x)
34
+
35
+ # the following are for debugs
36
+ print(
37
+ "****",
38
+ np.max(x.cpu().numpy()),
39
+ np.min(x.cpu().numpy()),
40
+ np.mean(x.cpu().numpy()),
41
+ np.std(x.cpu().numpy()),
42
+ x.shape,
43
+ )
44
+ for i, layer in enumerate(self.model):
45
+ if i != 2:
46
+ x = layer(x)
47
+ else:
48
+ x = layer(x)
49
+ # x = nn.functional.pad(x, (1, 1, 1, 1), mode='constant', value=0)
50
+ print(
51
+ "____",
52
+ np.max(x.cpu().numpy()),
53
+ np.min(x.cpu().numpy()),
54
+ np.mean(x.cpu().numpy()),
55
+ np.std(x.cpu().numpy()),
56
+ x.shape,
57
+ )
58
+ print(x[0])
59
+ return x
60
+
61
+
62
+ class _u_bn_relu_conv(nn.Module):
63
+ def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
64
+ super(_u_bn_relu_conv, self).__init__()
65
+ self.model = nn.Sequential(
66
+ nn.BatchNorm2d(in_filters, eps=1e-3),
67
+ nn.LeakyReLU(0.2),
68
+ nn.Conv2d(
69
+ in_filters,
70
+ nb_filters,
71
+ (fw, fh),
72
+ stride=subsample,
73
+ padding=(fw // 2, fh // 2),
74
+ ),
75
+ nn.Upsample(scale_factor=2, mode="nearest"),
76
+ )
77
+
78
+ def forward(self, x):
79
+ return self.model(x)
80
+
81
+
82
+ class _shortcut(nn.Module):
83
+ def __init__(self, in_filters, nb_filters, subsample=1):
84
+ super(_shortcut, self).__init__()
85
+ self.process = False
86
+ self.model = None
87
+ if in_filters != nb_filters or subsample != 1:
88
+ self.process = True
89
+ self.model = nn.Sequential(
90
+ nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample)
91
+ )
92
+
93
+ def forward(self, x, y):
94
+ # print(x.size(), y.size(), self.process)
95
+ if self.process:
96
+ y0 = self.model(x)
97
+ # print("merge+", torch.max(y0+y), torch.min(y0+y),torch.mean(y0+y), torch.std(y0+y), y0.shape)
98
+ return y0 + y
99
+ else:
100
+ # print("merge", torch.max(x+y), torch.min(x+y),torch.mean(x+y), torch.std(x+y), y.shape)
101
+ return x + y
102
+
103
+
104
+ class _u_shortcut(nn.Module):
105
+ def __init__(self, in_filters, nb_filters, subsample):
106
+ super(_u_shortcut, self).__init__()
107
+ self.process = False
108
+ self.model = None
109
+ if in_filters != nb_filters:
110
+ self.process = True
111
+ self.model = nn.Sequential(
112
+ nn.Conv2d(
113
+ in_filters,
114
+ nb_filters,
115
+ (1, 1),
116
+ stride=subsample,
117
+ padding_mode="zeros",
118
+ ),
119
+ nn.Upsample(scale_factor=2, mode="nearest"),
120
+ )
121
+
122
+ def forward(self, x, y):
123
+ if self.process:
124
+ return self.model(x) + y
125
+ else:
126
+ return x + y
127
+
128
+
129
+ class basic_block(nn.Module):
130
+ def __init__(self, in_filters, nb_filters, init_subsample=1):
131
+ super(basic_block, self).__init__()
132
+ self.conv1 = _bn_relu_conv(
133
+ in_filters, nb_filters, 3, 3, subsample=init_subsample
134
+ )
135
+ self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
136
+ self.shortcut = _shortcut(in_filters, nb_filters, subsample=init_subsample)
137
+
138
+ def forward(self, x):
139
+ x1 = self.conv1(x)
140
+ x2 = self.residual(x1)
141
+ return self.shortcut(x, x2)
142
+
143
+
144
+ class _u_basic_block(nn.Module):
145
+ def __init__(self, in_filters, nb_filters, init_subsample=1):
146
+ super(_u_basic_block, self).__init__()
147
+ self.conv1 = _u_bn_relu_conv(
148
+ in_filters, nb_filters, 3, 3, subsample=init_subsample
149
+ )
150
+ self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
151
+ self.shortcut = _u_shortcut(in_filters, nb_filters, subsample=init_subsample)
152
+
153
+ def forward(self, x):
154
+ y = self.residual(self.conv1(x))
155
+ return self.shortcut(x, y)
156
+
157
+
158
+ class _residual_block(nn.Module):
159
+ def __init__(self, in_filters, nb_filters, repetitions, is_first_layer=False):
160
+ super(_residual_block, self).__init__()
161
+ layers = []
162
+ for i in range(repetitions):
163
+ init_subsample = 1
164
+ if i == repetitions - 1 and not is_first_layer:
165
+ init_subsample = 2
166
+ if i == 0:
167
+ l = basic_block(
168
+ in_filters=in_filters,
169
+ nb_filters=nb_filters,
170
+ init_subsample=init_subsample,
171
+ )
172
+ else:
173
+ l = basic_block(
174
+ in_filters=nb_filters,
175
+ nb_filters=nb_filters,
176
+ init_subsample=init_subsample,
177
+ )
178
+ layers.append(l)
179
+
180
+ self.model = nn.Sequential(*layers)
181
+
182
+ def forward(self, x):
183
+ return self.model(x)
184
+
185
+
186
+ class _upsampling_residual_block(nn.Module):
187
+ def __init__(self, in_filters, nb_filters, repetitions):
188
+ super(_upsampling_residual_block, self).__init__()
189
+ layers = []
190
+ for i in range(repetitions):
191
+ l = None
192
+ if i == 0:
193
+ l = _u_basic_block(
194
+ in_filters=in_filters, nb_filters=nb_filters
195
+ ) # (input)
196
+ else:
197
+ l = basic_block(in_filters=nb_filters, nb_filters=nb_filters) # (input)
198
+ layers.append(l)
199
+
200
+ self.model = nn.Sequential(*layers)
201
+
202
+ def forward(self, x):
203
+ return self.model(x)
204
+
205
+
206
+ class res_skip(nn.Module):
207
+ def __init__(self):
208
+ super(res_skip, self).__init__()
209
+ self.block0 = _residual_block(
210
+ in_filters=1, nb_filters=24, repetitions=2, is_first_layer=True
211
+ ) # (input)
212
+ self.block1 = _residual_block(
213
+ in_filters=24, nb_filters=48, repetitions=3
214
+ ) # (block0)
215
+ self.block2 = _residual_block(
216
+ in_filters=48, nb_filters=96, repetitions=5
217
+ ) # (block1)
218
+ self.block3 = _residual_block(
219
+ in_filters=96, nb_filters=192, repetitions=7
220
+ ) # (block2)
221
+ self.block4 = _residual_block(
222
+ in_filters=192, nb_filters=384, repetitions=12
223
+ ) # (block3)
224
+
225
+ self.block5 = _upsampling_residual_block(
226
+ in_filters=384, nb_filters=192, repetitions=7
227
+ ) # (block4)
228
+ self.res1 = _shortcut(
229
+ in_filters=192, nb_filters=192
230
+ ) # (block3, block5, subsample=(1,1))
231
+
232
+ self.block6 = _upsampling_residual_block(
233
+ in_filters=192, nb_filters=96, repetitions=5
234
+ ) # (res1)
235
+ self.res2 = _shortcut(
236
+ in_filters=96, nb_filters=96
237
+ ) # (block2, block6, subsample=(1,1))
238
+
239
+ self.block7 = _upsampling_residual_block(
240
+ in_filters=96, nb_filters=48, repetitions=3
241
+ ) # (res2)
242
+ self.res3 = _shortcut(
243
+ in_filters=48, nb_filters=48
244
+ ) # (block1, block7, subsample=(1,1))
245
+
246
+ self.block8 = _upsampling_residual_block(
247
+ in_filters=48, nb_filters=24, repetitions=2
248
+ ) # (res3)
249
+ self.res4 = _shortcut(
250
+ in_filters=24, nb_filters=24
251
+ ) # (block0,block8, subsample=(1,1))
252
+
253
+ self.block9 = _residual_block(
254
+ in_filters=24, nb_filters=16, repetitions=2, is_first_layer=True
255
+ ) # (res4)
256
+ self.conv15 = _bn_relu_conv(
257
+ in_filters=16, nb_filters=1, fh=1, fw=1, subsample=1
258
+ ) # (block7)
259
+
260
+ def forward(self, x):
261
+ x0 = self.block0(x)
262
+ x1 = self.block1(x0)
263
+ x2 = self.block2(x1)
264
+ x3 = self.block3(x2)
265
+ x4 = self.block4(x3)
266
+
267
+ x5 = self.block5(x4)
268
+ res1 = self.res1(x3, x5)
269
+
270
+ x6 = self.block6(res1)
271
+ res2 = self.res2(x2, x6)
272
+
273
+ x7 = self.block7(res2)
274
+ res3 = self.res3(x1, x7)
275
+
276
+ x8 = self.block8(res3)
277
+ res4 = self.res4(x0, x8)
278
+
279
+ x9 = self.block9(res4)
280
+ y = self.conv15(x9)
281
+
282
+ return y
283
+
284
+
285
+ class MangaLineExtractor:
286
+ def __init__(self, model_path: str = "erika.pth", device: str = "cpu"):
287
+ self.model = res_skip()
288
+ self.model.load_state_dict(torch.load(model_path))
289
+
290
+ self.is_cuda = torch.cuda.is_available() and device == "cuda"
291
+ if self.is_cuda:
292
+ self.model.cuda()
293
+ else:
294
+ self.model.cpu()
295
+
296
+ self.model.eval()
297
+
298
+ def predict(self, image):
299
+ src = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
300
+
301
+ rows = int(np.ceil(src.shape[0] / 16)) * 16
302
+ cols = int(np.ceil(src.shape[1] / 16)) * 16
303
+
304
+ # manually construct a batch. You can change it based on your usecases.
305
+ patch = np.ones((1, 1, rows, cols), dtype=np.float32)
306
+ patch[0, 0, 0 : src.shape[0], 0 : src.shape[1]] = src
307
+
308
+ if self.is_cuda:
309
+ tensor = torch.from_numpy(patch).cuda()
310
+ else:
311
+ tensor = torch.from_numpy(patch).cpu()
312
+
313
+ y = self.model(tensor)
314
+
315
+ yc = y.detach().numpy()[0, 0, :, :]
316
+ yc[yc > 255] = 255
317
+ yc[yc < 0] = 0
318
+ yc = yc / 255.0
319
+
320
+ output = yc[0 : src.shape[0], 0 : src.shape[1]]
321
+ output = cv2.cvtColor(output, cv2.COLOR_GRAY2BGR)
322
+
323
+ return output
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ opencv-python
5
+ huggingface_hub
setup.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import os
3
+ from huggingface_hub import hf_hub_download
4
+ from utils import custom_drive_cache_dir, get_drive
5
+
6
+ HF_TOKEN = os.getenv("HF_TOKEN")
7
+
8
+ MANGA_LINE_EXTRACTION_MODEL = "https://github.com/ljsabc/MangaLineExtraction_PyTorch/releases/download/v1/erika.pth"
9
+ ANIME2SKETCH_MODEL = {"REPO_ID": "p1atdev/Anime2Sketch", "FILENAME": "netG.pth"}
10
+
11
+
12
+ def download_manga_line_extraction_model():
13
+ if os.path.exists("./models/erika.pth"):
14
+ return
15
+
16
+
17
+ def download_anime2sketch_model():
18
+ if os.path.exists("./models/netG.pth"):
19
+ return
20
+
21
+ drive = get_drive("./models/netG.pth")
22
+ with custom_drive_cache_dir(drive) as cache_dir:
23
+ hf_hub_download(
24
+ repo_id=ANIME2SKETCH_MODEL["REPO_ID"],
25
+ filename=ANIME2SKETCH_MODEL["FILENAME"],
26
+ local_dir="./models",
27
+ use_auth_token=HF_TOKEN,
28
+ local_dir_use_symlinks=False,
29
+ cache_dir=cache_dir,
30
+ )
31
+
32
+
33
+ def setup():
34
+ download_manga_line_extraction_model()
35
+ download_anime2sketch_model()
utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import tempfile
3
+ from contextlib import contextmanager
4
+ import os
5
+
6
+
7
+ def get_drive(path: str):
8
+ path = Path(path).resolve()
9
+ drive = path.drive
10
+ root = path.root
11
+ return drive + root
12
+
13
+
14
+ @contextmanager
15
+ def custom_drive_cache_dir(drive: str):
16
+ drive = Path(drive)
17
+ base_dir = Path(drive) / "tmp"
18
+ if not base_dir.exists():
19
+ os.makedirs(base_dir)
20
+ print(f"Using {base_dir.resolve()} as cache dir")
21
+ with tempfile.TemporaryDirectory(dir=base_dir) as tmp_dir:
22
+ yield tmp_dir