hysts HF staff commited on
Commit
91b1804
1 Parent(s): 3603c98
Files changed (4) hide show
  1. .pre-commit-config.yaml +35 -0
  2. .style.yapf +5 -0
  3. README.md +1 -1
  4. app.py +17 -49
.pre-commit-config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.2.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: double-quote-string-fixer
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ['--fix=lf']
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.4
19
+ hooks:
20
+ - id: docformatter
21
+ args: ['--in-place']
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.12.0
24
+ hooks:
25
+ - id: isort
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v0.991
28
+ hooks:
29
+ - id: mypy
30
+ args: ['--ignore-missing-imports']
31
+ - repo: https://github.com/google/yapf
32
+ rev: v0.32.0
33
+ hooks:
34
+ - id: yapf
35
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 😻
4
  colorFrom: green
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 3.0.5
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: green
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  from __future__ import annotations
4
 
5
- import argparse
6
  import functools
7
  import os
8
  import sys
@@ -18,25 +17,10 @@ sys.path.insert(0, 'Anime2Sketch')
18
  from data import read_img_path, tensor_to_img
19
  from model import UnetGenerator
20
 
21
- TITLE = 'Mukosame/Anime2Sketch'
22
  DESCRIPTION = 'This is an unofficial demo for https://github.com/Mukosame/Anime2Sketch.'
23
- ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.anime2sketch" alt="visitor badge"/></center>'
24
 
25
- TOKEN = os.environ['TOKEN']
26
-
27
-
28
- def parse_args() -> argparse.Namespace:
29
- parser = argparse.ArgumentParser()
30
- parser.add_argument('--device', type=str, default='cpu')
31
- parser.add_argument('--theme', type=str)
32
- parser.add_argument('--live', action='store_true')
33
- parser.add_argument('--share', action='store_true')
34
- parser.add_argument('--port', type=int)
35
- parser.add_argument('--disable-queue',
36
- dest='enable_queue',
37
- action='store_false')
38
- parser.add_argument('--allow-flagging', type=str, default='never')
39
- return parser.parse_args()
40
 
41
 
42
  def load_model(device: torch.device) -> nn.Module:
@@ -52,7 +36,7 @@ def load_model(device: torch.device) -> nn.Module:
52
 
53
  path = huggingface_hub.hf_hub_download('hysts/Anime2Sketch',
54
  'netG.pth',
55
- use_auth_token=TOKEN)
56
  ckpt = torch.load(path)
57
  for key in list(ckpt.keys()):
58
  if 'module.' in key:
@@ -65,11 +49,11 @@ def load_model(device: torch.device) -> nn.Module:
65
 
66
 
67
  @torch.inference_mode()
68
- def run(image_file,
69
  model: nn.Module,
70
  device: torch.device,
71
  load_size: int = 512) -> PIL.Image.Image:
72
- tensor, orig_size = read_img_path(image_file.name, load_size)
73
  tensor = tensor.to(device)
74
  out = model(tensor)
75
  res = tensor_to_img(out)
@@ -78,34 +62,18 @@ def run(image_file,
78
  return res
79
 
80
 
81
- def main():
82
- args = parse_args()
83
- device = torch.device(args.device)
84
-
85
- model = load_model(device)
86
-
87
- func = functools.partial(run, model=model, device=device)
88
- func = functools.update_wrapper(func, run)
89
-
90
- examples = [['Anime2Sketch/test_samples/madoka.jpg']]
91
 
92
- gr.Interface(
93
- func,
94
- gr.inputs.Image(type='file', label='Input'),
95
- gr.outputs.Image(type='pil', label='Output'),
96
- examples=examples,
97
- title=TITLE,
98
- description=DESCRIPTION,
99
- article=ARTICLE,
100
- theme=args.theme,
101
- allow_flagging=args.allow_flagging,
102
- live=args.live,
103
- ).launch(
104
- enable_queue=args.enable_queue,
105
- server_port=args.port,
106
- share=args.share,
107
- )
108
 
 
109
 
110
- if __name__ == '__main__':
111
- main()
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import functools
6
  import os
7
  import sys
 
17
  from data import read_img_path, tensor_to_img
18
  from model import UnetGenerator
19
 
20
+ TITLE = 'Anime2Sketch'
21
  DESCRIPTION = 'This is an unofficial demo for https://github.com/Mukosame/Anime2Sketch.'
 
22
 
23
+ HF_TOKEN = os.getenv('HF_TOKEN')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  def load_model(device: torch.device) -> nn.Module:
 
36
 
37
  path = huggingface_hub.hf_hub_download('hysts/Anime2Sketch',
38
  'netG.pth',
39
+ use_auth_token=HF_TOKEN)
40
  ckpt = torch.load(path)
41
  for key in list(ckpt.keys()):
42
  if 'module.' in key:
 
49
 
50
 
51
  @torch.inference_mode()
52
+ def run(image_file: str,
53
  model: nn.Module,
54
  device: torch.device,
55
  load_size: int = 512) -> PIL.Image.Image:
56
+ tensor, orig_size = read_img_path(image_file, load_size)
57
  tensor = tensor.to(device)
58
  out = model(tensor)
59
  res = tensor_to_img(out)
 
62
  return res
63
 
64
 
65
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
66
+ model = load_model(device)
 
 
 
 
 
 
 
 
67
 
68
+ func = functools.partial(run, model=model, device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ examples = [['Anime2Sketch/test_samples/madoka.jpg']]
71
 
72
+ gr.Interface(
73
+ fn=func,
74
+ inputs=gr.Image(label='Input', type='filepath'),
75
+ outputs=gr.Image(label='Output', type='pil'),
76
+ examples=examples,
77
+ title=TITLE,
78
+ description=DESCRIPTION,
79
+ ).queue().launch(show_api=False)