Diffusers documentation

JAX / Flax์—์„œ์˜ ๐Ÿงจ Stable Diffusion!

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.31.0).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

JAX / Flax์—์„œ์˜ ๐Ÿงจ Stable Diffusion!

๐Ÿค— Hugging Face [Diffusers] (https://github.com/huggingface/diffusers) ๋Š” ๋ฒ„์ „ 0.5.1๋ถ€ํ„ฐ Flax๋ฅผ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค! ์ด๋ฅผ ํ†ตํ•ด Colab, Kaggle, Google Cloud Platform์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒƒ์ฒ˜๋Ÿผ Google TPU์—์„œ ์ดˆ๊ณ ์† ์ถ”๋ก ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

์ด ๋…ธํŠธ๋ถ์€ JAX / Flax๋ฅผ ์‚ฌ์šฉํ•ด ์ถ”๋ก ์„ ์‹คํ–‰ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค. Stable Diffusion์˜ ์ž‘๋™ ๋ฐฉ์‹์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์„ ์›ํ•˜๊ฑฐ๋‚˜ GPU์—์„œ ์‹คํ–‰ํ•˜๋ ค๋ฉด ์ด [๋…ธํŠธ๋ถ] ](https://huggingface.co/docs/diffusers/stable_diffusion)์„ ์ฐธ์กฐํ•˜์„ธ์š”.

๋จผ์ €, TPU ๋ฐฑ์—”๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค. Colab์—์„œ ์ด ๋…ธํŠธ๋ถ์„ ์‹คํ–‰ํ•˜๋Š” ๊ฒฝ์šฐ, ๋ฉ”๋‰ด์—์„œ ๋Ÿฐํƒ€์ž„์„ ์„ ํƒํ•œ ๋‹ค์Œ โ€œ๋Ÿฐํƒ€์ž„ ์œ ํ˜• ๋ณ€๊ฒฝโ€ ์˜ต์…˜์„ ์„ ํƒํ•œ ๋‹ค์Œ ํ•˜๋“œ์›จ์–ด ๊ฐ€์†๊ธฐ ์„ค์ •์—์„œ TPU๋ฅผ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค.

JAX๋Š” TPU ์ „์šฉ์€ ์•„๋‹ˆ์ง€๋งŒ ๊ฐ TPU ์„œ๋ฒ„์—๋Š” 8๊ฐœ์˜ TPU ๊ฐ€์†๊ธฐ๊ฐ€ ๋ณ‘๋ ฌ๋กœ ์ž‘๋™ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํ•ด๋‹น ํ•˜๋“œ์›จ์–ด์—์„œ ๋” ๋น›์„ ๋ฐœํ•œ๋‹ค๋Š” ์ ์€ ์•Œ์•„๋‘์„ธ์š”.

Setup

๋จผ์ € diffusers๊ฐ€ ์„ค์น˜๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.

!pip install jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy
!pip install diffusers
import jax.tools.colab_tpu

jax.tools.colab_tpu.setup_tpu()
import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"Found {num_devices} JAX devices of type {device_type}.")
assert (
    "TPU" in device_type
), "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"
Found 8 JAX devices of type Cloud TPU.

๊ทธ๋Ÿฐ ๋‹ค์Œ ๋ชจ๋“  dependencies๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.

import numpy as np
import jax
import jax.numpy as jnp

from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline

๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

TPU ์žฅ์น˜๋Š” ํšจ์œจ์ ์ธ half-float ์œ ํ˜•์ธ bfloat16์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค. ํ…Œ์ŠคํŠธ์—๋Š” ์ด ์œ ํ˜•์„ ์‚ฌ์šฉํ•˜์ง€๋งŒ ๋Œ€์‹  float32๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ „์ฒด ์ •๋ฐ€๋„(full precision)๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

dtype = jnp.bfloat16

