Diffusers documentation

JAX / Flax์—์„œ์˜ ๐Ÿงจ Stable Diffusion!

You are viewing v0.22.0 version. A newer version v0.31.0 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

JAX / Flax์—์„œ์˜ ๐Ÿงจ Stable Diffusion!

๐Ÿค— Hugging Face [Diffusers] (https://github.com/huggingface/diffusers) ๋Š” ๋ฒ„์ „ 0.5.1๋ถ€ํ„ฐ Flax๋ฅผ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค! ์ด๋ฅผ ํ†ตํ•ด Colab, Kaggle, Google Cloud Platform์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒƒ์ฒ˜๋Ÿผ Google TPU์—์„œ ์ดˆ๊ณ ์† ์ถ”๋ก ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

์ด ๋…ธํŠธ๋ถ์€ JAX / Flax๋ฅผ ์‚ฌ์šฉํ•ด ์ถ”๋ก ์„ ์‹คํ–‰ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค. Stable Diffusion์˜ ์ž‘๋™ ๋ฐฉ์‹์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์„ ์›ํ•˜๊ฑฐ๋‚˜ GPU์—์„œ ์‹คํ–‰ํ•˜๋ ค๋ฉด ์ด [๋…ธํŠธ๋ถ] ](https://huggingface.co/docs/diffusers/stable_diffusion)์„ ์ฐธ์กฐํ•˜์„ธ์š”.

๋จผ์ €, TPU ๋ฐฑ์—”๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค. Colab์—์„œ ์ด ๋…ธํŠธ๋ถ์„ ์‹คํ–‰ํ•˜๋Š” ๊ฒฝ์šฐ, ๋ฉ”๋‰ด์—์„œ ๋Ÿฐํƒ€์ž„์„ ์„ ํƒํ•œ ๋‹ค์Œ โ€œ๋Ÿฐํƒ€์ž„ ์œ ํ˜• ๋ณ€๊ฒฝโ€ ์˜ต์…˜์„ ์„ ํƒํ•œ ๋‹ค์Œ ํ•˜๋“œ์›จ์–ด ๊ฐ€์†๊ธฐ ์„ค์ •์—์„œ TPU๋ฅผ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค.

JAX๋Š” TPU ์ „์šฉ์€ ์•„๋‹ˆ์ง€๋งŒ ๊ฐ TPU ์„œ๋ฒ„์—๋Š” 8๊ฐœ์˜ TPU ๊ฐ€์†๊ธฐ๊ฐ€ ๋ณ‘๋ ฌ๋กœ ์ž‘๋™ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํ•ด๋‹น ํ•˜๋“œ์›จ์–ด์—์„œ ๋” ๋น›์„ ๋ฐœํ•œ๋‹ค๋Š” ์ ์€ ์•Œ์•„๋‘์„ธ์š”.

Setup

๋จผ์ € diffusers๊ฐ€ ์„ค์น˜๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.

!pip install jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy
!pip install diffusers
import jax.tools.colab_tpu

jax.tools.colab_tpu.setup_tpu()
import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"Found {num_devices} JAX devices of type {device_type}.")
assert (
    "TPU" in device_type
), "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"
Found 8 JAX devices of type Cloud TPU.

๊ทธ๋Ÿฐ ๋‹ค์Œ ๋ชจ๋“  dependencies๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.

import numpy as np
import jax
import jax.numpy as jnp

from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline

๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

TPU ์žฅ์น˜๋Š” ํšจ์œจ์ ์ธ half-float ์œ ํ˜•์ธ bfloat16์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค. ํ…Œ์ŠคํŠธ์—๋Š” ์ด ์œ ํ˜•์„ ์‚ฌ์šฉํ•˜์ง€๋งŒ ๋Œ€์‹  float32๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ „์ฒด ์ •๋ฐ€๋„(full precision)๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

dtype = jnp.bfloat16

