hysts HF staff commited on
Commit
4fb3c5e
1 Parent(s): d64d4af
Files changed (9) hide show
  1. .gitattributes +1 -0
  2. .pre-commit-config.yaml +46 -0
  3. .style.yapf +5 -0
  4. README.md +1 -0
  5. app.py +140 -0
  6. model.py +109 -0
  7. requirements.txt +4 -0
  8. samples/ddpm-128-exp000.png +3 -0
  9. style.css +11 -0
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
.pre-commit-config.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: ^stylegan_xl
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.10.1
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.812
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ - repo: https://github.com/google/yapf
33
+ rev: v0.32.0
34
+ hooks:
35
+ - id: yapf
36
+ args: ['--parallel', '--in-place']
37
+ - repo: https://github.com/kynan/nbstripout
38
+ rev: 0.5.0
39
+ hooks:
40
+ - id: nbstripout
41
+ args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
42
+ - repo: https://github.com/nbQA-dev/nbQA
43
+ rev: 1.3.1
44
+ hooks:
45
+ - id: nbqa-isort
46
+ - id: nbqa-yapf
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
README.md CHANGED
@@ -3,6 +3,7 @@ title: Diffusers Anime Faces
3
  emoji: 🚀
4
  colorFrom: pink
5
  colorTo: purple
 
6
  sdk: gradio
7
  sdk_version: 3.1.4
8
  app_file: app.py
 
3
  emoji: 🚀
4
  colorFrom: pink
5
  colorTo: purple
6
+ python_version: 3.9.13
7
  sdk: gradio
8
  sdk_version: 3.1.4
