jmemon's picture
Upload DDPMPipeline
d56e039
from pathlib import Path
import PIL
from tqdm import tqdm
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDPMPipeline, UNet2DModel, DDPMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers.utils import make_image_grid
from huggingface_hub import create_repo, upload_folder
from peft import LoraConfig, get_peft_model
import torch
import torch.nn.functional as F
from torchvision import transforms
from config import TrainingConfig
"""
Or diffusion for simple images and explore subtly different
x_T's and what the output is.
Denoise each x_T multiple times to get a better picture of the distribution.
Maybe use a set sequence of seeds for every denoising run (torch.Generator(seed=__)).
Inter-concept space. Conciousness.
"""
def evaluate(config, epoch, pipeline):
# Sample some images from random noise (this is the backward diffusion process).
# The default pipeline output type is `List[PIL.Image]`
images = pipeline(
batch_size=config.eval_batch_size,
generator=torch.manual_seed(config.seed),
num_inference_steps=50
).images
# Make a grid out of the images
image_grid = make_image_grid(images, rows=2, cols=2)
# Save the images
test_dir = Path(config.output_dir) / 'samples'
test_dir.mkdir(exist_ok=True)
image_grid.save(test_dir / f'{epoch:04d}.png')
def print_trainable_parameters(model):
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
)
if __name__ == '__main__':
config = TrainingConfig()
config.dataset_name = 'keremberke/painting-style-classification'
ds_dict = load_dataset(config.dataset_name, name='full')
preprocess = transforms.Compose([
transforms.Resize((config.image_size, config.image_size)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
def transform(examples):
return {
'images': [preprocess(img.convert('RGB')) for img in examples['image']]
}
ds_dict.set_transform(transform) # automatically applies preprocessing to samples as we load them
train_dataloader = torch.utils.data.DataLoader(ds_dict['train'], batch_size=config.train_batch_size, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(ds_dict['validation'], batch_size=config.eval_batch_size, shuffle=False)
test_dataloader = torch.utils.data.DataLoader(ds_dict['test'], batch_size=config.eval_batch_size, shuffle=False)
"""
unet = UNet2DModel.from_pretrained(
'google/ddpm-celebahq-256'
).to('mps')
scheduler = DDPMScheduler.from_pretrained(
'google/ddpm-celebahq-256'
)
"""
"""
unet = UNet2DModel.from_pretrained(
'jmemon/ddpm-paintings-128-finetuned-celebahq',
use_safetensors=True
).to('mps')
scheduler = DDPMScheduler.from_pretrained(
'jmemon/ddpm-paintings-128-finetuned-celebahq'
)
"""
unet = UNet2DModel.from_pretrained(
str(Path(__file__).parent / 'unet'),
use_safetensors=True
).to('mps')
scheduler = DDPMScheduler.from_pretrained(
str(Path(__file__).parent / 'scheduler')
)
lora_config = LoraConfig(
r=8,
lora_alpha=8,
target_modules=['to_k','to_v'],
lora_dropout=0.1,
bias='none')
lora_unet = get_peft_model(unet, lora_config)
print_trainable_parameters(lora_unet)
optimizer = torch.optim.AdamW(lora_unet.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps,
num_training_steps=(len(train_dataloader) * config.num_epochs)
)
accelerator = Accelerator(
gradient_accumulation_steps=config.gradient_accumulation_steps,
mixed_precision=config.mixed_precision,
log_with='tensorboard',
project_dir=Path(config.output_dir) / 'logs'
)
if accelerator.is_main_process:
if config.push_to_hub:
repo_id = create_repo(repo_id=config.hub_model_id, exist_ok=True).repo_id
accelerator.init_trackers('ddpm-paintings-128-finetuned-celebahq')
global_step = 0
for epoch in range(6, config.num_epochs + 6):
pbar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
pbar.set_description(f'Epoch {epoch}')
for idx, batch in enumerate(train_dataloader):
clean_images = batch['images'].to('mps')
noise = torch.randn(clean_images.shape, device=clean_images.device)
bs = clean_images.shape[0]
ts = torch.randint(0, scheduler.config.num_train_timesteps, (bs,), device=clean_images.device, dtype=torch.int64)
noisy_images = scheduler.add_noise(clean_images, noise, ts)
with accelerator.accumulate(lora_unet):
noise_pred = lora_unet(noisy_images, ts, return_dict=False)[0]
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)
accelerator.clip_grad_norm_(lora_unet.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
logs = {'loss': loss.detach().item(), 'lr': lr_scheduler.get_last_lr()[0], 'step': global_step}
pbar.update(1)
pbar.set_postfix(loss=logs['loss'], step=idx + 1)
accelerator.log(logs, step=global_step)
global_step += 1
pbar.close()
if accelerator.is_main_process:
#pipeline = DDPMPipeline(unet=accelerator.unwrap_model(lora_unet).merge_and_unload(), scheduler=scheduler)
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(lora_unet), scheduler=scheduler)
if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
# Save some images for model trained at end of epoch
evaluate(config, epoch, pipeline)
if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
_pipeline = DDPMPipeline(
unet=accelerator.unwrap_model(lora_unet).merge_and_unload(),
scheduler=scheduler)
if config.push_to_hub:
_pipeline.save_pretrained(
config.output_dir,
push_to_hub=True,
repo_id=repo_id,
token='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'
)
upload_folder(
repo_id=repo_id,
folder_path=config.output_dir,
commit_message=f'Epoch {epoch}',
ignore_patterns=['logs/*', '*/.DS_Store'],
token='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'
)
model_loc = 'jmemon/ddpm-paintings-128-finetuned-celebahq'
else:
_pipeline.save_pretrained(config.output_dir)
model_loc = str(Path(__file__).parent / 'diffusion_model_pytorch.safetensors')
unet = UNet2DModel.from_pretrained(model_loc, use_safetensors=True)
lora_unet = get_peft_model(unet, lora_config)