Flax๋Š” ํ•จ์ˆ˜ํ˜• ํ”„๋ ˆ์ž„์›Œํฌ์ด๋ฏ€๋กœ ๋ชจ๋ธ์€ ๋ฌด์ƒํƒœ(stateless)ํ˜•์ด๋ฉฐ ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” ๋ชจ๋ธ ์™ธ๋ถ€์— ์ €์žฅ๋ฉ๋‹ˆ๋‹ค. ์‚ฌ์ „ํ•™์Šต๋œ Flax ํŒŒ์ดํ”„๋ผ์ธ์„ ๋ถˆ๋Ÿฌ์˜ค๋ฉด ํŒŒ์ดํ”„๋ผ์ธ ์ž์ฒด์™€ ๋ชจ๋ธ ๊ฐ€์ค‘์น˜(๋˜๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜)๊ฐ€ ๋ชจ๋‘ ๋ฐ˜ํ™˜๋ฉ๋‹ˆ๋‹ค. ์ €ํฌ๋Š” bf16 ๋ฒ„์ „์˜ ๊ฐ€์ค‘์น˜๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ์œผ๋ฏ€๋กœ ์œ ํ˜• ๊ฒฝ๊ณ ๊ฐ€ ํ‘œ์‹œ๋˜์ง€๋งŒ ๋ฌด์‹œํ•ด๋„ ๋ฉ๋‹ˆ๋‹ค.

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="bf16",
    dtype=dtype,
)

์ถ”๋ก 

TPU์—๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ 8๊ฐœ์˜ ๋””๋ฐ”์ด์Šค๊ฐ€ ๋ณ‘๋ ฌ๋กœ ์ž‘๋™ํ•˜๋ฏ€๋กœ ๋ณด์œ ํ•œ ๋””๋ฐ”์ด์Šค ์ˆ˜๋งŒํผ ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋ณต์ œํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ ๊ฐ๊ฐ ํ•˜๋‚˜์˜ ์ด๋ฏธ์ง€ ์ƒ์„ฑ์„ ๋‹ด๋‹นํ•˜๋Š” 8๊ฐœ์˜ ๋””๋ฐ”์ด์Šค์—์„œ ํ•œ ๋ฒˆ์— ์ถ”๋ก ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ํ•˜๋‚˜์˜ ์นฉ์ด ํ•˜๋‚˜์˜ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๋ฐ ๊ฑธ๋ฆฌ๋Š” ์‹œ๊ฐ„๊ณผ ๋™์ผํ•œ ์‹œ๊ฐ„์— 8๊ฐœ์˜ ์ด๋ฏธ์ง€๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋ณต์ œํ•˜๊ณ  ๋‚˜๋ฉด ํŒŒ์ดํ”„๋ผ์ธ์˜ prepare_inputs ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ํ† ํฐํ™”๋œ ํ…์ŠคํŠธ ID๋ฅผ ์–ป์Šต๋‹ˆ๋‹ค. ํ† ํฐํ™”๋œ ํ…์ŠคํŠธ์˜ ๊ธธ์ด๋Š” ๊ธฐ๋ณธ CLIP ํ…์ŠคํŠธ ๋ชจ๋ธ์˜ ๊ตฌ์„ฑ์— ๋”ฐ๋ผ 77ํ† ํฐ์œผ๋กœ ์„ค์ •๋ฉ๋‹ˆ๋‹ค.

prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape
(8, 77)

๋ณต์‚ฌ(Replication) ๋ฐ ์ •๋ ฌํ™”

