|
import os |
|
import math |
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from torch.utils.data import DataLoader |
|
from torch.optim.lr_scheduler import LambdaLR |
|
from diffusers import UNet2DConditionModel, AutoencoderKL |
|
from accelerate import Accelerator |
|
from datasets import load_from_disk |
|
from tqdm import tqdm |
|
from PIL import Image |
|
import wandb |
|
import random |
|
|
|
|
|
save_path = "datasets/mnist" |
|
batch_size = 320 |
|
base_learning_rate = 5e-5 |
|
num_epochs = 10 |
|
gradient_accumulation_steps = 1 |
|
project = "sdxs" |
|
use_wandb = True |
|
limit = 0 |
|
grad_clip = 0.1 |
|
|
|
|
|
n_diffusion_steps = 20 |
|
samples_to_generate = 6 |
|
|
|
|
|
generated_folder = "samples" |
|
checkpoints_folder = "" |
|
os.makedirs(generated_folder, exist_ok=True) |
|
|
|
seed = 42 |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
print("init") |
|
|
|
dtype = torch.bfloat16 |
|
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps,mixed_precision="bf16" if dtype == torch.bfloat16 else "no") |
|
device = accelerator.device |
|
|
|
if use_wandb and accelerator.is_main_process: |
|
wandb.init(project=project, config={ |
|
"batch_size": batch_size, |
|
"base_learning_rate": base_learning_rate, |
|
"num_epochs": num_epochs, |
|
"gradient_accumulation_steps": gradient_accumulation_steps, |
|
"n_diffusion_steps": n_diffusion_steps, |
|
"samples_to_generate": samples_to_generate, |
|
"dtype": str(dtype) |
|
}) |
|
|
|
|
|
if limit>0: |
|
dataset = load_from_disk(save_path).select(range(limit)) |
|
else: |
|
dataset = load_from_disk(save_path) |
|
|
|
def collate_fn(batch): |
|
|
|
latents = torch.tensor(np.array([item["vae"] for item in batch]), dtype=dtype).to(device) |
|
embeddings = torch.tensor(np.array([item["embeddings"] for item in batch]), dtype=dtype).to(device) |
|
return latents, embeddings |
|
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) |
|
dataloader = accelerator.prepare(dataloader) |
|
|
|
|
|
vae = AutoencoderKL.from_pretrained("AuraDiffusion/16ch-vae").to("cpu", dtype=dtype) |
|
|
|
|
|
from diffusers import DDPMScheduler |
|
|
|
scheduler = DDPMScheduler( |
|
num_train_timesteps=1000, |
|
prediction_type="v_prediction", |
|
rescale_betas_zero_snr=True |
|
) |
|
|
|
|
|
latest_checkpoint = os.path.join(checkpoints_folder, project) |
|
if os.path.isdir(latest_checkpoint): |
|
print("Загружаем UNet из чекпоинта:", latest_checkpoint) |
|
unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device, dtype=dtype) |
|
|
|
|
|
unet = accelerator.prepare(unet) |
|
|
|
total_training_steps = len(dataloader) * num_epochs // gradient_accumulation_steps |
|
|
|
world_size = accelerator.state.num_processes |
|
print(f"World Size: {world_size}") |
|
|
|
|
|
optimizer = torch.optim.AdamW(unet.parameters(), lr=base_learning_rate) |
|
def lr_schedule(step, max_steps, base_lr): |
|
|
|
x = step / max_steps |
|
if x < 0.1: |
|
return 0.1 * base_lr + 0.9 * base_lr * (x / 0.1) |
|
else: |
|
return 0.1 * base_lr + 0.9 * base_lr * (1 + math.cos(math.pi * (x - 0.1) / (1 - 0.1))) / 2 |
|
|
|
def custom_lr_lambda(step): |
|
return lr_schedule(step, total_training_steps*world_size, base_learning_rate) / base_learning_rate |
|
|
|
lr_scheduler = LambdaLR(optimizer, lr_lambda=custom_lr_lambda) |
|
|
|
optimizer, lr_scheduler = accelerator.prepare(optimizer, lr_scheduler) |
|
|
|
|
|
|
|
sample_indices = np.random.choice(len(dataset), samples_to_generate, replace=False) |
|
sample_data = [dataset[int(i)] for i in sample_indices] |
|
|
|
def get_sample_inputs(sample_data): |
|
|
|
return collate_fn(sample_data) |
|
|
|
@torch.no_grad() |
|
def generate_and_save_samples(step: int): |
|
""" |
|
Перемещает VAE на device, генерирует семплы с заданным числом diffusion шагов, |
|
декодирует их и отправляет на WandB по одному изображению. |
|
После генерации VAE возвращается на CPU. |
|
""" |
|
try: |
|
|
|
vae.to(accelerator.device, dtype=dtype) |
|
|
|
|
|
scheduler.set_timesteps(n_diffusion_steps) |
|
|
|
|
|
sample_latents, sample_text_embeddings = get_sample_inputs(sample_data) |
|
sample_latents = sample_latents.to(dtype) |
|
|
|
|
|
sample_latents = torch.randn_like(sample_latents) |
|
|
|
with torch.no_grad(): |
|
for t in scheduler.timesteps: |
|
latent_model_input = scheduler.scale_model_input(sample_latents, t) |
|
noise_pred = unet(latent_model_input, t, sample_text_embeddings).sample |
|
sample_latents = scheduler.step(noise_pred, t, sample_latents).prev_sample |
|
|
|
|
|
latent = (sample_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor |
|
latent = latent.to(accelerator.device, dtype=dtype) |
|
decoded = vae.decode(latent).sample |
|
|
|
|
|
generated_images = [] |
|
for img_idx, img_tensor in enumerate(decoded): |
|
img = (img_tensor.to(torch.float32) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0) |
|
pil_img = Image.fromarray((img * 255).astype("uint8")) |
|
generated_images.append(pil_img) |
|
save_path = f"{generated_folder}/{project}{img_idx}.jpg" |
|
pil_img.save(save_path, "JPEG", quality=95) |
|
|
|
|
|
if use_wandb: |
|
wandb_images = [wandb.Image(img, caption=f"Sample {i}") for i, img in enumerate(generated_images)] |
|
wandb.log({"generated_images": wandb_images, "global_step": step}) |
|
|
|
finally: |
|
|
|
vae.to("cpu") |
|
torch.cuda.empty_cache() |
|
|
|
|
|
if accelerator.is_main_process: |
|
print("Генерация сэмплов до старта обучения...") |
|
generate_and_save_samples(step=0) |
|
|
|
|
|
global_step = 0 |
|
|
|
|
|
if accelerator.is_main_process: |
|
print(f"Total steps per GPU: {total_training_steps}") |
|
print(f"[GPU {accelerator.process_index}] Total steps: {total_training_steps}") |
|
epoch_loss_points = [] |
|
progress_bar = tqdm(total=total_training_steps,disable=not accelerator.is_local_main_process, desc="Training", unit="step") |
|
|
|
|
|
steps_per_epoch = len(dataloader)// gradient_accumulation_steps |
|
sample_interval = max(1, steps_per_epoch // 10) |
|
|
|
for epoch in range(num_epochs): |
|
batch_losses = [] |
|
unet.train() |
|
for step, (latents, embeddings) in enumerate(dataloader): |
|
|
|
with accelerator.accumulate(unet): |
|
|
|
noise = torch.randn_like(latents) |
|
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long() |
|
noisy_latents = scheduler.add_noise(latents, noise, timesteps) |
|
|
|
noise_pred = unet(noisy_latents, timesteps, embeddings).sample |
|
target = scheduler.get_velocity(latents, noise, timesteps) |
|
with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
|
loss = torch.nn.functional.mse_loss(noise_pred, target) |
|
|
|
accelerator.backward(loss) |
|
|
|
if accelerator.sync_gradients: |
|
grad_norm = torch.nn.utils.clip_grad_norm_(unet.parameters(), grad_clip) |
|
optimizer.step() |
|
lr_scheduler.step() |
|
optimizer.zero_grad() |
|
batch_losses.append(loss.item()) |
|
global_step += 1 |
|
progress_bar.update(1) |
|
|
|
|
|
if accelerator.is_main_process: |
|
if use_wandb: |
|
wandb.log({"loss": loss.item(), "learning_rate": optimizer.param_groups[0]["lr"], "epoch": epoch}) |
|
|
|
|
|
if (step + 1) % sample_interval == 0: |
|
avg_loss = np.mean(batch_losses[-sample_interval:]) |
|
if use_wandb: |
|
wandb.log({"intermediate_loss": avg_loss, "grad_norm": grad_norm.item()}) |
|
|
|
unet.eval() |
|
generate_and_save_samples(step=global_step) |
|
unet.train() |
|
|
|
|
|
accelerator.wait_for_everyone() |
|
if accelerator.is_main_process: |
|
print(f"[GPU {accelerator.process_index}] Max memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f} GB") |
|
avg_epoch_loss = np.mean(batch_losses) |
|
epoch_loss_points.append(avg_epoch_loss) |
|
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss:.6f}") |
|
if use_wandb: |
|
wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1}) |
|
|
|
|
|
checkpoint_path = os.path.join(checkpoints_folder, project) |
|
|
|
original_model = unet.module if hasattr(unet, 'module') else unet |
|
original_model.save_pretrained(checkpoint_path) |
|
print(f"Чекпоинт сохранён: {checkpoint_path}") |
|
|
|
progress_bar.close() |
|
|
|
|
|
accelerator.wait_for_everyone() |
|
if accelerator.is_main_process: |
|
print("Генерация сэмплов после окончания обучения...") |
|
generate_and_save_samples(step=global_step) |
|
|
|
|
|
plt.figure() |
|
plt.plot(np.arange(1, num_epochs+1), epoch_loss_points, marker='o') |
|
plt.xlabel("Epoch") |
|
plt.ylabel("Average Loss") |
|
plt.title("Training Loss Curve") |
|
plt.grid() |
|
plt.savefig("training_loss.png", dpi=300) |
|
print("Training loss curve saved to training_loss.png") |
|
|
|
|
|
checkpoint_path = os.path.join(checkpoints_folder, project) |
|
original_model = unet.module if hasattr(unet, 'module') else unet |
|
original_model.save_pretrained(checkpoint_path) |
|
|
|
print("Итоговая модель сохранена в", checkpoint_path) |
|
if use_wandb: |
|
wandb.log({"training_loss_curve": wandb.Image("training_loss.png")}) |
|
wandb.save(checkpoint_path) |
|
wandb.finish() |
|
|
|
|
|
del unet, vae, optimizer, lr_scheduler, dataloader |
|
torch.cuda.empty_cache() |
|
print("Память очищена.") |
|
|
|
|