import os import math import torch import numpy as np import matplotlib.pyplot as plt from torch.utils.data import DataLoader from torch.optim.lr_scheduler import LambdaLR from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler from accelerate import Accelerator from datasets import load_from_disk from tqdm import tqdm from PIL import Image import wandb import random import gc from accelerate.state import DistributedType from torch.distributed import broadcast_object_list #from adamw_bfloat16 import LR, AdamW_BF16 # --------------------------- Параметры --------------------------- save_path = "datasets/siski384"#"datasets/siski64" #"datasets/mnist" batch_size = 5 base_learning_rate = 8e-5 min_learning_rate = 2e-5 num_epochs = 36 #18 project = "sdxs" use_wandb = True limit = 0 checkpoint_file = "full_checkpoint.pt" reset_lroptim = True saveckpt = False save_model = True # Параметры для диффузии n_diffusion_steps = 50 # Увеличиваем число шагов для лучшего качества samples_to_generate = 6 # Папки для сохранения результатов generated_folder = "samples" checkpoints_folder = "" os.makedirs(generated_folder, exist_ok=True) # Настройка seed для воспроизводимости seed = 42 torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) print("init") # --------------------------- Инициализация Accelerator --------------------------- dtype = torch.bfloat16 accelerator = Accelerator(mixed_precision="bf16") device = accelerator.device gen = torch.Generator(device=device) gen.manual_seed(42) # --------------------------- Инициализация WandB --------------------------- if use_wandb and accelerator.is_main_process: wandb.init(project=project, config={ "batch_size": batch_size, "base_learning_rate": base_learning_rate, "num_epochs": num_epochs, "n_diffusion_steps": n_diffusion_steps, "samples_to_generate": samples_to_generate, "dtype": str(dtype) }) # --------------------------- Загрузка датасета --------------------------- if limit > 0: dataset = load_from_disk(save_path).select(range(limit)) else: dataset = load_from_disk(save_path) def collate_fn(batch): # Преобразуем список в тензоры и перемещаем на девайс latents = torch.tensor(np.array([item["vae"] for item in batch]), dtype=dtype).to(device) embeddings = torch.tensor(np.array([item["embeddings"] for item in batch]), dtype=dtype).to(device) return latents, embeddings dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) print("Total samples",len(dataloader)) dataloader = accelerator.prepare(dataloader) # --------------------------- Загрузка моделей --------------------------- # VAE загружается на CPU для экономии GPU-памяти vae = AutoencoderKL.from_pretrained("AuraDiffusion/16ch-vae").to("cpu", dtype=dtype) # DDPMScheduler с V_Prediction и Zero-SNR scheduler = DDPMScheduler( num_train_timesteps=1000, # Полный график шагов для обучения prediction_type="v_prediction", # V-Prediction rescale_betas_zero_snr=True, # Включение Zero-SNR timestep_spacing="leading", # Добавляем улучшенное распределение шагов steps_offset=1 # Избегаем проблем с нулевым timestep ) # Инициализация переменных для возобновления обучения start_epoch = 0 global_step = 0 # Расчёт общего количества шагов total_training_steps = (len(dataloader) * num_epochs) # Get the world size world_size = accelerator.state.num_processes print(f"World Size: {world_size}") # Опция загрузки модели из последнего чекпоинта (если существует) latest_checkpoint = os.path.join(checkpoints_folder, project) if os.path.isdir(latest_checkpoint): print("Загружаем UNet из чекпоинта:", latest_checkpoint) unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device, dtype=dtype) # --------------------------- Оптимизатор и кастомный LR scheduler --------------------------- # Улучшенный AdamW с правильными параметрами from optimi import AdamW optimizer = torch.optim.AdamW( unet.parameters(), lr=base_learning_rate, betas=(0.9, 0.999), weight_decay=1e-6, # Безопасное значение для SD eps=1e-8 ) #optimizer = AdamW_BF16(unet.parameters(), lr_function=LR(lr=base_learning_rate, preheat_steps=5000, decay_power=-0.25)) # Улучшенный планировщик с фиксированной скоростью обучения для случаев возобновления # и возможностью плавного затухания def lr_schedule(step, max_steps, base_lr, min_lr, use_decay=True): # Если не используем затухание, возвращаем базовый LR if not use_decay: return base_lr # Иначе используем линейный прогрев и косинусное затухание x = step / max_steps if x < 0.1: # Линейный прогрев до 10% шагов return min_lr + (base_lr - min_lr) * (x / 0.1) else: # Косинусное затухание decay_ratio = (x - 0.1) / (1 - 0.1) return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(math.pi * decay_ratio)) # Флаг для контроля использования затухания LR, вы можете менять его при возобновлении обучения use_lr_decay = True # Устанавливайте в False, если хотите отключить затухание def custom_lr_lambda(step): return lr_schedule(step, total_training_steps*world_size, base_learning_rate, min_learning_rate, use_lr_decay) / base_learning_rate lr_scheduler = LambdaLR(optimizer, lr_lambda=custom_lr_lambda) # Подготовка через Accelerator unet, optimizer,lr_scheduler = accelerator.prepare(unet, optimizer,lr_scheduler) # Загрузка полного чекпоинта если существует checkpoint_path = os.path.join(checkpoints_folder, checkpoint_file) if os.path.isfile(checkpoint_path) and accelerator.is_main_process: print(f"Загружаем полный чекпоинт: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location="cpu") else: checkpoint = None # Если задействовано распределённое обучение, рассылаем объект чекпоинта всем процессам. if accelerator.state.distributed_type != DistributedType.NO: checkpoint_list = [checkpoint] broadcast_object_list(checkpoint_list, src=0) checkpoint = checkpoint_list[0] if checkpoint is not None: # Загружаем состояния accelerator.unwrap_model(unet).load_state_dict(checkpoint['model_state_dict']) #if not reset_lroptim: # Пересоздаем оптимизатор с текущим learning rate # optimizer.param_groups[0]['lr'] = base_learning_rate # optimizer = torch.optim.AdamW( # unet.parameters(), # lr=base_learning_rate, # Можно изменить при возобновлении # betas=(0.9, 0.999), # weight_decay=1e-6, # eps=1e-8 # ) # print(f"optimizer learning rate: {optimizer.param_groups[0]['lr']}") # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) #lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict']) start_epoch = checkpoint['epoch'] + 1 global_step = checkpoint['global_step'] print(f"Продолжаем с эпохи {start_epoch}, шага {global_step}") # --------------------------- Фиксированные семплы для генерации --------------------------- # Выбираем фиксированные индексы (приводим к int) sample_indices = np.random.choice(len(dataset), samples_to_generate, replace=False) sample_data = [dataset[int(i)] for i in sample_indices] def get_sample_inputs(sample_data): # Получаем latents и текстовые эмбеддинги для выбранных примеров (collate_fn переносит на device) return collate_fn(sample_data) @torch.no_grad() def generate_and_save_samples(step: int): """ Перемещает VAE на device, генерирует семплы с заданным числом diffusion шагов, декодирует их и отправляет на WandB. После генерации VAE возвращается на CPU. """ try: original_model = accelerator.unwrap_model(unet) # Перемещаем VAE на device для семплирования vae.to(accelerator.device, dtype=dtype) # Устанавливаем количество diffusion шагов scheduler.set_timesteps(n_diffusion_steps) # Получаем фиксированные данные sample_latents, sample_text_embeddings = get_sample_inputs(sample_data) sample_latents = sample_latents.to(dtype) # Инициализируем латенты случайным шумом sample_latents = torch.randn( sample_latents.shape, generator=gen, device=sample_latents.device, dtype=sample_latents.dtype ) # Добавляем CFG для лучшего качества guidance_scale = 3.0 # Генерация изображений with torch.no_grad(): for t in scheduler.timesteps: latent_model_input = scheduler.scale_model_input(sample_latents, t) noise_pred = original_model(latent_model_input, t, sample_text_embeddings).sample # Шаг диффузии sample_latents = scheduler.step(noise_pred, t, sample_latents).prev_sample # Декодирование через VAE latent = (sample_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor latent = latent.to(accelerator.device, dtype=dtype) decoded = vae.decode(latent).sample # Преобразуем тензоры в PIL-изображения generated_images = [] for img_idx, img_tensor in enumerate(decoded): img = (img_tensor.to(torch.float32) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0) pil_img = Image.fromarray((img * 255).astype("uint8")) generated_images.append(pil_img) save_path = f"{generated_folder}/{project}_{img_idx}.jpg" pil_img.save(save_path, "JPEG", quality=95) # Отправляем изображения на WandB if use_wandb and accelerator.is_main_process: wandb_images = [wandb.Image(img, caption=f"Sample {i}") for i, img in enumerate(generated_images)] wandb.log({"generated_images": wandb_images, "global_step": step}) finally: # Гарантированное перемещение VAE обратно на CPU vae.to("cpu") if original_model is not None: del original_model torch.cuda.empty_cache() # Очищаем кэш CUDA gc.collect() # --------------------------- Генерация сэмплов перед обучением --------------------------- if accelerator.is_main_process: print("Генерация сэмплов до старта обучения...") generate_and_save_samples(step=0) # --------------------------- Тренировочный цикл --------------------------- # Для логирования среднего лосса каждые 10% эпохи if accelerator.is_main_process: print(f"Total steps per GPU: {total_training_steps}") print(f"[GPU {accelerator.process_index}] Total steps: {total_training_steps}") epoch_loss_points = [] progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step") # Определяем интервал для сэмплирования и логирования в пределах эпохи (10% эпохи) steps_per_epoch = len(dataloader) sample_interval = max(1, steps_per_epoch // 4) def save_ckpt(): if saveckpt: # После сохранения модели в конце каждой эпохи checkpoint_path = os.path.join(checkpoints_folder, checkpoint_file) torch.save({ 'model_state_dict': accelerator.unwrap_model(unet).state_dict(), 'optimizer_state_dict': optimizer.state_dict(), #'lr_scheduler_state_dict': lr_scheduler.state_dict(), 'epoch': epoch, 'global_step': global_step, }, checkpoint_path) print(f"Полный чекпоинт сохранён: {checkpoint_path}") # Начинаем с указанной эпохи (полезно при возобновлении) for epoch in range(start_epoch, start_epoch + num_epochs): batch_losses = [] unet.train() for step, (latents, embeddings) in enumerate(dataloader): with accelerator.accumulate(unet): # Forward pass - используем timesteps из более широкого диапазона noise = torch.randn_like(latents) timesteps = torch.randint( 1, # Начинаем с 1, не с 0 scheduler.config.num_train_timesteps, (latents.shape[0],), device=device ).long() # Добавляем шум к латентам noisy_latents = scheduler.add_noise(latents, noise, timesteps) # Получаем предсказание шума noise_pred = unet(noisy_latents, timesteps, embeddings).sample#.to(dtype=torch.bfloat16) # Используем целевое значение v_prediction target = scheduler.get_velocity(latents, noise, timesteps) #print("noise_pred",noise_pred.dtype) # Должно быть torch.bfloat16 #print("target",target.dtype) # Должно быть torch.bfloat16 # Считаем лосс в BF16 для стабильности loss = torch.nn.functional.mse_loss( noise_pred.float(), target.float() ) #print("loss",loss.dtype) # Делаем backward через Accelerator accelerator.backward(loss) # так делал раньше # Используем ограничение нормы градиентов через Accelerator grad_norm = accelerator.clip_grad_norm_(unet.parameters(), 1.0) # Важно: шаг оптимизатора должен быть именно здесь optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Увеличиваем счетчик глобальных шагов global_step += 1 # Обновляем прогресс-бар progress_bar.update(1) # Логируем метрики if accelerator.is_main_process: current_lr = lr_scheduler.get_last_lr()[0] batch_losses.append(loss.detach().item()) # Логируем в Wandb if global_step % 2 == 0 and use_wandb: wandb.log({ "loss": loss.detach().item(), "learning_rate": current_lr, "epoch": epoch, "grad_norm": grad_norm.item(), "epoch": epoch, "global_step": global_step }) # Генерируем сэмплы с заданным интервалом if global_step % sample_interval == 0 and accelerator.is_main_process: generate_and_save_samples(global_step) # Выводим текущий лосс avg_loss = np.mean(batch_losses[-sample_interval:]) #print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}, LR: {current_lr:.8f}") if use_wandb: wandb.log({"intermediate_loss": avg_loss}) # По окончании эпохи accelerator.wait_for_everyone() # Сохраняем чекпоинт в конце каждой эпохи if accelerator.is_main_process: save_ckpt() # Сохраняем UNet отдельно для удобства использования if save_model: accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}")) avg_epoch_loss = np.mean(batch_losses) print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}") if use_wandb: wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1}) # Завершение обучения - сохраняем финальную модель if accelerator.is_main_process: print("Обучение завершено! Сохраняем финальную модель...") # Сохраняем основную модель if save_model: accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}")) print("Готово!")