hysts HF staff commited on
Commit
9f9c100
1 Parent(s): b99b941
Files changed (4) hide show
  1. .gitmodules +3 -0
  2. Anime2Sketch +1 -0
  3. app.py +115 -0
  4. requirements.txt +4 -0
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "Anime2Sketch"]
2
+ path = Anime2Sketch
3
+ url = https://github.com/Mukosame/Anime2Sketch
Anime2Sketch ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 46342c2c7c19a15f907f2b5005721b13636b659a
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import sys
9
+
10
+ import gradio as gr
11
+ import huggingface_hub
12
+ import PIL.Image
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+ sys.path.insert(0, 'Anime2Sketch')
17
+
18
+ from data import read_img_path, tensor_to_img
19
+ from model import UnetGenerator
20
+
21
+ TITLE = 'Mukosame/Anime2Sketch'
22
+ DESCRIPTION = 'This is a demo for https://github.com/Mukosame/Anime2Sketch.'
23
+ ARTICLE = None
24
+
25
+ TOKEN = os.environ['TOKEN']
26
+
27
+
28
+ def parse_args() -> argparse.Namespace:
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument('--device', type=str, default='cpu')
31
+ parser.add_argument('--theme', type=str)
32
+ parser.add_argument('--live', action='store_true')
33
+ parser.add_argument('--share', action='store_true')
34
+ parser.add_argument('--port', type=int)
35
+ parser.add_argument('--disable-queue',
36
+ dest='enable_queue',
37
+ action='store_false')
38
+ parser.add_argument('--allow-flagging', type=str, default='never')
39
+ parser.add_argument('--allow-screenshot', action='store_true')
40
+ return parser.parse_args()
41
+
42
+
43
+ def load_model(device: torch.device) -> nn.Module:
44
+ norm_layer = functools.partial(nn.InstanceNorm2d,
45
+ affine=False,
46
+ track_running_stats=False)
47
+ model = UnetGenerator(3,
48
+ 1,
49
+ 8,
50
+ 64,
51
+ norm_layer=norm_layer,
52
+ use_dropout=False)
53
+
54
+ path = huggingface_hub.hf_hub_download('hysts/Anime2Sketch',
55
+ 'netG.pth',
56
+ use_auth_token=TOKEN)
57
+ ckpt = torch.load(path)
58
+ for key in list(ckpt.keys()):
59
+ if 'module.' in key:
60
+ ckpt[key.replace('module.', '')] = ckpt[key]
61
+ del ckpt[key]
62
+ model.load_state_dict(ckpt)
63
+ model.to(device)
64
+ model.eval()
65
+ return model
66
+
67
+
68
+ @torch.inference_mode()
69
+ def run(image_file,
70
+ model: nn.Module,
71
+ device: torch.device,
72
+ load_size: int = 512) -> PIL.Image.Image:
73
+ tensor, orig_size = read_img_path(image_file.name, load_size)
74
+ tensor = tensor.to(device)
75
+ out = model(tensor)
76
+ res = tensor_to_img(out)
77
+ res = PIL.Image.fromarray(res)
78
+ res = res.resize(orig_size, PIL.Image.Resampling.BICUBIC)
79
+ return res
80
+
81
+
82
+ def main():
83
+ gr.close_all()
84
+
85
+ args = parse_args()
86
+ device = torch.device(args.device)
87
+
88
+ model = load_model(device)
89
+
90
+ func = functools.partial(run, model=model, device=device)
91
+ func = functools.update_wrapper(func, run)
92
+
93
+ examples = [['Anime2Sketch/test_samples/madoka.jpg']]
94
+
95
+ gr.Interface(
96
+ func,
97
+ gr.inputs.Image(type='file', label='Input'),
98
+ gr.outputs.Image(type='pil', label='Output'),
99
+ examples=examples,
100
+ title=TITLE,
101
+ description=DESCRIPTION,
102
+ article=ARTICLE,
103
+ theme=args.theme,
104
+ allow_screenshot=args.allow_screenshot,
105
+ allow_flagging=args.allow_flagging,
106
+ live=args.live,
107
+ ).launch(
108
+ enable_queue=args.enable_queue,
109
+ server_port=args.port,
110
+ share=args.share,
111
+ )
112
+
113
+
114
+ if __name__ == '__main__':
115
+ main()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy==1.22.3
2
+ Pillow==9.1.0
3
+ torch==1.11.0
4
+ torchvision==0.12.0