Flax๋Š” ํ•จ์ˆ˜ํ˜• ํ”„๋ ˆ์ž„์›Œํฌ์ด๋ฏ€๋กœ ๋ชจ๋ธ์€ ๋ฌด์ƒํƒœ(stateless)ํ˜•์ด๋ฉฐ ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” ๋ชจ๋ธ ์™ธ๋ถ€์— ์ €์žฅ๋ฉ๋‹ˆ๋‹ค. ์‚ฌ์ „ํ•™์Šต๋œ Flax ํŒŒ์ดํ”„๋ผ์ธ์„ ๋ถˆ๋Ÿฌ์˜ค๋ฉด ํŒŒ์ดํ”„๋ผ์ธ ์ž์ฒด์™€ ๋ชจ๋ธ ๊ฐ€์ค‘์น˜(๋˜๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜)๊ฐ€ ๋ชจ๋‘ ๋ฐ˜ํ™˜๋ฉ๋‹ˆ๋‹ค. ์ €ํฌ๋Š” bf16 ๋ฒ„์ „์˜ ๊ฐ€์ค‘์น˜๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ์œผ๋ฏ€๋กœ ์œ ํ˜• ๊ฒฝ๊ณ ๊ฐ€ ํ‘œ์‹œ๋˜์ง€๋งŒ ๋ฌด์‹œํ•ด๋„ ๋ฉ๋‹ˆ๋‹ค.

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    variant="bf16",
    dtype=dtype,
)

์ถ”๋ก 

TPU์—๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ 8๊ฐœ์˜ ๋””๋ฐ”์ด์Šค๊ฐ€ ๋ณ‘๋ ฌ๋กœ ์ž‘๋™ํ•˜๋ฏ€๋กœ ๋ณด์œ ํ•œ ๋””๋ฐ”์ด์Šค ์ˆ˜๋งŒํผ ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋ณต์ œํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ ๊ฐ๊ฐ ํ•˜๋‚˜์˜ ์ด๋ฏธ์ง€ ์ƒ์„ฑ์„ ๋‹ด๋‹นํ•˜๋Š” 8๊ฐœ์˜ ๋””๋ฐ”์ด์Šค์—์„œ ํ•œ ๋ฒˆ์— ์ถ”๋ก ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ํ•˜๋‚˜์˜ ์นฉ์ด ํ•˜๋‚˜์˜ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๋ฐ ๊ฑธ๋ฆฌ๋Š” ์‹œ๊ฐ„๊ณผ ๋™์ผํ•œ ์‹œ๊ฐ„์— 8๊ฐœ์˜ ์ด๋ฏธ์ง€๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋ณต์ œํ•˜๊ณ  ๋‚˜๋ฉด ํŒŒ์ดํ”„๋ผ์ธ์˜ prepare_inputs ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ํ† ํฐํ™”๋œ ํ…์ŠคํŠธ ID๋ฅผ ์–ป์Šต๋‹ˆ๋‹ค. ํ† ํฐํ™”๋œ ํ…์ŠคํŠธ์˜ ๊ธธ์ด๋Š” ๊ธฐ๋ณธ CLIP ํ…์ŠคํŠธ ๋ชจ๋ธ์˜ ๊ตฌ์„ฑ์— ๋”ฐ๋ผ 77ํ† ํฐ์œผ๋กœ ์„ค์ •๋ฉ๋‹ˆ๋‹ค.

prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape
(8, 77)

๋ณต์‚ฌ(Replication) ๋ฐ ์ •๋ ฌํ™”

