import os import math import torch import numpy as np import matplotlib.pyplot as plt from torch.utils.data import DataLoader from torch.optim.lr_scheduler import LambdaLR from diffusers import UNet2DConditionModel, AutoencoderKL from accelerate import Accelerator from datasets import load_from_disk from tqdm import tqdm from PIL import Image import wandb import random # --------------------------- Параметры --------------------------- save_path = "datasets/mnist" batch_size = 320 base_learning_rate = 5e-5 num_epochs = 10 gradient_accumulation_steps = 1 project = "sdxs" use_wandb = True limit = 0 grad_clip = 0.1 # Параметры для диффузии n_diffusion_steps = 20 # число шагов в цепочке диффузии (при генерации) samples_to_generate = 6 # Папки для сохранения результатов generated_folder = "samples" checkpoints_folder = "" os.makedirs(generated_folder, exist_ok=True) seed = 42 torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) print("init") # --------------------------- Инициализация Accelerator --------------------------- dtype = torch.bfloat16 accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps,mixed_precision="bf16" if dtype == torch.bfloat16 else "no") device = accelerator.device#"cuda" if torch.cuda.is_available() else "cpu" # --------------------------- Инициализация WandB --------------------------- if use_wandb and accelerator.is_main_process: wandb.init(project=project, config={ "batch_size": batch_size, "base_learning_rate": base_learning_rate, "num_epochs": num_epochs, "gradient_accumulation_steps": gradient_accumulation_steps, "n_diffusion_steps": n_diffusion_steps, "samples_to_generate": samples_to_generate, "dtype": str(dtype) }) # --------------------------- Загрузка датасета --------------------------- if limit>0: dataset = load_from_disk(save_path).select(range(limit)) else: dataset = load_from_disk(save_path) def collate_fn(batch): # Преобразуем список в тензоры и перемещаем на девайс latents = torch.tensor(np.array([item["vae"] for item in batch]), dtype=dtype).to(device) embeddings = torch.tensor(np.array([item["embeddings"] for item in batch]), dtype=dtype).to(device) return latents, embeddings dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) dataloader = accelerator.prepare(dataloader) # --------------------------- Загрузка моделей --------------------------- # VAE загружается на CPU для экономии GPU-памяти vae = AutoencoderKL.from_pretrained("AuraDiffusion/16ch-vae").to("cpu", dtype=dtype) # Загружаем UNet (на устройство) from diffusers import DDPMScheduler # DDPMScheduler с V_Prediction и Zero-SNR scheduler = DDPMScheduler(#DDPMScheduler( num_train_timesteps=1000, # Полный график шагов для обучения prediction_type="v_prediction", # V-Prediction rescale_betas_zero_snr=True # Включение Zero-SNR ) # Опция загрузки модели из последнего чекпоинта (если существует) latest_checkpoint = os.path.join(checkpoints_folder, project)#check if os.path.isdir(latest_checkpoint): print("Загружаем UNet из чекпоинта:", latest_checkpoint) unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device, dtype=dtype) # Подготовка через Accelerator unet = accelerator.prepare(unet) # Расчёт общего количества шагов (учитывая gradient_accumulation_steps) total_training_steps = len(dataloader) * num_epochs // gradient_accumulation_steps # Get the world size world_size = accelerator.state.num_processes print(f"World Size: {world_size}") # --------------------------- Оптимизатор и кастомный LR scheduler --------------------------- optimizer = torch.optim.AdamW(unet.parameters(), lr=base_learning_rate) def lr_schedule(step, max_steps, base_lr): # Прогрев до 10% от max_steps, затем косинусное затухание до 0.1 * base_lr x = step / max_steps if x < 0.1: return 0.1 * base_lr + 0.9 * base_lr * (x / 0.1) else: return 0.1 * base_lr + 0.9 * base_lr * (1 + math.cos(math.pi * (x - 0.1) / (1 - 0.1))) / 2 def custom_lr_lambda(step): return lr_schedule(step, total_training_steps*world_size, base_learning_rate) / base_learning_rate lr_scheduler = LambdaLR(optimizer, lr_lambda=custom_lr_lambda) optimizer, lr_scheduler = accelerator.prepare(optimizer, lr_scheduler) # --------------------------- Фиксированные семплы для генерации --------------------------- # Выбираем фиксированные индексы (приводим к int) sample_indices = np.random.choice(len(dataset), samples_to_generate, replace=False) sample_data = [dataset[int(i)] for i in sample_indices] def get_sample_inputs(sample_data): # Получаем latents и текстовые эмбеддинги для выбранных примеров (collate_fn переносит на device) return collate_fn(sample_data) @torch.no_grad() def generate_and_save_samples(step: int): """ Перемещает VAE на device, генерирует семплы с заданным числом diffusion шагов, декодирует их и отправляет на WandB по одному изображению. После генерации VAE возвращается на CPU. """ try: # Перемещаем VAE на device для семплирования vae.to(accelerator.device, dtype=dtype) # Устанавливаем количество diffusion шагов scheduler.set_timesteps(n_diffusion_steps) # Получаем фиксированные данные sample_latents, sample_text_embeddings = get_sample_inputs(sample_data) sample_latents = sample_latents.to(dtype) # Инициализируем латенты случайным шумом sample_latents = torch.randn_like(sample_latents) with torch.no_grad(): for t in scheduler.timesteps: latent_model_input = scheduler.scale_model_input(sample_latents, t) noise_pred = unet(latent_model_input, t, sample_text_embeddings).sample sample_latents = scheduler.step(noise_pred, t, sample_latents).prev_sample #decoded = vae.decode(sample_latents).sample latent = (sample_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor latent = latent.to(accelerator.device, dtype=dtype) decoded = vae.decode(latent).sample # Преобразуем тензоры в PIL-изображения generated_images = [] for img_idx, img_tensor in enumerate(decoded): img = (img_tensor.to(torch.float32) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0) pil_img = Image.fromarray((img * 255).astype("uint8")) generated_images.append(pil_img) save_path = f"{generated_folder}/{project}{img_idx}.jpg" pil_img.save(save_path, "JPEG", quality=95) # Отправляем изображения на WandB if use_wandb: wandb_images = [wandb.Image(img, caption=f"Sample {i}") for i, img in enumerate(generated_images)] wandb.log({"generated_images": wandb_images, "global_step": step}) finally: # Гарантированное перемещение VAE обратно на CPU vae.to("cpu") torch.cuda.empty_cache() # Очищаем кэш CUDA # --------------------------- Генерация сэмплов перед обучением --------------------------- if accelerator.is_main_process: print("Генерация сэмплов до старта обучения...") generate_and_save_samples(step=0) # --------------------------- Тренировочный цикл --------------------------- global_step = 0 # Для логирования среднего лосса каждые 10% эпохи if accelerator.is_main_process: print(f"Total steps per GPU: {total_training_steps}") print(f"[GPU {accelerator.process_index}] Total steps: {total_training_steps}") epoch_loss_points = [] progress_bar = tqdm(total=total_training_steps,disable=not accelerator.is_local_main_process, desc="Training", unit="step") # Определяем интервал для сэмплирования и логирования в пределах эпохи (10% эпохи) steps_per_epoch = len(dataloader)// gradient_accumulation_steps sample_interval = max(1, steps_per_epoch // 10) for epoch in range(num_epochs): batch_losses = [] unet.train() for step, (latents, embeddings) in enumerate(dataloader): #optimizer.zero_grad() with accelerator.accumulate(unet): # Forward pass noise = torch.randn_like(latents) timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()#long? noisy_latents = scheduler.add_noise(latents, noise, timesteps) noise_pred = unet(noisy_latents, timesteps, embeddings).sample target = scheduler.get_velocity(latents, noise, timesteps) with torch.autocast(device_type="cuda", dtype=torch.bfloat16): loss = torch.nn.functional.mse_loss(noise_pred, target) accelerator.backward(loss) if accelerator.sync_gradients: grad_norm = torch.nn.utils.clip_grad_norm_(unet.parameters(), grad_clip) optimizer.step() lr_scheduler.step() optimizer.zero_grad() batch_losses.append(loss.item()) global_step += 1 progress_bar.update(1) # Логирование в W&B if accelerator.is_main_process: if use_wandb: wandb.log({"loss": loss.item(), "learning_rate": optimizer.param_groups[0]["lr"], "epoch": epoch}) # Логирование каждые sample_interval шагов (примерно 10% эпохи) if (step + 1) % sample_interval == 0: avg_loss = np.mean(batch_losses[-sample_interval:]) if use_wandb: wandb.log({"intermediate_loss": avg_loss, "grad_norm": grad_norm.item()}) # Также проводим генерацию сэмплов unet.eval() generate_and_save_samples(step=global_step) unet.train() # Средний лосс за эпоху accelerator.wait_for_everyone() # Дождаться завершения всех процессов if accelerator.is_main_process: print(f"[GPU {accelerator.process_index}] Max memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f} GB") avg_epoch_loss = np.mean(batch_losses) epoch_loss_points.append(avg_epoch_loss) print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss:.6f}") if use_wandb: wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1}) # Сохранение модели после каждой эпохи (перезаписываем последний чекпоинт) checkpoint_path = os.path.join(checkpoints_folder, project) # Получаем оригинальную модель из DistributedDataParallel original_model = unet.module if hasattr(unet, 'module') else unet original_model.save_pretrained(checkpoint_path) print(f"Чекпоинт сохранён: {checkpoint_path}") progress_bar.close() # Генерация сэмплов по окончании обучения accelerator.wait_for_everyone() if accelerator.is_main_process: print("Генерация сэмплов после окончания обучения...") generate_and_save_samples(step=global_step) # --------------------------- Построение графика лосса --------------------------- plt.figure() plt.plot(np.arange(1, num_epochs+1), epoch_loss_points, marker='o') plt.xlabel("Epoch") plt.ylabel("Average Loss") plt.title("Training Loss Curve") plt.grid() plt.savefig("training_loss.png", dpi=300) print("Training loss curve saved to training_loss.png") # --------------------------- Сохранение итоговой модели --------------------------- checkpoint_path = os.path.join(checkpoints_folder, project) original_model = unet.module if hasattr(unet, 'module') else unet original_model.save_pretrained(checkpoint_path) print("Итоговая модель сохранена в", checkpoint_path) if use_wandb: wandb.log({"training_loss_curve": wandb.Image("training_loss.png")}) wandb.save(checkpoint_path) wandb.finish() # --------------------------- Очистка памяти --------------------------- del unet, vae, optimizer, lr_scheduler, dataloader torch.cuda.empty_cache() print("Память очищена.")