Spaces:
Running
on
Zero
A newer version of the Gradio SDK is available:
5.12.0
InstructPix2Pix
InstructPix2Pix๋ text-conditioned diffusion ๋ชจ๋ธ์ด ํ ์ด๋ฏธ์ง์ ํธ์ง์ ๋ฐ๋ฅผ ์ ์๋๋ก ํ์ธํ๋ํ๋ ๋ฐฉ๋ฒ์ ๋๋ค. ์ด ๋ฐฉ๋ฒ์ ์ฌ์ฉํ์ฌ ํ์ธํ๋๋ ๋ชจ๋ธ์ ๋ค์์ ์ ๋ ฅ์ผ๋ก ์ฌ์ฉํฉ๋๋ค:
์ถ๋ ฅ์ ์ ๋ ฅ ์ด๋ฏธ์ง์ ํธ์ง ์ง์๊ฐ ๋ฐ์๋ "์์ ๋" ์ด๋ฏธ์ง์ ๋๋ค:
train_instruct_pix2pix.py
์คํฌ๋ฆฝํธ(์ฌ๊ธฐ์์ ์ฐพ์ ์ ์์ต๋๋ค.)๋ ํ์ต ์ ์ฐจ๋ฅผ ์ค๋ช
ํ๊ณ Stable Diffusion์ ์ ์ฉํ ์ ์๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค.
*** train_instruct_pix2pix.py
๋ ์๋ ๊ตฌํ์ ์ถฉ์คํ๋ฉด์ InstructPix2Pix ํ์ต ์ ์ฐจ๋ฅผ ๊ตฌํํ๊ณ ์์ง๋ง, ์๊ท๋ชจ ๋ฐ์ดํฐ์
์์๋ง ํ
์คํธ๋ฅผ ํ์ต๋๋ค. ์ด๋ ์ต์ข
๊ฒฐ๊ณผ์ ์ํฅ์ ๋ผ์น ์ ์์ต๋๋ค. ๋ ๋์ ๊ฒฐ๊ณผ๋ฅผ ์ํด, ๋ ํฐ ๋ฐ์ดํฐ์
์์ ๋ ๊ธธ๊ฒ ํ์ตํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค. ์ฌ๊ธฐ์์ InstructPix2Pix ํ์ต์ ์ํด ํฐ ๋ฐ์ดํฐ์
์ ์ฐพ์ ์ ์์ต๋๋ค.
PyTorch๋ก ๋ก์ปฌ์์ ์คํํ๊ธฐ
์ข ์์ฑ(dependencies) ์ค์นํ๊ธฐ
์ด ์คํฌ๋ฆฝํธ๋ฅผ ์คํํ๊ธฐ ์ ์, ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ํ์ต ์ข ์์ฑ์ ์ค์นํ์ธ์:
์ค์
์ต์ ๋ฒ์ ์ ์์ ์คํฌ๋ฆฝํธ๋ฅผ ์ฑ๊ณต์ ์ผ๋ก ์คํํ๊ธฐ ์ํด, ์๋ณธ์ผ๋ก๋ถํฐ ์ค์นํ๋ ๊ฒ๊ณผ ์์ ์คํฌ๋ฆฝํธ๋ฅผ ์์ฃผ ์ ๋ฐ์ดํธํ๊ณ ์์ ๋ณ ์๊ตฌ์ฌํญ์ ์ค์นํ๊ธฐ ๋๋ฌธ์ ์ต์ ์ํ๋ก ์ ์งํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค. ์ด๋ฅผ ์ํด, ์๋ก์ด ๊ฐ์ ํ๊ฒฝ์์ ๋ค์ ์คํ ์ ์คํํ์ธ์:
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
cd ๋ช ๋ น์ด๋ก ์์ ํด๋๋ก ์ด๋ํ์ธ์.
cd examples/instruct_pix2pix
์ด์ ์คํํ์ธ์.
pip install -r requirements.txt
๊ทธ๋ฆฌ๊ณ ๐คAccelerate ํ๊ฒฝ์์ ์ด๊ธฐํํ์ธ์:
accelerate config
ํน์ ํ๊ฒฝ์ ๋ํ ์ง๋ฌธ ์์ด ๊ธฐ๋ณธ์ ์ธ accelerate ๊ตฌ์ฑ์ ์ฌ์ฉํ๋ ค๋ฉด ๋ค์์ ์คํํ์ธ์.
accelerate config default
ํน์ ์ฌ์ฉ ์ค์ธ ํ๊ฒฝ์ด notebook๊ณผ ๊ฐ์ ๋ํํ ์์ ์ง์ํ์ง ์๋ ๊ฒฝ์ฐ๋ ๋ค์ ์ ์ฐจ๋ฅผ ๋ฐ๋ผ์ฃผ์ธ์.
from accelerate.utils import write_basic_config
write_basic_config()
์์
์ด์ ์ ์ธ๊ธํ๋ฏ์ด, ํ์ต์ ์ํด ์์ ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ ๊ฒ์ ๋๋ค. ๊ทธ ๋ฐ์ดํฐ์ ์ InstructPix2Pix ๋ ผ๋ฌธ์์ ์ฌ์ฉ๋ ์๋์ ๋ฐ์ดํฐ์ ๋ณด๋ค ์์ ๋ฒ์ ์ ๋๋ค. ์์ ์ ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ๊ธฐ ์ํด, ํ์ต์ ์ํ ๋ฐ์ดํฐ์ ๋ง๋ค๊ธฐ ๊ฐ์ด๋๋ฅผ ์ฐธ๊ณ ํ์ธ์.
MODEL_NAME
ํ๊ฒฝ ๋ณ์(ํ๋ธ ๋ชจ๋ธ ๋ ํฌ์งํ ๋ฆฌ ๋๋ ๋ชจ๋ธ ๊ฐ์ค์น๊ฐ ํฌํจ๋ ํด๋ ๊ฒฝ๋ก)๋ฅผ ์ง์ ํ๊ณ pretrained_model_name_or_path
์ธ์์ ์ ๋ฌํฉ๋๋ค. DATASET_ID
์ ๋ฐ์ดํฐ์
์ด๋ฆ์ ์ง์ ํด์ผ ํฉ๋๋ค:
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export DATASET_ID="fusing/instructpix2pix-1000-samples"
์ง๊ธ, ํ์ต์ ์คํํ ์ ์์ต๋๋ค. ์คํฌ๋ฆฝํธ๋ ๋ ํฌ์งํ ๋ฆฌ์ ํ์ ํด๋์ ๋ชจ๋ ๊ตฌ์ฑ์์(feature_extractor
, scheduler
, text_encoder
, unet
๋ฑ)๋ฅผ ์ ์ฅํฉ๋๋ค.
accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_ID \
--enable_xformers_memory_efficient_attention \
--resolution=256 --random_flip \
--train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
--max_train_steps=15000 \
--checkpointing_steps=5000 --checkpoints_total_limit=1 \
--learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
--conditioning_dropout_prob=0.05 \
--mixed_precision=fp16 \
--seed=42 \
--push_to_hub
์ถ๊ฐ์ ์ผ๋ก, ๊ฐ์ค์น์ ๋ฐ์ด์ด์ค๋ฅผ ํ์ต ๊ณผ์ ์ ๋ชจ๋ํฐ๋งํ์ฌ ๊ฒ์ฆ ์ถ๋ก ์ ์ํํ๋ ๊ฒ์ ์ง์ํฉ๋๋ค. report_to="wandb"
์ ์ด ๊ธฐ๋ฅ์ ์ฌ์ฉํ ์ ์์ต๋๋ค:
accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_ID \
--enable_xformers_memory_efficient_attention \
--resolution=256 --random_flip \
--train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
--max_train_steps=15000 \
--checkpointing_steps=5000 --checkpoints_total_limit=1 \
--learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
--conditioning_dropout_prob=0.05 \
--mixed_precision=fp16 \
--val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" \
--validation_prompt="make the mountains snowy" \
--seed=42 \
--report_to=wandb \
--push_to_hub
๋ชจ๋ธ ๋๋ฒ๊น
์ ์ ์ฉํ ์ด ํ๊ฐ ๋ฐฉ๋ฒ ๊ถ์ฅํฉ๋๋ค. ์ด๋ฅผ ์ฌ์ฉํ๊ธฐ ์ํด wandb
๋ฅผ ์ค์นํ๋ ๊ฒ์ ์ฃผ๋ชฉํด์ฃผ์ธ์. pip install wandb
๋ก ์คํํด wandb
๋ฅผ ์ค์นํ ์ ์์ต๋๋ค.
์ฌ๊ธฐ, ๋ช ๊ฐ์ง ํ๊ฐ ๋ฐฉ๋ฒ๊ณผ ํ์ต ํ๋ผ๋ฏธํฐ๋ฅผ ํฌํจํ๋ ์์๋ฅผ ๋ณผ ์ ์์ต๋๋ค.
์ฐธ๊ณ : ์๋ณธ ๋ ผ๋ฌธ์์, ์ ์๋ค์ 256x256 ์ด๋ฏธ์ง ํด์๋๋ก ํ์ตํ ๋ชจ๋ธ๋ก 512x512์ ๊ฐ์ ๋ ํฐ ํด์๋๋ก ์ ์ผ๋ฐํ๋๋ ๊ฒ์ ๋ณผ ์ ์์์ต๋๋ค. ์ด๋ ํ์ต์ ์ฌ์ฉํ ํฐ ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ๋๋ค.
๋ค์์ GPU๋ก ํ์ตํ๊ธฐ
accelerate
๋ ์ํํ ๋ค์์ GPU๋ก ํ์ต์ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค. accelerate
๋ก ๋ถ์ฐ ํ์ต์ ์คํํ๋ ์ฌ๊ธฐ ์ค๋ช
์ ๋ฐ๋ผ ํด ์ฃผ์๊ธฐ ๋ฐ๋๋๋ค. ์์์ ๋ช
๋ น์ด ์
๋๋ค:
accelerate launch --mixed_precision="fp16" --multi_gpu train_instruct_pix2pix.py \
--pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 \
--dataset_name=sayakpaul/instructpix2pix-1000-samples \
--use_ema \
--enable_xformers_memory_efficient_attention \
--resolution=512 --random_flip \
--train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
--max_train_steps=15000 \
--checkpointing_steps=5000 --checkpoints_total_limit=1 \
--learning_rate=5e-05 --lr_warmup_steps=0 \
--conditioning_dropout_prob=0.05 \
--mixed_precision=fp16 \
--seed=42 \
--push_to_hub
์ถ๋ก ํ๊ธฐ
์ผ๋จ ํ์ต์ด ์๋ฃ๋๋ฉด, ์ถ๋ก ํ ์ ์์ต๋๋ค:
import PIL
import requests
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline
model_id = "your_model_id" # <- ์ด๋ฅผ ์์ ํ์ธ์.
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
generator = torch.Generator("cuda").manual_seed(0)
url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/test_pix2pix_4.png"
def download_image(url):
image = PIL.Image.open(requests.get(url, stream=True).raw)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image
image = download_image(url)
prompt = "wipe out the lake"
num_inference_steps = 20
image_guidance_scale = 1.5
guidance_scale = 10
edited_image = pipe(
prompt,
image=image,
num_inference_steps=num_inference_steps,
image_guidance_scale=image_guidance_scale,
guidance_scale=guidance_scale,
generator=generator,
).images[0]
edited_image.save("edited_image.png")
ํ์ต ์คํฌ๋ฆฝํธ๋ฅผ ์ฌ์ฉํด ์ป์ ์์์ ๋ชจ๋ธ ๋ ํฌ์งํ ๋ฆฌ๋ ์ฌ๊ธฐ sayakpaul/instruct-pix2pix์์ ํ์ธํ ์ ์์ต๋๋ค.
์ฑ๋ฅ์ ์ํ ์๋์ ํ์ง์ ์ ์ดํ๊ธฐ ์ํด ์ธ ๊ฐ์ง ํ๋ผ๋ฏธํฐ๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข์ต๋๋ค:
num_inference_steps
image_guidance_scale
guidance_scale
ํนํ, image_guidance_scale
์ guidance_scale
๋ ์์ฑ๋("์์ ๋") ์ด๋ฏธ์ง์์ ํฐ ์ํฅ์ ๋ฏธ์น ์ ์์ต๋๋ค.(์ฌ๊ธฐ์์๋ฅผ ์ฐธ๊ณ ํด์ฃผ์ธ์.)
๋ง์ฝ InstructPix2Pix ํ์ต ๋ฐฉ๋ฒ์ ์ฌ์ฉํด ๋ช ๊ฐ์ง ํฅ๋ฏธ๋ก์ด ๋ฐฉ๋ฒ์ ์ฐพ๊ณ ์๋ค๋ฉด, ์ด ๋ธ๋ก๊ทธ ๊ฒ์๋ฌผInstruction-tuning Stable Diffusion with InstructPix2Pix์ ํ์ธํด์ฃผ์ธ์.