๋ชจ๋ธ ๋งค๊ฐœ๋ณ€์ˆ˜์™€ ์ž…๋ ฅ๊ฐ’์€ ์šฐ๋ฆฌ๊ฐ€ ๋ณด์œ ํ•œ 8๊ฐœ์˜ ๋ณ‘๋ ฌ ์žฅ์น˜์— ๋ณต์‚ฌ(Replication)๋˜์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋งค๊ฐœ๋ณ€์ˆ˜ ๋”•์…”๋„ˆ๋ฆฌ๋Š” flax.jax_utils.replicate(๋”•์…”๋„ˆ๋ฆฌ๋ฅผ ์ˆœํšŒํ•˜๋ฉฐ ๊ฐ€์ค‘์น˜์˜ ๋ชจ์–‘์„ ๋ณ€๊ฒฝํ•˜์—ฌ 8๋ฒˆ ๋ฐ˜๋ณตํ•˜๋Š” ํ•จ์ˆ˜)๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ณต์‚ฌ๋ฉ๋‹ˆ๋‹ค. ๋ฐฐ์—ด์€ shard๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ณต์ œ๋ฉ๋‹ˆ๋‹ค.

p_params = replicate(params)
prompt_ids = shard(prompt_ids)
prompt_ids.shape
(8, 1, 77)

์ด shape์€ 8๊ฐœ์˜ ๋””๋ฐ”์ด์Šค ๊ฐ๊ฐ์ด shape (1, 77)์˜ jnp ๋ฐฐ์—ด์„ ์ž…๋ ฅ๊ฐ’์œผ๋กœ ๋ฐ›๋Š”๋‹ค๋Š” ์˜๋ฏธ์ž…๋‹ˆ๋‹ค. ์ฆ‰ 1์€ ๋””๋ฐ”์ด์Šค๋‹น batch(๋ฐฐ์น˜) ํฌ๊ธฐ์ž…๋‹ˆ๋‹ค. ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ์ถฉ๋ถ„ํ•œ TPU์—์„œ๋Š” ํ•œ ๋ฒˆ์— ์—ฌ๋Ÿฌ ์ด๋ฏธ์ง€(์นฉ๋‹น)๋ฅผ ์ƒ์„ฑํ•˜๋ ค๋Š” ๊ฒฝ์šฐ 1๋ณด๋‹ค ํด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•  ์ค€๋น„๊ฐ€ ๊ฑฐ์˜ ์™„๋ฃŒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค! ์ด์ œ ์ƒ์„ฑ ํ•จ์ˆ˜์— ์ „๋‹ฌํ•  ๋‚œ์ˆ˜ ์ƒ์„ฑ๊ธฐ๋งŒ ๋งŒ๋“ค๋ฉด ๋ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์€ ๋‚œ์ˆ˜๋ฅผ ๋‹ค๋ฃจ๋Š” ๋ชจ๋“  ํ•จ์ˆ˜์— ๋‚œ์ˆ˜ ์ƒ์„ฑ๊ธฐ๊ฐ€ ์žˆ์–ด์•ผ ํ•œ๋‹ค๋Š”, ๋‚œ์ˆ˜์— ๋Œ€ํ•ด ๋งค์šฐ ์ง„์ง€ํ•˜๊ณ  ๋…๋‹จ์ ์ธ Flax์˜ ํ‘œ์ค€ ์ ˆ์ฐจ์ž…๋‹ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ์—ฌ๋Ÿฌ ๋ถ„์‚ฐ๋œ ๊ธฐ๊ธฐ์—์„œ ํ›ˆ๋ จํ•  ๋•Œ์—๋„ ์žฌํ˜„์„ฑ์ด ๋ณด์žฅ๋ฉ๋‹ˆ๋‹ค.

์•„๋ž˜ ํ—ฌํผ ํ•จ์ˆ˜๋Š” ์‹œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋‚œ์ˆ˜ ์ƒ์„ฑ๊ธฐ๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค. ๋™์ผํ•œ ์‹œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ํ•œ ์ •ํ™•ํžˆ ๋™์ผํ•œ ๊ฒฐ๊ณผ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‚˜์ค‘์— ๋…ธํŠธ๋ถ์—์„œ ๊ฒฐ๊ณผ๋ฅผ ํƒ์ƒ‰ํ•  ๋•Œ์—” ๋‹ค๋ฅธ ์‹œ๋“œ๋ฅผ ์ž์œ ๋กญ๊ฒŒ ์‚ฌ์šฉํ•˜์„ธ์š”.

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