๋ชจ๋ธ ๋งค๊ฐœ๋ณ€์ˆ˜์™€ ์ž…๋ ฅ๊ฐ’์€ ์šฐ๋ฆฌ๊ฐ€ ๋ณด์œ ํ•œ 8๊ฐœ์˜ ๋ณ‘๋ ฌ ์žฅ์น˜์— ๋ณต์‚ฌ(Replication)๋˜์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋งค๊ฐœ๋ณ€์ˆ˜ ๋”•์…”๋„ˆ๋ฆฌ๋Š” flax.jax_utils.replicate(๋”•์…”๋„ˆ๋ฆฌ๋ฅผ ์ˆœํšŒํ•˜๋ฉฐ ๊ฐ€์ค‘์น˜์˜ ๋ชจ์–‘์„ ๋ณ€๊ฒฝํ•˜์—ฌ 8๋ฒˆ ๋ฐ˜๋ณตํ•˜๋Š” ํ•จ์ˆ˜)๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ณต์‚ฌ๋ฉ๋‹ˆ๋‹ค. ๋ฐฐ์—ด์€ shard๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ณต์ œ๋ฉ๋‹ˆ๋‹ค.

p_params = replicate(params)
prompt_ids = shard(prompt_ids)
prompt_ids.shape
(8, 1, 77)

์ด shape์€ 8๊ฐœ์˜ ๋””๋ฐ”์ด์Šค ๊ฐ๊ฐ์ด shape (1, 77)์˜ jnp ๋ฐฐ์—ด์„ ์ž…๋ ฅ๊ฐ’์œผ๋กœ ๋ฐ›๋Š”๋‹ค๋Š” ์˜๋ฏธ์ž…๋‹ˆ๋‹ค. ์ฆ‰ 1์€ ๋””๋ฐ”์ด์Šค๋‹น batch(๋ฐฐ์น˜) ํฌ๊ธฐ์ž…๋‹ˆ๋‹ค. ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ์ถฉ๋ถ„ํ•œ TPU์—์„œ๋Š” ํ•œ ๋ฒˆ์— ์—ฌ๋Ÿฌ ์ด๋ฏธ์ง€(์นฉ๋‹น)๋ฅผ ์ƒ์„ฑํ•˜๋ ค๋Š” ๊ฒฝ์šฐ 1๋ณด๋‹ค ํด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•  ์ค€๋น„๊ฐ€ ๊ฑฐ์˜ ์™„๋ฃŒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค! ์ด์ œ ์ƒ์„ฑ ํ•จ์ˆ˜์— ์ „๋‹ฌํ•  ๋‚œ์ˆ˜ ์ƒ์„ฑ๊ธฐ๋งŒ ๋งŒ๋“ค๋ฉด ๋ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์€ ๋‚œ์ˆ˜๋ฅผ ๋‹ค๋ฃจ๋Š” ๋ชจ๋“  ํ•จ์ˆ˜์— ๋‚œ์ˆ˜ ์ƒ์„ฑ๊ธฐ๊ฐ€ ์žˆ์–ด์•ผ ํ•œ๋‹ค๋Š”, ๋‚œ์ˆ˜์— ๋Œ€ํ•ด ๋งค์šฐ ์ง„์ง€ํ•˜๊ณ  ๋…๋‹จ์ ์ธ Flax์˜ ํ‘œ์ค€ ์ ˆ์ฐจ์ž…๋‹ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ์—ฌ๋Ÿฌ ๋ถ„์‚ฐ๋œ ๊ธฐ๊ธฐ์—์„œ ํ›ˆ๋ จํ•  ๋•Œ์—๋„ ์žฌํ˜„์„ฑ์ด ๋ณด์žฅ๋ฉ๋‹ˆ๋‹ค.

์•„๋ž˜ ํ—ฌํผ ํ•จ์ˆ˜๋Š” ์‹œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋‚œ์ˆ˜ ์ƒ์„ฑ๊ธฐ๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค. ๋™์ผํ•œ ์‹œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ํ•œ ์ •ํ™•ํžˆ ๋™์ผํ•œ ๊ฒฐ๊ณผ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‚˜์ค‘์— ๋…ธํŠธ๋ถ์—์„œ ๊ฒฐ๊ณผ๋ฅผ ํƒ์ƒ‰ํ•  ๋•Œ์—” ๋‹ค๋ฅธ ์‹œ๋“œ๋ฅผ ์ž์œ ๋กญ๊ฒŒ ์‚ฌ์šฉํ•˜์„ธ์š”.

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