9
  app_file: app.py
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+
7
+ import gradio as gr
8
+
9
+ from model import Model
10
+
11
+ TITLE = '# Anime Face Generation with [Diffusers](https://github.com/huggingface/diffusers)'
12
+ DESCRIPTION = '''
13
+
14
+ Expected execution time on Hugging Face Spaces: 13s (DDIM, 50 steps), 6s (PNDM, 20 steps), 247s (DDPM, 1000 steps)
15
+ '''
16
+ FOOTER = '<img id="visitor-badge" src="https://visitor-badge.glitch.me/badge?page_id=hysts.diffusers-anime-faces" alt="visitor badge" />'
17
+
18
+
19
+ def get_sample_image_url(name: str) -> str:
20
+ sample_image_dir = 'https://huggingface.co/spaces/hysts/diffusers-anime-faces/resolve/main/samples'
21
+ return f'{sample_image_dir}/{name}.png'
22
+
23
+
24
+ def get_sample_image_markdown(name: str) -> str:
25
+ model_name = name.split()[0]
26
+ url = get_sample_image_url(model_name)
27
+ if name == 'ddpm-128-exp000 (DDPM)':
28
+ text = f'''
29
+ - size: 128x128
30
+ - seed: 0-99
31
+ - scheduler: DDPM
32
+
33
+ ![sample images]({url})'''
34
+ else:
35
+ raise ValueError
36
+ return text
37
+
38
+
39
+ def update_scheduler_type(name: str) -> dict:
40
+ visible = name != 'DDPM'
41
+ if name == 'PNDM':
42
+ minimum = 4
43
+ maximum = 100
44
+ value = 20
45
+ else:
46
+ minimum = 1
47
+ maximum = 200
48
+ value = 50
49
+ return gr.Slider.update(visible=visible,
50
+ minimum=minimum,
51
+ maximum=maximum,
52
+ value=value)
53
+
54
+
55
+ def main():
56
+ parser = argparse.ArgumentParser()
57
+ parser.add_argument('--device', type=str, default='cpu')
58
+ args = parser.parse_args()
59
+
60
+ model = Model(args.device)
61
+
62
+ with gr.Blocks(css='style.css') as demo:
63
+ gr.Markdown(TITLE)
64
+ gr.Markdown(DESCRIPTION)
65
+
66
+ with gr.Tabs():
67
+ with gr.TabItem('App'):
68
+ with gr.Row():
69
+ with gr.Column():
70
+ with gr.Group():
71
+ model_name = gr.Dropdown(
72
+ model.MODEL_NAMES,
73
+ value=model.MODEL_NAMES[0],
74
+ label='Model',
75
+ interactive=False)
76
+ scheduler_type = gr.Radio(
77
+ choices=['DDPM', 'DDIM', 'PNDM'],
78
+ value='DDIM',
79
+ label='Scheduler')
80
+ num_steps = gr.Slider(1,
81
+ 200,
82
+ step=1,
83
+ value=50,
84
+ label='Number of Steps')
85
+ seed = gr.Slider(0,
86
+ 100000,
87
+ step=1,
88
+ value=1234,
89
+ label='Seed')
90
+ run_button = gr.Button('Run')
91
+ with gr.Column():
92
+ result = gr.Image(label='Result', elem_id='result')
93
+
94
+ with gr.TabItem('Sample Images'):
95
+ with gr.Row():
96
+ model_name2 = gr.Dropdown([
97
+ 'ddpm-128-exp000 (DDPM)',
98
+ ],
99
+ value='ddpm-128-exp000 (DDPM)',
100
+ label='Model',
101
+ interactive=False)
102
+ with gr.Row():
103
+ text = get_sample_image_markdown(model_name2.value)
104
+ sample_images = gr.Markdown(text)
105
+
106
+ gr.Markdown(FOOTER)
107
+
108
+ model_name.change(fn=model.set_pipeline,
109
+ inputs=[
110
+ model_name,
111
+ scheduler_type,
112
+ ],
113
+ outputs=None)
114
+ scheduler_type.change(fn=update_scheduler_type,
115
+ inputs=scheduler_type,
116
+ outputs=num_steps,
117
+ queue=False)
118
+ scheduler_type.change(fn=model.set_pipeline,
119
+ inputs=[
120
+ model_name,
121
+ scheduler_type,
122
+ ],
123
+ outputs=None)
124
+ run_button.click(fn=model.run,
125
+ inputs=[
126
+ model_name,
127
+ scheduler_type,
128
+ num_steps,
129
+ seed,
130
+ ],
131
+ outputs=result)
132
+ model_name2.change(fn=get_sample_image_markdown,
133
+ inputs=model_name2,
134
+ outputs=sample_images)
135
+
136
+ demo.launch(enable_queue=True, share=False)
137
+
138
+
139
+ if __name__ == '__main__':
140
+ main()
model.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import sys
6
+
7
+ import PIL.Image
8
+ import torch
9
+ from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline, PNDMPipeline,
10
+ PNDMScheduler)
11
+
12
+ HF_TOKEN = os.environ['HF_TOKEN']
13
+
14
+ formatter = logging.Formatter(
15
+ '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
16
+ datefmt='%Y-%m-%d %H:%M:%S')
17
+ stream_handler = logging.StreamHandler(stream=sys.stdout)
18
+ stream_handler.setLevel(logging.INFO)
19
+ stream_handler.setFormatter(formatter)
20
+ logger = logging.getLogger(__name__)
21
+ logger.setLevel(logging.INFO)
22
+ logger.propagate = False
23
+ logger.addHandler(stream_handler)
24
+
25
+
26
+ class Model:
27
+
28
+ MODEL_NAMES = [
29
+ 'ddpm-128-exp000',
30
+ ]
31
+
32
+ def __init__(self, device: str | torch.device):
33
+ self.device = torch.device(device)
34
+ self._download_all_models()
35
+
36
+ self.model_name = self.MODEL_NAMES[0]
37
+ self.scheduler_type = 'DDIM'
38
+ self.pipeline = self._load_pipeline(self.model_name,
39
+ self.scheduler_type)
40
+
41
+ def _load_pipeline(self, model_name: str,
42
+ scheduler_type: str) -> DDIMPipeline | DDPMPipeline:
43
+ repo_id = f'hysts/diffusers-anime-faces-{model_name}'
44
+ if scheduler_type == 'DDPM':
45
+ pipeline = DDPMPipeline.from_pretrained(repo_id,
46
+ use_auth_token=HF_TOKEN)
47
+ elif scheduler_type == 'DDIM':
48
+ pipeline = DDIMPipeline.from_pretrained(repo_id,
49
+ use_auth_token=HF_TOKEN)
50
+ config, _ = DDIMScheduler.extract_init_dict(
51
+ dict(pipeline.scheduler.config))
52
+ pipeline.scheduler = DDIMScheduler(**config)
53
+ elif scheduler_type == 'PNDM':
54
+ pipeline = PNDMPipeline.from_pretrained(repo_id,
55
+ use_auth_token=HF_TOKEN)
56
+ config, _ = PNDMScheduler.extract_init_dict(
57
+ dict(pipeline.scheduler.config))
58
+ pipeline.scheduler = PNDMScheduler(**config)
59
+ else:
60
+ raise ValueError
61
+ return pipeline
62
+
63
+ def set_pipeline(self, model_name: str, scheduler_type: str) -> None:
64
+ logger.info('--- set_pipeline ---')
65
+ logger.info(f'{model_name=}, {scheduler_type=}')
66
+
67
+ if model_name == self.model_name and scheduler_type == self.scheduler_type:
68
+ logger.info('Skipping')
69
+ logger.info('--- done ---')
70
+ return
71
+ self.model_name = model_name
72
+ self.scheduler_type = scheduler_type
73
+ self.pipeline = self._load_pipeline(model_name, scheduler_type)
74
+
75
+ logger.info('--- done ---')
76
+
77
+ def _download_all_models(self):
78
+ for name in self.MODEL_NAMES:
79
+ self._load_pipeline(name, 'DDPM')
80
+
81
+ def generate(self, seed: int, num_steps: int) -> PIL.Image.Image:
82
+ logger.info('--- generate ---')
83
+ logger.info(f'{seed=}, {num_steps=}')
84
+
85
+ torch.manual_seed(seed)
86
+ if self.scheduler_type == 'DDPM':
87
+ res = self.pipeline(batch_size=1,
88
+ torch_device=self.device)['sample'][0]
89
+ elif self.scheduler_type in ['DDIM', 'PNDM']:
90
+ res = self.pipeline(batch_size=1,
91
+ torch_device=self.device,
92
+ num_inference_steps=num_steps)['sample'][0]
93
+ else:
94
+ raise ValueError
95
+
96
+ logger.info('--- done ---')
97
+ return res
98
+
99
+ def run(
100
+ self,
101
+ model_name: str,
102
+ scheduler_type: str,
103
+ num_steps: int,
104
+ seed: int,
105
+ ) -> PIL.Image.Image:
106
+ self.set_pipeline(model_name, scheduler_type)
107
+ if scheduler_type == 'PNDM':
108
+ num_steps = max(4, min(num_steps, 100))
109
+ return self.generate(seed, num_steps)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ diffusers==0.1.3
2
+ diffusers==0.1.3
3
+ torch==1.12.1
4
+ torchvision==0.13.1
samples/ddpm-128-exp000.png ADDED

Git LFS Details

  • SHA256: 83050d9266b9ccb442a227cba68db7c174a0179023a1b2797c592df62515a58e
  • Pointer size: 132 Bytes
  • Size of remote file: 3.05 MB
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+ div#result {
5
+ max-width: 400px;
6
+ max-height: 400px;
7
+ }
8
+ img#visitor-badge {
9
+ display: block;
10
+ margin: auto;
11
+ }