rng๋ฅผ ์–ป์€ ๋‹ค์Œ 8๋ฒˆ โ€˜๋ถ„ํ• โ€™ํ•˜์—ฌ ๊ฐ ๋””๋ฐ”์ด์Šค๊ฐ€ ๋‹ค๋ฅธ ์ œ๋„ˆ๋ ˆ์ดํ„ฐ๋ฅผ ์ˆ˜์‹ ํ•˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ๊ฐ ๋””๋ฐ”์ด์Šค๋งˆ๋‹ค ๋‹ค๋ฅธ ์ด๋ฏธ์ง€๊ฐ€ ์ƒ์„ฑ๋˜๋ฉฐ ์ „์ฒด ํ”„๋กœ์„ธ์Šค๋ฅผ ์žฌํ˜„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

JAX ์ฝ”๋“œ๋Š” ๋งค์šฐ ๋น ๋ฅด๊ฒŒ ์‹คํ–‰๋˜๋Š” ํšจ์œจ์ ์ธ ํ‘œํ˜„์œผ๋กœ ์ปดํŒŒ์ผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ํ›„์† ํ˜ธ์ถœ์—์„œ ๋ชจ๋“  ์ž…๋ ฅ์ด ๋™์ผํ•œ ๋ชจ์–‘์„ ๊ฐ–๋„๋ก ํ•ด์•ผ ํ•˜๋ฉฐ, ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด JAX๊ฐ€ ์ฝ”๋“œ๋ฅผ ๋‹ค์‹œ ์ปดํŒŒ์ผํ•ด์•ผ ํ•˜๋ฏ€๋กœ ์ตœ์ ํ™”๋œ ์†๋„๋ฅผ ํ™œ์šฉํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.

jit = True๋ฅผ ์ธ์ˆ˜๋กœ ์ „๋‹ฌํ•˜๋ฉด Flax ํŒŒ์ดํ”„๋ผ์ธ์ด ์ฝ”๋“œ๋ฅผ ์ปดํŒŒ์ผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ ๋ชจ๋ธ์ด ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ 8๊ฐœ์˜ ๋””๋ฐ”์ด์Šค์—์„œ ๋ณ‘๋ ฌ๋กœ ์‹คํ–‰๋˜๋„๋ก ๋ณด์žฅํ•ฉ๋‹ˆ๋‹ค.

๋‹ค์Œ ์…€์„ ์ฒ˜์Œ ์‹คํ–‰ํ•˜๋ฉด ์ปดํŒŒ์ผํ•˜๋Š” ๋ฐ ์‹œ๊ฐ„์ด ์˜ค๋ž˜ ๊ฑธ๋ฆฌ์ง€๋งŒ ์ดํ›„ ํ˜ธ์ถœ(์ž…๋ ฅ์ด ๋‹ค๋ฅธ ๊ฒฝ์šฐ์—๋„)์€ ํ›จ์”ฌ ๋นจ๋ผ์ง‘๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ํ…Œ์ŠคํŠธํ–ˆ์„ ๋•Œ TPU v2-8์—์„œ ์ปดํŒŒ์ผํ•˜๋Š” ๋ฐ 1๋ถ„ ์ด์ƒ ๊ฑธ๋ฆฌ์ง€๋งŒ ์ดํ›„ ์ถ”๋ก  ์‹คํ–‰์—๋Š” ์•ฝ 7์ดˆ๊ฐ€ ๊ฑธ๋ฆฝ๋‹ˆ๋‹ค.

%%time
images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
Wall time: 1min 29s