rng๋ฅผ ์–ป์€ ๋‹ค์Œ 8๋ฒˆ โ€˜๋ถ„ํ• โ€™ํ•˜์—ฌ ๊ฐ ๋””๋ฐ”์ด์Šค๊ฐ€ ๋‹ค๋ฅธ ์ œ๋„ˆ๋ ˆ์ดํ„ฐ๋ฅผ ์ˆ˜์‹ ํ•˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ๊ฐ ๋””๋ฐ”์ด์Šค๋งˆ๋‹ค ๋‹ค๋ฅธ ์ด๋ฏธ์ง€๊ฐ€ ์ƒ์„ฑ๋˜๋ฉฐ ์ „์ฒด ํ”„๋กœ์„ธ์Šค๋ฅผ ์žฌํ˜„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

JAX ์ฝ”๋“œ๋Š” ๋งค์šฐ ๋น ๋ฅด๊ฒŒ ์‹คํ–‰๋˜๋Š” ํšจ์œจ์ ์ธ ํ‘œํ˜„์œผ๋กœ ์ปดํŒŒ์ผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ํ›„์† ํ˜ธ์ถœ์—์„œ ๋ชจ๋“  ์ž…๋ ฅ์ด ๋™์ผํ•œ ๋ชจ์–‘์„ ๊ฐ–๋„๋ก ํ•ด์•ผ ํ•˜๋ฉฐ, ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด JAX๊ฐ€ ์ฝ”๋“œ๋ฅผ ๋‹ค์‹œ ์ปดํŒŒ์ผํ•ด์•ผ ํ•˜๋ฏ€๋กœ ์ตœ์ ํ™”๋œ ์†๋„๋ฅผ ํ™œ์šฉํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.

jit = True๋ฅผ ์ธ์ˆ˜๋กœ ์ „๋‹ฌํ•˜๋ฉด Flax ํŒŒ์ดํ”„๋ผ์ธ์ด ์ฝ”๋“œ๋ฅผ ์ปดํŒŒ์ผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ ๋ชจ๋ธ์ด ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ 8๊ฐœ์˜ ๋””๋ฐ”์ด์Šค์—์„œ ๋ณ‘๋ ฌ๋กœ ์‹คํ–‰๋˜๋„๋ก ๋ณด์žฅํ•ฉ๋‹ˆ๋‹ค.

๋‹ค์Œ ์…€์„ ์ฒ˜์Œ ์‹คํ–‰ํ•˜๋ฉด ์ปดํŒŒ์ผํ•˜๋Š” ๋ฐ ์‹œ๊ฐ„์ด ์˜ค๋ž˜ ๊ฑธ๋ฆฌ์ง€๋งŒ ์ดํ›„ ํ˜ธ์ถœ(์ž…๋ ฅ์ด ๋‹ค๋ฅธ ๊ฒฝ์šฐ์—๋„)์€ ํ›จ์”ฌ ๋นจ๋ผ์ง‘๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ํ…Œ์ŠคํŠธํ–ˆ์„ ๋•Œ TPU v2-8์—์„œ ์ปดํŒŒ์ผํ•˜๋Š” ๋ฐ 1๋ถ„ ์ด์ƒ ๊ฑธ๋ฆฌ์ง€๋งŒ ์ดํ›„ ์ถ”๋ก  ์‹คํ–‰์—๋Š” ์•ฝ 7์ดˆ๊ฐ€ ๊ฑธ๋ฆฝ๋‹ˆ๋‹ค.

%%time
images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
Wall time: 1min 29s

