Spaces:
Runtime error
Runtime error
Add files
Browse files- .gitattributes +1 -0
- .pre-commit-config.yaml +46 -0
- .style.yapf +5 -0
- README.md +1 -0
- app.py +140 -0
- model.py +109 -0
- requirements.txt +4 -0
- samples/ddpm-128-exp000.png +3 -0
- style.css +11 -0
.gitattributes
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
2 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
3 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
4 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
exclude: ^stylegan_xl
|
2 |
+
repos:
|
3 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
4 |
+
rev: v4.2.0
|
5 |
+
hooks:
|
6 |
+
- id: check-executables-have-shebangs
|
7 |
+
- id: check-json
|
8 |
+
- id: check-merge-conflict
|
9 |
+
- id: check-shebang-scripts-are-executable
|
10 |
+
- id: check-toml
|
11 |
+
- id: check-yaml
|
12 |
+
- id: double-quote-string-fixer
|
13 |
+
- id: end-of-file-fixer
|
14 |
+
- id: mixed-line-ending
|
15 |
+
args: ['--fix=lf']
|
16 |
+
- id: requirements-txt-fixer
|
17 |
+
- id: trailing-whitespace
|
18 |
+
- repo: https://github.com/myint/docformatter
|
19 |
+
rev: v1.4
|
20 |
+
hooks:
|
21 |
+
- id: docformatter
|
22 |
+
args: ['--in-place']
|
23 |
+
- repo: https://github.com/pycqa/isort
|
24 |
+
rev: 5.10.1
|
25 |
+
hooks:
|
26 |
+
- id: isort
|
27 |
+
- repo: https://github.com/pre-commit/mirrors-mypy
|
28 |
+
rev: v0.812
|
29 |
+
hooks:
|
30 |
+
- id: mypy
|
31 |
+
args: ['--ignore-missing-imports']
|
32 |
+
- repo: https://github.com/google/yapf
|
33 |
+
rev: v0.32.0
|
34 |
+
hooks:
|
35 |
+
- id: yapf
|
36 |
+
args: ['--parallel', '--in-place']
|
37 |
+
- repo: https://github.com/kynan/nbstripout
|
38 |
+
rev: 0.5.0
|
39 |
+
hooks:
|
40 |
+
- id: nbstripout
|
41 |
+
args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
|
42 |
+
- repo: https://github.com/nbQA-dev/nbQA
|
43 |
+
rev: 1.3.1
|
44 |
+
hooks:
|
45 |
+
- id: nbqa-isort
|
46 |
+
- id: nbqa-yapf
|
.style.yapf
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[style]
|
2 |
+
based_on_style = pep8
|
3 |
+
blank_line_before_nested_class_or_def = false
|
4 |
+
spaces_before_comment = 2
|
5 |
+
split_before_logical_operator = true
|
README.md
CHANGED
@@ -3,6 +3,7 @@ title: Diffusers Anime Faces
|
|
3 |
emoji: 🚀
|
4 |
colorFrom: pink
|
5 |
colorTo: purple
|
|
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.1.4
|
8 |
app_file: app.py
|
|
|
3 |
emoji: 🚀
|
4 |
colorFrom: pink
|
5 |
colorTo: purple
|
6 |
+
python_version: 3.9.13
|
7 |
sdk: gradio
|
8 |
sdk_version: 3.1.4
|
9 |
app_file: app.py
|
app.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
from model import Model
|
10 |
+
|
11 |
+
TITLE = '# Anime Face Generation with [Diffusers](https://github.com/huggingface/diffusers)'
|
12 |
+
DESCRIPTION = '''
|
13 |
+
|
14 |
+
Expected execution time on Hugging Face Spaces: 13s (DDIM, 50 steps), 6s (PNDM, 20 steps), 247s (DDPM, 1000 steps)
|
15 |
+
'''
|
16 |
+
FOOTER = '<img id="visitor-badge" src="https://visitor-badge.glitch.me/badge?page_id=hysts.diffusers-anime-faces" alt="visitor badge" />'
|
17 |
+
|
18 |
+
|
19 |
+
def get_sample_image_url(name: str) -> str:
|
20 |
+
sample_image_dir = 'https://huggingface.co/spaces/hysts/diffusers-anime-faces/resolve/main/samples'
|
21 |
+
return f'{sample_image_dir}/{name}.png'
|
22 |
+
|
23 |
+
|
24 |
+
def get_sample_image_markdown(name: str) -> str:
|
25 |
+
model_name = name.split()[0]
|
26 |
+
url = get_sample_image_url(model_name)
|
27 |
+
if name == 'ddpm-128-exp000 (DDPM)':
|
28 |
+
text = f'''
|
29 |
+
- size: 128x128
|
30 |
+
- seed: 0-99
|
31 |
+
- scheduler: DDPM
|
32 |
+
|
33 |
+
![sample images]({url})'''
|
34 |
+
else:
|
35 |
+
raise ValueError
|
36 |
+
return text
|
37 |
+
|
38 |
+
|
39 |
+
def update_scheduler_type(name: str) -> dict:
|
40 |
+
visible = name != 'DDPM'
|
41 |
+
if name == 'PNDM':
|
42 |
+
minimum = 4
|
43 |
+
maximum = 100
|
44 |
+
value = 20
|
45 |
+
else:
|
46 |
+
minimum = 1
|
47 |
+
maximum = 200
|
48 |
+
value = 50
|
49 |
+
return gr.Slider.update(visible=visible,
|
50 |
+
minimum=minimum,
|
51 |
+
maximum=maximum,
|
52 |
+
value=value)
|
53 |
+
|
54 |
+
|
55 |
+
def main():
|
56 |
+
parser = argparse.ArgumentParser()
|
57 |
+
parser.add_argument('--device', type=str, default='cpu')
|
58 |
+
args = parser.parse_args()
|
59 |
+
|
60 |
+
model = Model(args.device)
|
61 |
+
|
62 |
+
with gr.Blocks(css='style.css') as demo:
|
63 |
+
gr.Markdown(TITLE)
|
64 |
+
gr.Markdown(DESCRIPTION)
|
65 |
+
|
66 |
+
with gr.Tabs():
|
67 |
+
with gr.TabItem('App'):
|
68 |
+
with gr.Row():
|
69 |
+
with gr.Column():
|
70 |
+
with gr.Group():
|
71 |
+
model_name = gr.Dropdown(
|
72 |
+
model.MODEL_NAMES,
|
73 |
+
value=model.MODEL_NAMES[0],
|
74 |
+
label='Model',
|
75 |
+
interactive=False)
|
76 |
+
scheduler_type = gr.Radio(
|
77 |
+
choices=['DDPM', 'DDIM', 'PNDM'],
|
78 |
+
value='DDIM',
|
79 |
+
label='Scheduler')
|
80 |
+
num_steps = gr.Slider(1,
|
81 |
+
200,
|
82 |
+
step=1,
|
83 |
+
value=50,
|
84 |
+
label='Number of Steps')
|
85 |
+
seed = gr.Slider(0,
|
86 |
+
100000,
|
87 |
+
step=1,
|
88 |
+
value=1234,
|
89 |
+
label='Seed')
|
90 |
+
run_button = gr.Button('Run')
|
91 |
+
with gr.Column():
|
92 |
+
result = gr.Image(label='Result', elem_id='result')
|
93 |
+
|
94 |
+
with gr.TabItem('Sample Images'):
|
95 |
+
with gr.Row():
|
96 |
+
model_name2 = gr.Dropdown([
|
97 |
+
'ddpm-128-exp000 (DDPM)',
|
98 |
+
],
|
99 |
+
value='ddpm-128-exp000 (DDPM)',
|
100 |
+
label='Model',
|
101 |
+
interactive=False)
|
102 |
+
with gr.Row():
|
103 |
+
text = get_sample_image_markdown(model_name2.value)
|
104 |
+
sample_images = gr.Markdown(text)
|
105 |
+
|
106 |
+
gr.Markdown(FOOTER)
|
107 |
+
|
108 |
+
model_name.change(fn=model.set_pipeline,
|
109 |
+
inputs=[
|
110 |
+
model_name,
|
111 |
+
scheduler_type,
|
112 |
+
],
|
113 |
+
outputs=None)
|
114 |
+
scheduler_type.change(fn=update_scheduler_type,
|
115 |
+
inputs=scheduler_type,
|
116 |
+
outputs=num_steps,
|
117 |
+
queue=False)
|
118 |
+
scheduler_type.change(fn=model.set_pipeline,
|
119 |
+
inputs=[
|
120 |
+
model_name,
|
121 |
+
scheduler_type,
|
122 |
+
],
|
123 |
+
outputs=None)
|
124 |
+
run_button.click(fn=model.run,
|
125 |
+
inputs=[
|
126 |
+
model_name,
|
127 |
+
scheduler_type,
|
128 |
+
num_steps,
|
129 |
+
seed,
|
130 |
+
],
|
131 |
+
outputs=result)
|
132 |
+
model_name2.change(fn=get_sample_image_markdown,
|
133 |
+
inputs=model_name2,
|
134 |
+
outputs=sample_images)
|
135 |
+
|
136 |
+
demo.launch(enable_queue=True, share=False)
|
137 |
+
|
138 |
+
|
139 |
+
if __name__ == '__main__':
|
140 |
+
main()
|
model.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import PIL.Image
|
8 |
+
import torch
|
9 |
+
from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline, PNDMPipeline,
|
10 |
+
PNDMScheduler)
|
11 |
+
|
12 |
+
HF_TOKEN = os.environ['HF_TOKEN']
|
13 |
+
|
14 |
+
formatter = logging.Formatter(
|
15 |
+
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
|
16 |
+
datefmt='%Y-%m-%d %H:%M:%S')
|
17 |
+
stream_handler = logging.StreamHandler(stream=sys.stdout)
|
18 |
+
stream_handler.setLevel(logging.INFO)
|
19 |
+
stream_handler.setFormatter(formatter)
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
logger.setLevel(logging.INFO)
|
22 |
+
logger.propagate = False
|
23 |
+
logger.addHandler(stream_handler)
|
24 |
+
|
25 |
+
|
26 |
+
class Model:
|
27 |
+
|
28 |
+
MODEL_NAMES = [
|
29 |
+
'ddpm-128-exp000',
|
30 |
+
]
|
31 |
+
|
32 |
+
def __init__(self, device: str | torch.device):
|
33 |
+
self.device = torch.device(device)
|
34 |
+
self._download_all_models()
|
35 |
+
|
36 |
+
self.model_name = self.MODEL_NAMES[0]
|
37 |
+
self.scheduler_type = 'DDIM'
|
38 |
+
self.pipeline = self._load_pipeline(self.model_name,
|
39 |
+
self.scheduler_type)
|
40 |
+
|
41 |
+
def _load_pipeline(self, model_name: str,
|
42 |
+
scheduler_type: str) -> DDIMPipeline | DDPMPipeline:
|
43 |
+
repo_id = f'hysts/diffusers-anime-faces-{model_name}'
|
44 |
+
if scheduler_type == 'DDPM':
|
45 |
+
pipeline = DDPMPipeline.from_pretrained(repo_id,
|
46 |
+
use_auth_token=HF_TOKEN)
|
47 |
+
elif scheduler_type == 'DDIM':
|
48 |
+
pipeline = DDIMPipeline.from_pretrained(repo_id,
|
49 |
+
use_auth_token=HF_TOKEN)
|
50 |
+
config, _ = DDIMScheduler.extract_init_dict(
|
51 |
+
dict(pipeline.scheduler.config))
|
52 |
+
pipeline.scheduler = DDIMScheduler(**config)
|
53 |
+
elif scheduler_type == 'PNDM':
|
54 |
+
pipeline = PNDMPipeline.from_pretrained(repo_id,
|
55 |
+
use_auth_token=HF_TOKEN)
|
56 |
+
config, _ = PNDMScheduler.extract_init_dict(
|
57 |
+
dict(pipeline.scheduler.config))
|
58 |
+
pipeline.scheduler = PNDMScheduler(**config)
|
59 |
+
else:
|
60 |
+
raise ValueError
|
61 |
+
return pipeline
|
62 |
+
|
63 |
+
def set_pipeline(self, model_name: str, scheduler_type: str) -> None:
|
64 |
+
logger.info('--- set_pipeline ---')
|
65 |
+
logger.info(f'{model_name=}, {scheduler_type=}')
|
66 |
+
|
67 |
+
if model_name == self.model_name and scheduler_type == self.scheduler_type:
|
68 |
+
logger.info('Skipping')
|
69 |
+
logger.info('--- done ---')
|
70 |
+
return
|
71 |
+
self.model_name = model_name
|
72 |
+
self.scheduler_type = scheduler_type
|
73 |
+
self.pipeline = self._load_pipeline(model_name, scheduler_type)
|
74 |
+
|
75 |
+
logger.info('--- done ---')
|
76 |
+
|
77 |
+
def _download_all_models(self):
|
78 |
+
for name in self.MODEL_NAMES:
|
79 |
+
self._load_pipeline(name, 'DDPM')
|
80 |
+
|
81 |
+
def generate(self, seed: int, num_steps: int) -> PIL.Image.Image:
|
82 |
+
logger.info('--- generate ---')
|
83 |
+
logger.info(f'{seed=}, {num_steps=}')
|
84 |
+
|
85 |
+
torch.manual_seed(seed)
|
86 |
+
if self.scheduler_type == 'DDPM':
|
87 |
+
res = self.pipeline(batch_size=1,
|
88 |
+
torch_device=self.device)['sample'][0]
|
89 |
+
elif self.scheduler_type in ['DDIM', 'PNDM']:
|
90 |
+
res = self.pipeline(batch_size=1,
|
91 |
+
torch_device=self.device,
|
92 |
+
num_inference_steps=num_steps)['sample'][0]
|
93 |
+
else:
|
94 |
+
raise ValueError
|
95 |
+
|
96 |
+
logger.info('--- done ---')
|
97 |
+
return res
|
98 |
+
|
99 |
+
def run(
|
100 |
+
self,
|
101 |
+
model_name: str,
|
102 |
+
scheduler_type: str,
|
103 |
+
num_steps: int,
|
104 |
+
seed: int,
|
105 |
+
) -> PIL.Image.Image:
|
106 |
+
self.set_pipeline(model_name, scheduler_type)
|
107 |
+
if scheduler_type == 'PNDM':
|
108 |
+
num_steps = max(4, min(num_steps, 100))
|
109 |
+
return self.generate(seed, num_steps)
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.1.3
|
2 |
+
diffusers==0.1.3
|
3 |
+
torch==1.12.1
|
4 |
+
torchvision==0.13.1
|
samples/ddpm-128-exp000.png
ADDED
Git LFS Details
|
style.css
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
text-align: center;
|
3 |
+
}
|
4 |
+
div#result {
|
5 |
+
max-width: 400px;
|
6 |
+
max-height: 400px;
|
7 |
+
}
|
8 |
+
img#visitor-badge {
|
9 |
+
display: block;
|
10 |
+
margin: auto;
|
11 |
+
}
|