๋ฐ˜ํ™˜๋œ ๋ฐฐ์—ด์˜ shape์€ (8, 1, 512, 512, 3)์ž…๋‹ˆ๋‹ค. ์ด๋ฅผ ์žฌ๊ตฌ์„ฑํ•˜์—ฌ ๋‘ ๋ฒˆ์งธ ์ฐจ์›์„ ์ œ๊ฑฐํ•˜๊ณ  512 ร— 512 ร— 3์˜ ์ด๋ฏธ์ง€ 8๊ฐœ๋ฅผ ์–ป์€ ๋‹ค์Œ PIL๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

์‹œ๊ฐํ™”

์ด๋ฏธ์ง€๋ฅผ ๊ทธ๋ฆฌ๋“œ์— ํ‘œ์‹œํ•˜๋Š” ๋„์šฐ๋ฏธ ํ•จ์ˆ˜๋ฅผ ๋งŒ๋“ค์–ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

def image_grid(imgs, rows, cols):
    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid
image_grid(images, 2, 4)

img

๋‹ค๋ฅธ ํ”„๋กฌํ”„ํŠธ ์‚ฌ์šฉ

๋ชจ๋“  ๋””๋ฐ”์ด์Šค์—์„œ ๋™์ผํ•œ ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋ณต์ œํ•  ํ•„์š”๋Š” ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กฌํ”„ํŠธ 2๊ฐœ๋ฅผ ๊ฐ๊ฐ 4๋ฒˆ์”ฉ ์ƒ์„ฑํ•˜๊ฑฐ๋‚˜ ํ•œ ๋ฒˆ์— 8๊ฐœ์˜ ์„œ๋กœ ๋‹ค๋ฅธ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๋“ฑ ์›ํ•˜๋Š” ๊ฒƒ์€ ๋ฌด์—‡์ด๋“  ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ•œ๋ฒˆ ํ•ด๋ณด์„ธ์š”!

๋จผ์ € ์ž…๋ ฅ ์ค€๋น„ ์ฝ”๋“œ๋ฅผ ํŽธ๋ฆฌํ•œ ํ•จ์ˆ˜๋กœ ๋ฆฌํŒฉํ„ฐ๋งํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค:

prompts = [
    "Labrador in the style of Hokusai",
    "Painting of a squirrel skating in New York",
    "HAL-9000 in the style of Van Gogh",
    "Times Square under water, with fish and a dolphin swimming around",
    "Ancient Roman fresco showing a man working on his laptop",
    "Close-up photograph of young black woman against urban background, high quality, bokeh",
    "Armchair in the shape of an avocado",
    "Clown astronaut in space, with Earth in the background",
]
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

image_grid(images, 2, 4)

img

๋ณ‘๋ ฌํ™”(parallelization)๋Š” ์–ด๋–ป๊ฒŒ ์ž‘๋™ํ•˜๋Š”๊ฐ€?

์•ž์„œ diffusers Flax ํŒŒ์ดํ”„๋ผ์ธ์ด ๋ชจ๋ธ์„ ์ž๋™์œผ๋กœ ์ปดํŒŒ์ผํ•˜๊ณ  ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋“  ๊ธฐ๊ธฐ์—์„œ ๋ณ‘๋ ฌ๋กœ ์‹คํ–‰ํ•œ๋‹ค๊ณ  ๋ง์”€๋“œ๋ ธ์Šต๋‹ˆ๋‹ค. ์ด์ œ ๊ทธ ํ”„๋กœ์„ธ์Šค๋ฅผ ๊ฐ„๋žตํ•˜๊ฒŒ ์‚ดํŽด๋ณด๊ณ  ์ž‘๋™ ๋ฐฉ์‹์„ ๋ณด์—ฌ๋“œ๋ฆฌ๊ฒ ์Šต๋‹ˆ๋‹ค.