๋ฐ˜ํ™˜๋œ ๋ฐฐ์—ด์˜ shape์€ (8, 1, 512, 512, 3)์ž…๋‹ˆ๋‹ค. ์ด๋ฅผ ์žฌ๊ตฌ์„ฑํ•˜์—ฌ ๋‘ ๋ฒˆ์งธ ์ฐจ์›์„ ์ œ๊ฑฐํ•˜๊ณ  512 ร— 512 ร— 3์˜ ์ด๋ฏธ์ง€ 8๊ฐœ๋ฅผ ์–ป์€ ๋‹ค์Œ PIL๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

์‹œ๊ฐํ™”

์ด๋ฏธ์ง€๋ฅผ ๊ทธ๋ฆฌ๋“œ์— ํ‘œ์‹œํ•˜๋Š” ๋„์šฐ๋ฏธ ํ•จ์ˆ˜๋ฅผ ๋งŒ๋“ค์–ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

def image_grid(imgs, rows, cols):
    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid
image_grid(images, 2, 4)

img

๋‹ค๋ฅธ ํ”„๋กฌํ”„ํŠธ ์‚ฌ์šฉ

๋ชจ๋“  ๋””๋ฐ”์ด์Šค์—์„œ ๋™์ผํ•œ ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋ณต์ œํ•  ํ•„์š”๋Š” ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กฌํ”„ํŠธ 2๊ฐœ๋ฅผ ๊ฐ๊ฐ 4๋ฒˆ์”ฉ ์ƒ์„ฑํ•˜๊ฑฐ๋‚˜ ํ•œ ๋ฒˆ์— 8๊ฐœ์˜ ์„œ๋กœ ๋‹ค๋ฅธ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๋“ฑ ์›ํ•˜๋Š” ๊ฒƒ์€ ๋ฌด์—‡์ด๋“  ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ•œ๋ฒˆ ํ•ด๋ณด์„ธ์š”!

๋จผ์ € ์ž…๋ ฅ ์ค€๋น„ ์ฝ”๋“œ๋ฅผ ํŽธ๋ฆฌํ•œ ํ•จ์ˆ˜๋กœ ๋ฆฌํŒฉํ„ฐ๋งํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค:

prompts = [
    "Labrador in the style of Hokusai",
    "Painting of a squirrel skating in New York",
    "HAL-9000 in the style of Van Gogh",
    "Times Square under water, with fish and a dolphin swimming around",
    "Ancient Roman fresco showing a man working on his laptop",
    "Close-up photograph of young black woman against urban background, high quality, bokeh",
    "Armchair in the shape of an avocado",
    "Clown astronaut in space, with Earth in the background",
]
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

image_grid(images, 2, 4)

img

๋ณ‘๋ ฌํ™”(parallelization)๋Š” ์–ด๋–ป๊ฒŒ ์ž‘๋™ํ•˜๋Š”๊ฐ€?

์•ž์„œ diffusers Flax ํŒŒ์ดํ”„๋ผ์ธ์ด ๋ชจ๋ธ์„ ์ž๋™์œผ๋กœ ์ปดํŒŒ์ผํ•˜๊ณ  ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋“  ๊ธฐ๊ธฐ์—์„œ ๋ณ‘๋ ฌ๋กœ ์‹คํ–‰ํ•œ๋‹ค๊ณ  ๋ง์”€๋“œ๋ ธ์Šต๋‹ˆ๋‹ค. ์ด์ œ ๊ทธ ํ”„๋กœ์„ธ์Šค๋ฅผ ๊ฐ„๋žตํ•˜๊ฒŒ ์‚ดํŽด๋ณด๊ณ  ์ž‘๋™ ๋ฐฉ์‹์„ ๋ณด์—ฌ๋“œ๋ฆฌ๊ฒ ์Šต๋‹ˆ๋‹ค.

