File size: 11,920 Bytes
43b7e92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 |
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ์ค์ผ์ค๋ฌ
diffusion ํ์ดํ๋ผ์ธ์ diffusion ๋ชจ๋ธ, ์ค์ผ์ค๋ฌ ๋ฑ์ ์ปดํฌ๋ํธ๋ค๋ก ๊ตฌ์ฑ๋ฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ ํ์ดํ๋ผ์ธ ์์ ์ผ๋ถ ์ปดํฌ๋ํธ๋ฅผ ๋ค๋ฅธ ์ปดํฌ๋ํธ๋ก ๊ต์ฒดํ๋ ์์ ์ปค์คํฐ๋ง์ด์ง ์ญ์ ๊ฐ๋ฅํฉ๋๋ค. ์ด์ ๊ฐ์ ์ปดํฌ๋ํธ ์ปค์คํฐ๋ง์ด์ง์ ๊ฐ์ฅ ๋ํ์ ์ธ ์์๊ฐ ๋ฐ๋ก [์ค์ผ์ค๋ฌ](../api/schedulers/overview.md)๋ฅผ ๊ต์ฒดํ๋ ๊ฒ์
๋๋ค.
์ค์ผ์ฅด๋ฌ๋ ๋ค์๊ณผ ๊ฐ์ด diffusion ์์คํ
์ ์ ๋ฐ์ ์ธ ๋๋
ธ์ด์ง ํ๋ก์ธ์ค๋ฅผ ์ ์ํฉ๋๋ค.
- ๋๋
ธ์ด์ง ์คํ
์ ์ผ๋ง๋ ๊ฐ์ ธ๊ฐ์ผ ํ ๊น?
- ํ๋ฅ ์ ์ผ๋ก(stochastic) ํน์ ํ์ ์ ์ผ๋ก(deterministic)?
- ๋๋
ธ์ด์ง ๋ ์ํ์ ์ฐพ์๋ด๊ธฐ ์ํด ์ด๋ค ์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํด์ผ ํ ๊น?
์ด๋ฌํ ํ๋ก์ธ์ค๋ ๋ค์ ๋ํดํ๊ณ , ๋๋
ธ์ด์ง ์๋์ ๋๋
ธ์ด์ง ํ๋ฆฌํฐ ์ฌ์ด์ ํธ๋ ์ด๋ ์คํ๋ฅผ ์ ์ํด์ผ ํ๋ ๋ฌธ์ ๊ฐ ๋ ์ ์์ต๋๋ค. ์ฃผ์ด์ง ํ์ดํ๋ผ์ธ์ ์ด๋ค ์ค์ผ์ค๋ฌ๊ฐ ๊ฐ์ฅ ์ ํฉํ์ง๋ฅผ ์ ๋์ ์ผ๋ก ํ๋จํ๋ ๊ฒ์ ๋งค์ฐ ์ด๋ ค์ด ์ผ์
๋๋ค. ์ด๋ก ์ธํด ์ผ๋จ ํด๋น ์ค์ผ์ค๋ฌ๋ฅผ ์ง์ ์ฌ์ฉํ์ฌ, ์์ฑ๋๋ ์ด๋ฏธ์ง๋ฅผ ์ง์ ๋์ผ๋ก ๋ณด๋ฉฐ, ์ ์ฑ์ ์ผ๋ก ์ฑ๋ฅ์ ํ๋จํด๋ณด๋ ๊ฒ์ด ์ถ์ฒ๋๊ณค ํฉ๋๋ค.
## ํ์ดํ๋ผ์ธ ๋ถ๋ฌ์ค๊ธฐ
๋จผ์ ์คํ
์ด๋ธ diffusion ํ์ดํ๋ผ์ธ์ ๋ถ๋ฌ์ค๋๋ก ํด๋ณด๊ฒ ์ต๋๋ค. ๋ฌผ๋ก ์คํ
์ด๋ธ diffusion์ ์ฌ์ฉํ๊ธฐ ์ํด์๋, ํ๊น
ํ์ด์ค ํ๋ธ์ ๋ฑ๋ก๋ ์ฌ์ฉ์์ฌ์ผ ํ๋ฉฐ, ๊ด๋ จ [๋ผ์ด์ผ์ค](https://huggingface.co/runwayml/stable-diffusion-v1-5)์ ๋์ํด์ผ ํ๋ค๋ ์ ์ ์์ง ๋ง์์ฃผ์ธ์.
*์ญ์ ์ฃผ: ๋ค๋ง, ํ์ฌ ์ ๊ท๋ก ์์ฑํ ํ๊น
ํ์ด์ค ๊ณ์ ์ ๋ํด์๋ ๋ผ์ด์ผ์ค ๋์๋ฅผ ์๊ตฌํ์ง ์๋ ๊ฒ์ผ๋ก ๋ณด์
๋๋ค!*
```python
from huggingface_hub import login
from diffusers import DiffusionPipeline
import torch
# first we need to login with our access token
login()
# Now we can download the pipeline
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
```
๋ค์์ผ๋ก, GPU๋ก ์ด๋ํฉ๋๋ค.
```python
pipeline.to("cuda")
```
## ์ค์ผ์ค๋ฌ ์ก์ธ์ค
์ค์ผ์ค๋ฌ๋ ์ธ์ ๋ ํ์ดํ๋ผ์ธ์ ์ปดํฌ๋ํธ๋ก์ ์กด์ฌํ๋ฉฐ, ์ผ๋ฐ์ ์ผ๋ก ํ์ดํ๋ผ์ธ ์ธ์คํด์ค ๋ด์ `scheduler`๋ผ๋ ์ด๋ฆ์ ์์ฑ(property)์ผ๋ก ์ ์๋์ด ์์ต๋๋ค.
```python
pipeline.scheduler
```
**Output**:
```
PNDMScheduler {
"_class_name": "PNDMScheduler",
"_diffusers_version": "0.8.0.dev0",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": false,
"num_train_timesteps": 1000,
"set_alpha_to_one": false,
"skip_prk_steps": true,
"steps_offset": 1,
"trained_betas": null
}
```
์ถ๋ ฅ ๊ฒฐ๊ณผ๋ฅผ ํตํด, ์ฐ๋ฆฌ๋ ํด๋น ์ค์ผ์ค๋ฌ๊ฐ [`PNDMScheduler`]์ ์ธ์คํด์ค๋ผ๋ ๊ฒ์ ์ ์ ์์ต๋๋ค. ์ด์ [`PNDMScheduler`]์ ๋ค๋ฅธ ์ค์ผ์ค๋ฌ๋ค์ ์ฑ๋ฅ์ ๋น๊ตํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค. ๋จผ์ ํ
์คํธ์ ์ฌ์ฉํ ํ๋กฌํํธ๋ฅผ ๋ค์๊ณผ ๊ฐ์ด ์ ์ํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.
```python
prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
```
๋ค์์ผ๋ก ์ ์ฌํ ์ด๋ฏธ์ง ์์ฑ์ ๋ณด์ฅํ๊ธฐ ์ํด์, ๋ค์๊ณผ ๊ฐ์ด ๋๋ค์๋๋ฅผ ๊ณ ์ ํด์ฃผ๋๋ก ํ๊ฒ ์ต๋๋ค.
```python
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator).images[0]
image
```
<p align="center">
<br>
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_pndm.png" width="400"/>
<br>
</p>
## ์ค์ผ์ค๋ฌ ๊ต์ฒดํ๊ธฐ
๋ค์์ผ๋ก ํ์ดํ๋ผ์ธ์ ์ค์ผ์ค๋ฌ๋ฅผ ๋ค๋ฅธ ์ค์ผ์ค๋ฌ๋ก ๊ต์ฒดํ๋ ๋ฐฉ๋ฒ์ ๋ํด ์์๋ณด๊ฒ ์ต๋๋ค. ๋ชจ๋ ์ค์ผ์ค๋ฌ๋ [`SchedulerMixin.compatibles`]๋ผ๋ ์์ฑ(property)์ ๊ฐ๊ณ ์์ต๋๋ค. ํด๋น ์์ฑ์ **ํธํ ๊ฐ๋ฅํ** ์ค์ผ์ค๋ฌ๋ค์ ๋ํ ์ ๋ณด๋ฅผ ๋ด๊ณ ์์ต๋๋ค.
```python
pipeline.scheduler.compatibles
```
**Output**:
```
[diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
diffusers.schedulers.scheduling_ddim.DDIMScheduler,
diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
diffusers.schedulers.scheduling_pndm.PNDMScheduler,
diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler]
```
ํธํ๋๋ ์ค์ผ์ค๋ฌ๋ค์ ์ดํด๋ณด๋ฉด ์๋์ ๊ฐ์ต๋๋ค.
- [`LMSDiscreteScheduler`],
- [`DDIMScheduler`],
- [`DPMSolverMultistepScheduler`],
- [`EulerDiscreteScheduler`],
- [`PNDMScheduler`],
- [`DDPMScheduler`],
- [`EulerAncestralDiscreteScheduler`].
์์ ์ ์ํ๋ ํ๋กฌํํธ๋ฅผ ์ฌ์ฉํด์ ๊ฐ๊ฐ์ ์ค์ผ์ค๋ฌ๋ค์ ๋น๊ตํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.
๋จผ์ ํ์ดํ๋ผ์ธ ์์ ์ค์ผ์ค๋ฌ๋ฅผ ๋ฐ๊พธ๊ธฐ ์ํด [`ConfigMixin.config`] ์์ฑ๊ณผ [`ConfigMixin.from_config`] ๋ฉ์๋๋ฅผ ํ์ฉํด๋ณด๋ ค๊ณ ํฉ๋๋ค.
```python
pipeline.scheduler.config
```
**Output**:
```
FrozenDict([('num_train_timesteps', 1000),
('beta_start', 0.00085),
('beta_end', 0.012),
('beta_schedule', 'scaled_linear'),
('trained_betas', None),
('skip_prk_steps', True),
('set_alpha_to_one', False),
('steps_offset', 1),
('_class_name', 'PNDMScheduler'),
('_diffusers_version', '0.8.0.dev0'),
('clip_sample', False)])
```
๊ธฐ์กด ์ค์ผ์ค๋ฌ์ config๋ฅผ ํธํ ๊ฐ๋ฅํ ๋ค๋ฅธ ์ค์ผ์ค๋ฌ์ ์ด์ํ๋ ๊ฒ ์ญ์ ๊ฐ๋ฅํฉ๋๋ค.
๋ค์ ์์๋ ๊ธฐ์กด ์ค์ผ์ค๋ฌ(`pipeline.scheduler`)๋ฅผ ๋ค๋ฅธ ์ข
๋ฅ์ ์ค์ผ์ค๋ฌ(`DDIMScheduler`)๋ก ๋ฐ๊พธ๋ ์ฝ๋์
๋๋ค. ๊ธฐ์กด ์ค์ผ์ค๋ฌ๊ฐ ๊ฐ๊ณ ์๋ config๋ฅผ `.from_config` ๋ฉ์๋์ ์ธ์๋ก ์ ๋ฌํ๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
```python
from diffusers import DDIMScheduler
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
```
์ด์ ํ์ดํ๋ผ์ธ์ ์คํํด์ ๋ ์ค์ผ์ค๋ฌ ์ฌ์ด์ ์์ฑ๋ ์ด๋ฏธ์ง์ ํ๋ฆฌํฐ๋ฅผ ๋น๊ตํด๋ด
์๋ค.
```python
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator).images[0]
image
```
<p align="center">
<br>
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_ddim.png" width="400"/>
<br>
</p>
## ์ค์ผ์ค๋ฌ๋ค ๋น๊ตํด๋ณด๊ธฐ
์ง๊ธ๊น์ง๋ [`PNDMScheduler`]์ [`DDIMScheduler`] ์ค์ผ์ค๋ฌ๋ฅผ ์คํํด๋ณด์์ต๋๋ค. ์์ง ๋น๊ตํด๋ณผ ์ค์ผ์ค๋ฌ๋ค์ด ๋ ๋ง์ด ๋จ์์์ผ๋ ๊ณ์ ๋น๊ตํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.
[`LMSDiscreteScheduler`]์ ์ผ๋ฐ์ ์ผ๋ก ๋ ์ข์ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ค๋๋ค.
```python
from diffusers import LMSDiscreteScheduler
pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator).images[0]
image
```
<p align="center">
<br>
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_lms.png" width="400"/>
<br>
</p>
[`EulerDiscreteScheduler`]์ [`EulerAncestralDiscreteScheduler`] ๊ณ ์ 30๋ฒ์ inference step๋ง์ผ๋ก๋ ๋์ ํ๋ฆฌํฐ์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ๊ฒ์ ์ ์ ์์ต๋๋ค.
```python
from diffusers import EulerDiscreteScheduler
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]
image
```
<p align="center">
<br>
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_discrete.png" width="400"/>
<br>
</p>
```python
from diffusers import EulerAncestralDiscreteScheduler
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]
image
```
<p align="center">
<br>
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_ancestral.png" width="400"/>
<br>
</p>
์ง๊ธ ์ด ๋ฌธ์๋ฅผ ์์ฑํ๋ ํ์์ ๊ธฐ์ค์์ , [`DPMSolverMultistepScheduler`]๊ฐ ์๊ฐ ๋๋น ๊ฐ์ฅ ์ข์ ํ์ง์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ๊ฒ ๊ฐ์ต๋๋ค. 20๋ฒ ์ ๋์ ์คํ
๋ง์ผ๋ก๋ ์คํ๋ ์ ์์ต๋๋ค.
```python
from diffusers import DPMSolverMultistepScheduler
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
image
```
<p align="center">
<br>
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_dpm.png" width="400"/>
<br>
</p>
๋ณด์๋ค์ํผ ์์ฑ๋ ์ด๋ฏธ์ง๋ค์ ๋งค์ฐ ๋น์ทํ๊ณ , ๋น์ทํ ํ๋ฆฌํฐ๋ฅผ ๋ณด์ด๋ ๊ฒ ๊ฐ์ต๋๋ค. ์ค์ ๋ก ์ด๋ค ์ค์ผ์ค๋ฌ๋ฅผ ์ ํํ ๊ฒ์ธ๊ฐ๋ ์ข
์ข
ํน์ ์ด์ฉ ์ฌ๋ก์ ๊ธฐ๋ฐํด์ ๊ฒฐ์ ๋๊ณค ํฉ๋๋ค. ๊ฒฐ๊ตญ ์ฌ๋ฌ ์ข
๋ฅ์ ์ค์ผ์ค๋ฌ๋ฅผ ์ง์ ์คํ์์ผ๋ณด๊ณ ๋์ผ๋ก ์ง์ ๋น๊ตํด์ ํ๋จํ๋ ๊ฒ ์ข์ ์ ํ์ผ ๊ฒ ๊ฐ์ต๋๋ค.
## Flax์์ ์ค์ผ์ค๋ฌ ๊ต์ฒดํ๊ธฐ
JAX/Flax ์ฌ์ฉ์์ธ ๊ฒฝ์ฐ ๊ธฐ๋ณธ ํ์ดํ๋ผ์ธ ์ค์ผ์ค๋ฌ๋ฅผ ๋ณ๊ฒฝํ ์๋ ์์ต๋๋ค. ๋ค์์ Flax Stable Diffusion ํ์ดํ๋ผ์ธ๊ณผ ์ด๊ณ ์ [DDPM-Solver++ ์ค์ผ์ค๋ฌ๋ฅผ](../api/schedulers/multistep_dpm_solver) ์ฌ์ฉํ์ฌ ์ถ๋ก ์ ์คํํ๋ ๋ฐฉ๋ฒ์ ๋ํ ์์์
๋๋ค .
```Python
import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler
model_id = "runwayml/stable-diffusion-v1-5"
scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler"
)
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
model_id,
scheduler=scheduler,
revision="bf16",
dtype=jax.numpy.bfloat16,
)
params["scheduler"] = scheduler_state
# Generate 1 image per parallel device (8 on TPUv2-8 or TPUv3-8)
prompt = "a photo of an astronaut riding a horse on mars"
num_samples = jax.device_count()
prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 25
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
```
<Tip warning={true}>
๋ค์ Flax ์ค์ผ์ค๋ฌ๋ *์์ง* Flax Stable Diffusion ํ์ดํ๋ผ์ธ๊ณผ ํธํ๋์ง ์์ต๋๋ค.
- `FlaxLMSDiscreteScheduler`
- `FlaxDDPMScheduler`
</Tip>
|