BleachNick's picture
upload required packages
87d40d2

A newer version of the Gradio SDK is available: 5.9.1

Upgrade

์—ฌ๋Ÿฌ GPU๋ฅผ ์‚ฌ์šฉํ•œ ๋ถ„์‚ฐ ์ถ”๋ก 

๋ถ„์‚ฐ ์„ค์ •์—์„œ๋Š” ์—ฌ๋Ÿฌ ๊ฐœ์˜ ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋™์‹œ์— ์ƒ์„ฑํ•  ๋•Œ ์œ ์šฉํ•œ ๐Ÿค— Accelerate ๋˜๋Š” PyTorch Distributed๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์—ฌ๋Ÿฌ GPU์—์„œ ์ถ”๋ก ์„ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ๋ถ„์‚ฐ ์ถ”๋ก ์„ ์œ„ํ•ด ๐Ÿค— Accelerate์™€ PyTorch Distributed๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ๋“œ๋ฆฝ๋‹ˆ๋‹ค.

๐Ÿค— Accelerate

๐Ÿค— Accelerate๋Š” ๋ถ„์‚ฐ ์„ค์ •์—์„œ ์ถ”๋ก ์„ ์‰ฝ๊ฒŒ ํ›ˆ๋ จํ•˜๊ฑฐ๋‚˜ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ๋„๋ก ์„ค๊ณ„๋œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์ž…๋‹ˆ๋‹ค. ๋ถ„์‚ฐ ํ™˜๊ฒฝ ์„ค์ • ํ”„๋กœ์„ธ์Šค๋ฅผ ๊ฐ„์†Œํ™”ํ•˜์—ฌ PyTorch ์ฝ”๋“œ์— ์ง‘์ค‘ํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•ด์ค๋‹ˆ๋‹ค.

์‹œ์ž‘ํ•˜๋ ค๋ฉด Python ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๊ณ  [accelerate.PartialState]๋ฅผ ์ดˆ๊ธฐํ™”ํ•˜์—ฌ ๋ถ„์‚ฐ ํ™˜๊ฒฝ์„ ์ƒ์„ฑํ•˜๋ฉด, ์„ค์ •์ด ์ž๋™์œผ๋กœ ๊ฐ์ง€๋˜๋ฏ€๋กœ rank ๋˜๋Š” world_size๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ์ •์˜ํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. ['DiffusionPipeline]์„ distributed_state.device`๋กœ ์ด๋™ํ•˜์—ฌ ๊ฐ ํ”„๋กœ์„ธ์Šค์— GPU๋ฅผ ํ• ๋‹นํ•ฉ๋‹ˆ๋‹ค.

์ด์ œ ์ปจํ…์ŠคํŠธ ๊ด€๋ฆฌ์ž๋กœ [~accelerate.PartialState.split_between_processes] ์œ ํ‹ธ๋ฆฌํ‹ฐ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ”„๋กœ์„ธ์Šค ์ˆ˜์— ๋”ฐ๋ผ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž๋™์œผ๋กœ ๋ถ„๋ฐฐํ•ฉ๋‹ˆ๋‹ค.

from accelerate import PartialState
from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
distributed_state = PartialState()
pipeline.to(distributed_state.device)

with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
    result = pipeline(prompt).images[0]
    result.save(f"result_{distributed_state.process_index}.png")

Use the --num_processes argument to specify the number of GPUs to use, and call accelerate launch to run the script:

accelerate launch run_distributed.py --num_processes=2

์ž์„ธํ•œ ๋‚ด์šฉ์€ ๐Ÿค— Accelerate๋ฅผ ์‚ฌ์šฉํ•œ ๋ถ„์‚ฐ ์ถ”๋ก  ๊ฐ€์ด๋“œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

Pytoerch ๋ถ„์‚ฐ

PyTorch๋Š” ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ฅผ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•˜๋Š” DistributedDataParallel์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค.

์‹œ์ž‘ํ•˜๋ ค๋ฉด Python ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๊ณ  torch.distributed ๋ฐ torch.multiprocessing์„ ์ž„ํฌํŠธํ•˜์—ฌ ๋ถ„์‚ฐ ํ”„๋กœ์„ธ์Šค ๊ทธ๋ฃน์„ ์„ค์ •ํ•˜๊ณ  ๊ฐ GPU์—์„œ ์ถ”๋ก ์šฉ ํ”„๋กœ์„ธ์Šค๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  [DiffusionPipeline]๋„ ์ดˆ๊ธฐํ™”ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

ํ™•์‚ฐ ํŒŒ์ดํ”„๋ผ์ธ์„ rank๋กœ ์ด๋™ํ•˜๊ณ  get_rank๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ ํ”„๋กœ์„ธ์Šค์— GPU๋ฅผ ํ• ๋‹นํ•˜๋ฉด ๊ฐ ํ”„๋กœ์„ธ์Šค๊ฐ€ ๋‹ค๋ฅธ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค:

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from diffusers import DiffusionPipeline

sd = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)

์‚ฌ์šฉํ•  ๋ฐฑ์—”๋“œ ์œ ํ˜•, ํ˜„์žฌ ํ”„๋กœ์„ธ์Šค์˜ rank, world_size ๋˜๋Š” ์ฐธ์—ฌํ•˜๋Š” ํ”„๋กœ์„ธ์Šค ์ˆ˜๋กœ ๋ถ„์‚ฐ ํ™˜๊ฒฝ ์ƒ์„ฑ์„ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜[init_process_group]๋ฅผ ๋งŒ๋“ค์–ด ์ถ”๋ก ์„ ์‹คํ–‰ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

2๊ฐœ์˜ GPU์—์„œ ์ถ”๋ก ์„ ๋ณ‘๋ ฌ๋กœ ์‹คํ–‰ํ•˜๋Š” ๊ฒฝ์šฐ world_size๋Š” 2์ž…๋‹ˆ๋‹ค.

def run_inference(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    sd.to(rank)

    if torch.distributed.get_rank() == 0:
        prompt = "a dog"
    elif torch.distributed.get_rank() == 1:
        prompt = "a cat"

    image = sd(prompt).images[0]
    image.save(f"./{'_'.join(prompt)}.png")

๋ถ„์‚ฐ ์ถ”๋ก ์„ ์‹คํ–‰ํ•˜๋ ค๋ฉด mp.spawn์„ ํ˜ธ์ถœํ•˜์—ฌ world_size์— ์ •์˜๋œ GPU ์ˆ˜์— ๋Œ€ํ•ด run_inference ํ•จ์ˆ˜๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค:

def main():
    world_size = 2
    mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main()

์ถ”๋ก  ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์™„๋ฃŒํ–ˆ์œผ๋ฉด --nproc_per_node ์ธ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์‚ฌ์šฉํ•  GPU ์ˆ˜๋ฅผ ์ง€์ •ํ•˜๊ณ  torchrun์„ ํ˜ธ์ถœํ•˜์—ฌ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค:

torchrun run_distributed.py --nproc_per_node=2