JAX ๋ณ‘๋ ฌํ™”๋Š” ์—ฌ๋Ÿฌ ๊ฐ€์ง€ ๋ฐฉ๋ฒ•์œผ๋กœ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฐ€์žฅ ์‰ฌ์šด ๋ฐฉ๋ฒ•์€ jax.pmap ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋‹จ์ผ ํ”„๋กœ๊ทธ๋žจ, ๋‹ค์ค‘ ๋ฐ์ดํ„ฐ(SPMD) ๋ณ‘๋ ฌํ™”๋ฅผ ๋‹ฌ์„ฑํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ฆ‰, ๋™์ผํ•œ ์ฝ”๋“œ์˜ ๋ณต์‚ฌ๋ณธ์„ ๊ฐ๊ฐ ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ ์ž…๋ ฅ์— ๋Œ€ํ•ด ์—ฌ๋Ÿฌ ๊ฐœ ์‹คํ–‰ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋” ์ •๊ตํ•œ ์ ‘๊ทผ ๋ฐฉ์‹๋„ ๊ฐ€๋Šฅํ•˜๋ฏ€๋กœ ๊ด€์‹ฌ์ด ์žˆ์œผ์‹œ๋‹ค๋ฉด JAX ๋ฌธ์„œ์™€ pjit ํŽ˜์ด์ง€์—์„œ ์ด ์ฃผ์ œ๋ฅผ ์‚ดํŽด๋ณด์‹œ๊ธฐ ๋ฐ”๋ž๋‹ˆ๋‹ค!

jax.pmap์€ ๋‘ ๊ฐ€์ง€ ๊ธฐ๋Šฅ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค:

  • jax.jit()๋ฅผ ํ˜ธ์ถœํ•œ ๊ฒƒ์ฒ˜๋Ÿผ ์ฝ”๋“œ๋ฅผ ์ปดํŒŒ์ผ(๋˜๋Š” jit)ํ•ฉ๋‹ˆ๋‹ค. ์ด ์ž‘์—…์€ pmap์„ ํ˜ธ์ถœํ•  ๋•Œ๊ฐ€ ์•„๋‹ˆ๋ผ pmapped ํ•จ์ˆ˜๊ฐ€ ์ฒ˜์Œ ํ˜ธ์ถœ๋  ๋•Œ ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค.
  • ์ปดํŒŒ์ผ๋œ ์ฝ”๋“œ๊ฐ€ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋“  ๊ธฐ๊ธฐ์—์„œ ๋ณ‘๋ ฌ๋กœ ์‹คํ–‰๋˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.

์ž‘๋™ ๋ฐฉ์‹์„ ๋ณด์—ฌ๋“œ๋ฆฌ๊ธฐ ์œ„ํ•ด ์ด๋ฏธ์ง€ ์ƒ์„ฑ์„ ์‹คํ–‰ํ•˜๋Š” ๋น„๊ณต๊ฐœ ๋ฉ”์„œ๋“œ์ธ ํŒŒ์ดํ”„๋ผ์ธ์˜ _generate ๋ฉ”์„œ๋“œ๋ฅผ pmapํ•ฉ๋‹ˆ๋‹ค. ์ด ๋ฉ”์„œ๋“œ๋Š” ํ–ฅํ›„ Diffusers ๋ฆด๋ฆฌ์Šค์—์„œ ์ด๋ฆ„์ด ๋ณ€๊ฒฝ๋˜๊ฑฐ๋‚˜ ์ œ๊ฑฐ๋  ์ˆ˜ ์žˆ๋‹ค๋Š” ์ ์— ์œ ์˜ํ•˜์„ธ์š”.

p_generate = pmap(pipeline._generate)