JAX ๋ณ‘๋ ฌํ™”๋Š” ์—ฌ๋Ÿฌ ๊ฐ€์ง€ ๋ฐฉ๋ฒ•์œผ๋กœ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฐ€์žฅ ์‰ฌ์šด ๋ฐฉ๋ฒ•์€ jax.pmap ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋‹จ์ผ ํ”„๋กœ๊ทธ๋žจ, ๋‹ค์ค‘ ๋ฐ์ดํ„ฐ(SPMD) ๋ณ‘๋ ฌํ™”๋ฅผ ๋‹ฌ์„ฑํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ฆ‰, ๋™์ผํ•œ ์ฝ”๋“œ์˜ ๋ณต์‚ฌ๋ณธ์„ ๊ฐ๊ฐ ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ ์ž…๋ ฅ์— ๋Œ€ํ•ด ์—ฌ๋Ÿฌ ๊ฐœ ์‹คํ–‰ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋” ์ •๊ตํ•œ ์ ‘๊ทผ ๋ฐฉ์‹๋„ ๊ฐ€๋Šฅํ•˜๋ฏ€๋กœ ๊ด€์‹ฌ์ด ์žˆ์œผ์‹œ๋‹ค๋ฉด JAX ๋ฌธ์„œ์™€ pjit ํŽ˜์ด์ง€์—์„œ ์ด ์ฃผ์ œ๋ฅผ ์‚ดํŽด๋ณด์‹œ๊ธฐ ๋ฐ”๋ž๋‹ˆ๋‹ค!

jax.pmap์€ ๋‘ ๊ฐ€์ง€ ๊ธฐ๋Šฅ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค:

  • jax.jit()๋ฅผ ํ˜ธ์ถœํ•œ ๊ฒƒ์ฒ˜๋Ÿผ ์ฝ”๋“œ๋ฅผ ์ปดํŒŒ์ผ(๋˜๋Š” jit)ํ•ฉ๋‹ˆ๋‹ค. ์ด ์ž‘์—…์€ pmap์„ ํ˜ธ์ถœํ•  ๋•Œ๊ฐ€ ์•„๋‹ˆ๋ผ pmapped ํ•จ์ˆ˜๊ฐ€ ์ฒ˜์Œ ํ˜ธ์ถœ๋  ๋•Œ ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค.
  • ์ปดํŒŒ์ผ๋œ ์ฝ”๋“œ๊ฐ€ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋“  ๊ธฐ๊ธฐ์—์„œ ๋ณ‘๋ ฌ๋กœ ์‹คํ–‰๋˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.

์ž‘๋™ ๋ฐฉ์‹์„ ๋ณด์—ฌ๋“œ๋ฆฌ๊ธฐ ์œ„ํ•ด ์ด๋ฏธ์ง€ ์ƒ์„ฑ์„ ์‹คํ–‰ํ•˜๋Š” ๋น„๊ณต๊ฐœ ๋ฉ”์„œ๋“œ์ธ ํŒŒ์ดํ”„๋ผ์ธ์˜ _generate ๋ฉ”์„œ๋“œ๋ฅผ pmapํ•ฉ๋‹ˆ๋‹ค. ์ด ๋ฉ”์„œ๋“œ๋Š” ํ–ฅํ›„ Diffusers ๋ฆด๋ฆฌ์Šค์—์„œ ์ด๋ฆ„์ด ๋ณ€๊ฒฝ๋˜๊ฑฐ๋‚˜ ์ œ๊ฑฐ๋  ์ˆ˜ ์žˆ๋‹ค๋Š” ์ ์— ์œ ์˜ํ•˜์„ธ์š”.

p_generate = pmap(pipeline._generate)