pmap์„ ์‚ฌ์šฉํ•œ ํ›„ ์ค€๋น„๋œ ํ•จ์ˆ˜ p_generate๋Š” ๊ฐœ๋…์ ์œผ๋กœ ๋‹ค์Œ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค:

  • ๊ฐ ์žฅ์น˜์—์„œ ๊ธฐ๋ณธ ํ•จ์ˆ˜ pipeline._generate์˜ ๋ณต์‚ฌ๋ณธ์„ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค.
  • ๊ฐ ์žฅ์น˜์— ์ž…๋ ฅ ์ธ์ˆ˜์˜ ๋‹ค๋ฅธ ๋ถ€๋ถ„์„ ๋ณด๋ƒ…๋‹ˆ๋‹ค. ์ด๊ฒƒ์ด ๋ฐ”๋กœ ์ƒค๋”ฉ์ด ์‚ฌ์šฉ๋˜๋Š” ์ด์œ ์ž…๋‹ˆ๋‹ค. ์ด ๊ฒฝ์šฐ prompt_ids์˜ shape์€ (8, 1, 77, 768)์ž…๋‹ˆ๋‹ค. ์ด ๋ฐฐ์—ด์€ 8๊ฐœ๋กœ ๋ถ„ํ• ๋˜๊ณ  _generate์˜ ๊ฐ ๋ณต์‚ฌ๋ณธ์€ (1, 77, 768)์˜ shape์„ ๊ฐ€์ง„ ์ž…๋ ฅ์„ ๋ฐ›๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

๋ณ‘๋ ฌ๋กœ ํ˜ธ์ถœ๋œ๋‹ค๋Š” ์‚ฌ์‹ค์„ ์™„์ „ํžˆ ๋ฌด์‹œํ•˜๊ณ  _generate๋ฅผ ์ฝ”๋”ฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. batch(๋ฐฐ์น˜) ํฌ๊ธฐ(์ด ์˜ˆ์ œ์—์„œ๋Š” 1)์™€ ์ฝ”๋“œ์— ์ ํ•ฉํ•œ ์ฐจ์›๋งŒ ์‹ ๊ฒฝ ์“ฐ๋ฉด ๋˜๋ฉฐ, ๋ณ‘๋ ฌ๋กœ ์ž‘๋™ํ•˜๊ธฐ ์œ„ํ•ด ์•„๋ฌด๊ฒƒ๋„ ๋ณ€๊ฒฝํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.

ํŒŒ์ดํ”„๋ผ์ธ ํ˜ธ์ถœ์„ ์‚ฌ์šฉํ•  ๋•Œ์™€ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ, ๋‹ค์Œ ์…€์„ ์ฒ˜์Œ ์‹คํ–‰ํ•  ๋•Œ๋Š” ์‹œ๊ฐ„์ด ๊ฑธ๋ฆฌ์ง€๋งŒ ๊ทธ ์ดํ›„์—๋Š” ํ›จ์”ฌ ๋นจ๋ผ์ง‘๋‹ˆ๋‹ค.

%%time
images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
images.shape
CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
Wall time: 1min 15s
images.shape
(8, 1, 512, 512, 3)

JAX๋Š” ๋น„๋™๊ธฐ ๋””์ŠคํŒจ์น˜๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ๊ฐ€๋Šฅํ•œ ํ•œ ๋นจ๋ฆฌ ์ œ์–ด๊ถŒ์„ Python ๋ฃจํ”„์— ๋ฐ˜ํ™˜ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ถ”๋ก  ์‹œ๊ฐ„์„ ์ •ํ™•ํ•˜๊ฒŒ ์ธก์ •ํ•˜๊ธฐ ์œ„ํ•ด block_until_ready()๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์•„์ง ๊ตฌ์ฒดํ™”๋˜์ง€ ์•Š์€ ๊ณ„์‚ฐ ๊ฒฐ๊ณผ๋ฅผ ์‚ฌ์šฉํ•˜๋ ค๋Š” ๊ฒฝ์šฐ ์ž๋™์œผ๋กœ ์ฐจ๋‹จ์ด ์ˆ˜ํ–‰๋˜๋ฏ€๋กœ ์ฝ”๋“œ์—์„œ ์ด ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.