pmap์„ ์‚ฌ์šฉํ•œ ํ›„ ์ค€๋น„๋œ ํ•จ์ˆ˜ p_generate๋Š” ๊ฐœ๋…์ ์œผ๋กœ ๋‹ค์Œ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค:

  • ๊ฐ ์žฅ์น˜์—์„œ ๊ธฐ๋ณธ ํ•จ์ˆ˜ pipeline._generate์˜ ๋ณต์‚ฌ๋ณธ์„ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค.
  • ๊ฐ ์žฅ์น˜์— ์ž…๋ ฅ ์ธ์ˆ˜์˜ ๋‹ค๋ฅธ ๋ถ€๋ถ„์„ ๋ณด๋ƒ…๋‹ˆ๋‹ค. ์ด๊ฒƒ์ด ๋ฐ”๋กœ ์ƒค๋”ฉ์ด ์‚ฌ์šฉ๋˜๋Š” ์ด์œ ์ž…๋‹ˆ๋‹ค. ์ด ๊ฒฝ์šฐ prompt_ids์˜ shape์€ (8, 1, 77, 768)์ž…๋‹ˆ๋‹ค. ์ด ๋ฐฐ์—ด์€ 8๊ฐœ๋กœ ๋ถ„ํ• ๋˜๊ณ  _generate์˜ ๊ฐ ๋ณต์‚ฌ๋ณธ์€ (1, 77, 768)์˜ shape์„ ๊ฐ€์ง„ ์ž…๋ ฅ์„ ๋ฐ›๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

๋ณ‘๋ ฌ๋กœ ํ˜ธ์ถœ๋œ๋‹ค๋Š” ์‚ฌ์‹ค์„ ์™„์ „ํžˆ ๋ฌด์‹œํ•˜๊ณ  _generate๋ฅผ ์ฝ”๋”ฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. batch(๋ฐฐ์น˜) ํฌ๊ธฐ(์ด ์˜ˆ์ œ์—์„œ๋Š” 1)์™€ ์ฝ”๋“œ์— ์ ํ•ฉํ•œ ์ฐจ์›๋งŒ ์‹ ๊ฒฝ ์“ฐ๋ฉด ๋˜๋ฉฐ, ๋ณ‘๋ ฌ๋กœ ์ž‘๋™ํ•˜๊ธฐ ์œ„ํ•ด ์•„๋ฌด๊ฒƒ๋„ ๋ณ€๊ฒฝํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.

ํŒŒ์ดํ”„๋ผ์ธ ํ˜ธ์ถœ์„ ์‚ฌ์šฉํ•  ๋•Œ์™€ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ, ๋‹ค์Œ ์…€์„ ์ฒ˜์Œ ์‹คํ–‰ํ•  ๋•Œ๋Š” ์‹œ๊ฐ„์ด ๊ฑธ๋ฆฌ์ง€๋งŒ ๊ทธ ์ดํ›„์—๋Š” ํ›จ์”ฌ ๋นจ๋ผ์ง‘๋‹ˆ๋‹ค.

%%time
images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
images.shape
CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
Wall time: 1min 15s
images.shape
(8, 1, 512, 512, 3)

JAX๋Š” ๋น„๋™๊ธฐ ๋””์ŠคํŒจ์น˜๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ๊ฐ€๋Šฅํ•œ ํ•œ ๋นจ๋ฆฌ ์ œ์–ด๊ถŒ์„ Python ๋ฃจํ”„์— ๋ฐ˜ํ™˜ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ถ”๋ก  ์‹œ๊ฐ„์„ ์ •ํ™•ํ•˜๊ฒŒ ์ธก์ •ํ•˜๊ธฐ ์œ„ํ•ด block_until_ready()๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์•„์ง ๊ตฌ์ฒดํ™”๋˜์ง€ ์•Š์€ ๊ณ„์‚ฐ ๊ฒฐ๊ณผ๋ฅผ ์‚ฌ์šฉํ•˜๋ ค๋Š” ๊ฒฝ์šฐ ์ž๋™์œผ๋กœ ์ฐจ๋‹จ์ด ์ˆ˜ํ–‰๋˜๋ฏ€๋กœ ์ฝ”๋“œ์—์„œ ์ด ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.

< > Update on GitHub