| |
| import os |
| import math |
| import re |
| import torch |
| import numpy as np |
| import random |
| import gc |
| from datetime import datetime |
| from pathlib import Path |
|
|
| import torchvision.transforms as transforms |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
| from torch.optim.lr_scheduler import LambdaLR |
| from diffusers import AutoencoderKL, AsymmetricAutoencoderKL |
| |
| from diffusers import AutoencoderKLQwenImage |
| from diffusers import AutoencoderKLWan |
|
|
| from accelerate import Accelerator |
| from PIL import Image, UnidentifiedImageError |
| from tqdm import tqdm |
| import bitsandbytes as bnb |
| import wandb |
| import lpips |
| from FDL_pytorch import FDL_loss |
| from collections import deque |
|
|
| |
| ds_path = "/workspace/d23" |
| project = "vae10" |
| batch_size = 1 |
| base_learning_rate = 6e-6 |
| min_learning_rate = 7e-7 |
| num_epochs = 2 |
| sample_interval_share = 25 |
| use_wandb = True |
| save_model = True |
| use_decay = True |
| optimizer_type = "adam8bit" |
| dtype = torch.float32 |
|
|
| model_resolution = 512 |
| high_resolution = 1024 |
| limit = 0 |
| save_barrier = 1.3 |
| warmup_percent = 0.005 |
| percentile_clipping = 99 |
| beta2 = 0.997 |
| eps = 1e-8 |
| clip_grad_norm = 1.0 |
| mixed_precision = "no" |
| gradient_accumulation_steps = 1 |
| generated_folder = "samples" |
| save_as = "vae10" |
| num_workers = 0 |
| device = None |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| |
| torch.backends.cuda.enable_flash_sdp(True) |
| torch.backends.cuda.enable_mem_efficient_sdp(True) |
| torch.backends.cuda.enable_math_sdp(False) |
|
|
| |
| |
| train_decoder_only = True |
| train_up_only = False |
| full_training = False |
| kl_ratio = 0.00 |
|
|
| |
| loss_ratios = { |
| "lpips": 0.70, |
| "fdl" : 0.10, |
| "edge": 0.05, |
| "mse": 0.10, |
| "mae": 0.05, |
| "kl": 0.00, |
| } |
| median_coeff_steps = 250 |
|
|
| resize_long_side = 1280 |
|
|
| |
| vae_kind = "kl" |
|
|
| Path(generated_folder).mkdir(parents=True, exist_ok=True) |
|
|
| accelerator = Accelerator( |
| mixed_precision=mixed_precision, |
| gradient_accumulation_steps=gradient_accumulation_steps |
| ) |
| device = accelerator.device |
|
|
| |
| seed = int(datetime.now().strftime("%Y%m%d")) + 13 |
| torch.manual_seed(seed); np.random.seed(seed); random.seed(seed) |
| torch.backends.cudnn.benchmark = False |
|
|
| |
| if use_wandb and accelerator.is_main_process: |
| wandb.init(project=project, config={ |
| "batch_size": batch_size, |
| "base_learning_rate": base_learning_rate, |
| "num_epochs": num_epochs, |
| "optimizer_type": optimizer_type, |
| "model_resolution": model_resolution, |
| "high_resolution": high_resolution, |
| "gradient_accumulation_steps": gradient_accumulation_steps, |
| "train_decoder_only": train_decoder_only, |
| "full_training": full_training, |
| "kl_ratio": kl_ratio, |
| "vae_kind": vae_kind, |
| }) |
|
|
| |
| def get_core_model(model): |
| m = model |
| |
| if hasattr(m, "_orig_mod"): |
| m = m._orig_mod |
| return m |
|
|
| def is_video_vae(model) -> bool: |
| |
| if vae_kind in ("wan", "qwen"): |
| return True |
| |
| try: |
| core = get_core_model(model) |
| enc = getattr(core, "encoder", None) |
| conv_in = getattr(enc, "conv_in", None) |
| w = getattr(conv_in, "weight", None) |
| if isinstance(w, torch.nn.Parameter): |
| return w.ndim == 5 |
| except Exception: |
| pass |
| return False |
|
|
| |
| if vae_kind == "qwen": |
| vae = AutoencoderKLQwenImage.from_pretrained("Qwen/Qwen-Image", subfolder="vae") |
| else: |
| if vae_kind == "wan": |
| vae = AutoencoderKLWan.from_pretrained(project) |
| else: |
| |
| if model_resolution==high_resolution: |
| vae = AutoencoderKL.from_pretrained(project) |
| else: |
| vae = AsymmetricAutoencoderKL.from_pretrained(project) |
|
|
| vae = vae.to(dtype) |
|
|
| |
| if hasattr(torch, "compile"): |
| try: |
| vae = torch.compile(vae) |
| except Exception as e: |
| print(f"[WARN] torch.compile failed: {e}") |
|
|
| |
| core = get_core_model(vae) |
|
|
| for p in core.parameters(): |
| p.requires_grad = False |
|
|
| unfrozen_param_names = [] |
|
|
| if full_training and not train_decoder_only: |
| for name, p in core.named_parameters(): |
| p.requires_grad = True |
| unfrozen_param_names.append(name) |
| loss_ratios["kl"] = float(kl_ratio) |
| trainable_module = core |
| else: |
| |
| if hasattr(core, "decoder"): |
| if train_up_only: |
| |
| for name, p in core.decoder.up_blocks[0].named_parameters(): |
| p.requires_grad = True |
| unfrozen_param_names.append(f"{name}") |
| else: |
| print("Decoder — fallback to full decoder") |
| for name, p in core.decoder.named_parameters(): |
| p.requires_grad = True |
| unfrozen_param_names.append(f"decoder.{name}") |
| if hasattr(core, "post_quant_conv"): |
| for name, p in core.post_quant_conv.named_parameters(): |
| p.requires_grad = True |
| unfrozen_param_names.append(f"post_quant_conv.{name}") |
| trainable_module = core.decoder if hasattr(core, "decoder") else core |
|
|
|
|
| print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:") |
| for nm in unfrozen_param_names[:200]: |
| print(" ", nm) |
|
|
| |
| class PngFolderDataset(Dataset): |
| def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0): |
| self.root_dir = root_dir |
| self.resolution = resolution |
| self.paths = [] |
| for root, _, files in os.walk(root_dir): |
| for fname in files: |
| if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)): |
| self.paths.append(os.path.join(root, fname)) |
| if limit: |
| self.paths = self.paths[:limit] |
| valid = [] |
| for p in self.paths: |
| try: |
| with Image.open(p) as im: |
| im.verify() |
| valid.append(p) |
| except (OSError, UnidentifiedImageError): |
| continue |
| self.paths = valid |
| if len(self.paths) == 0: |
| raise RuntimeError(f"No valid PNG images found under {root_dir}") |
| random.shuffle(self.paths) |
|
|
| def __len__(self): |
| return len(self.paths) |
|
|
| def __getitem__(self, idx): |
| p = self.paths[idx % len(self.paths)] |
| with Image.open(p) as img: |
| img = img.convert("RGB") |
| if not resize_long_side or resize_long_side <= 0: |
| return img |
| w, h = img.size |
| long = max(w, h) |
| if long <= resize_long_side: |
| return img |
| scale = resize_long_side / float(long) |
| new_w = int(round(w * scale)) |
| new_h = int(round(h * scale)) |
| return img.resize((new_w, new_h), Image.BICUBIC) |
|
|
| def random_crop(img, sz): |
| w, h = img.size |
| if w < sz or h < sz: |
| img = img.resize((max(sz, w), max(sz, h)), Image.BICUBIC) |
| x = random.randint(0, max(1, img.width - sz)) |
| y = random.randint(0, max(1, img.height - sz)) |
| return img.crop((x, y, x + sz, y + sz)) |
|
|
| tfm = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) |
| ]) |
|
|
| dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit) |
| print("len(dataset)",len(dataset)) |
| if len(dataset) < batch_size: |
| raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}") |
|
|
| def collate_fn(batch): |
| imgs = [] |
| for img in batch: |
| img = random_crop(img, high_resolution) |
| imgs.append(tfm(img)) |
| return torch.stack(imgs) |
|
|
| dataloader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| collate_fn=collate_fn, |
| num_workers=num_workers, |
| pin_memory=True, |
| drop_last=True |
| ) |
|
|
| |
| def get_param_groups(module, weight_decay=0.001): |
| no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"] |
| decay_params, no_decay_params = [], [] |
| for n, p in vae.named_parameters(): |
| if not p.requires_grad: |
| continue |
| if any(nd in n for nd in no_decay): |
| no_decay_params.append(p) |
| else: |
| decay_params.append(p) |
| return [ |
| {"params": decay_params, "weight_decay": weight_decay}, |
| {"params": no_decay_params, "weight_decay": 0.0}, |
| ] |
|
|
| def get_param_groups(module, weight_decay=0.001): |
| no_decay_tokens = ("bias", "norm", "rms", "layernorm") |
| decay_params, no_decay_params = [], [] |
| for n, p in module.named_parameters(): |
| if not p.requires_grad: |
| continue |
| n_l = n.lower() |
| if any(t in n_l for t in no_decay_tokens): |
| no_decay_params.append(p) |
| else: |
| decay_params.append(p) |
| return [ |
| {"params": decay_params, "weight_decay": weight_decay}, |
| {"params": no_decay_params, "weight_decay": 0.0}, |
| ] |
|
|
| def create_optimizer(name, param_groups): |
| if name == "adam8bit": |
| return bnb.optim.AdamW8bit(param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps) |
| raise ValueError(name) |
|
|
| param_groups = get_param_groups(get_core_model(vae), weight_decay=0.001) |
| optimizer = create_optimizer(optimizer_type, param_groups) |
|
|
| |
| batches_per_epoch = len(dataloader) |
| steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps))) |
| total_steps = steps_per_epoch * num_epochs |
|
|
| def lr_lambda(step): |
| if not use_decay: |
| return 1.0 |
| x = float(step) / float(max(1, total_steps)) |
| warmup = float(warmup_percent) |
| min_ratio = float(min_learning_rate) / float(base_learning_rate) |
| if x < warmup: |
| return min_ratio + (1.0 - min_ratio) * (x / warmup) |
| decay_ratio = (x - warmup) / (1.0 - warmup) |
| return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio)) |
|
|
| scheduler = LambdaLR(optimizer, lr_lambda) |
|
|
| |
| dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler) |
| trainable_params = [p for p in vae.parameters() if p.requires_grad] |
|
|
| |
| fdl_loss = FDL_loss() |
| fdl_loss = fdl_loss.to(accelerator.device) |
|
|
| |
| _lpips_net = None |
| def _get_lpips(): |
| global _lpips_net |
| if _lpips_net is None: |
| _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval() |
| return _lpips_net |
|
|
| _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32) |
| _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32) |
| def sobel_edges(x: torch.Tensor) -> torch.Tensor: |
| C = x.shape[1] |
| kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1) |
| ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1) |
| gx = F.conv2d(x, kx, padding=1, groups=C) |
| gy = F.conv2d(x, ky, padding=1, groups=C) |
| return torch.sqrt(gx * gx + gy * gy + 1e-12) |
|
|
| class MedianLossNormalizer: |
| def __init__(self, desired_ratios: dict, window_steps: int): |
| s = sum(desired_ratios.values()) |
| self.ratios = {k: (v / s) if s > 0 else 0.0 for k, v in desired_ratios.items()} |
| self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()} |
| self.window = window_steps |
|
|
| def update_and_total(self, abs_losses: dict): |
| for k, v in abs_losses.items(): |
| if k in self.buffers: |
| self.buffers[k].append(float(v.detach().abs().cpu())) |
| meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers} |
| coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios} |
| total = sum(coeffs[k] * abs_losses[k] for k in abs_losses if k in coeffs) |
| return total, coeffs, meds |
|
|
| if full_training and not train_decoder_only: |
| loss_ratios["kl"] = float(kl_ratio) |
| normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps) |
|
|
| |
| @torch.no_grad() |
| def get_fixed_samples(n=3): |
| idx = random.sample(range(len(dataset)), min(n, len(dataset))) |
| pil_imgs = [dataset[i] for i in idx] |
| tensors = [] |
| for img in pil_imgs: |
| img = random_crop(img, high_resolution) |
| tensors.append(tfm(img)) |
| return torch.stack(tensors).to(accelerator.device, dtype) |
|
|
| fixed_samples = get_fixed_samples() |
|
|
| @torch.no_grad() |
| def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image: |
| arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0) |
| return Image.fromarray(arr) |
|
|
|
|
| @torch.no_grad() |
| def generate_and_save_samples(step=None): |
| try: |
| |
| if hasattr(vae, "module"): |
| |
| unwrapped_vae = vae.module |
| else: |
| unwrapped_vae = vae |
| |
| |
| if hasattr(unwrapped_vae, "_orig_mod"): |
| temp_vae = unwrapped_vae._orig_mod |
| else: |
| temp_vae = unwrapped_vae |
| |
| temp_vae = temp_vae.eval() |
| lpips_net = _get_lpips() |
| with torch.no_grad(): |
| orig_high = fixed_samples |
| orig_low = F.interpolate( |
| orig_high, |
| size=(model_resolution, model_resolution), |
| mode="bilinear", |
| align_corners=False |
| ) |
| model_dtype = next(temp_vae.parameters()).dtype |
| orig_low = orig_low.to(dtype=model_dtype) |
|
|
| |
| if is_video_vae(temp_vae): |
| x_in = orig_low.unsqueeze(2) |
| enc = temp_vae.encode(x_in) |
| latents_mean = enc.latent_dist.mean |
| dec = temp_vae.decode(latents_mean).sample |
| rec = dec.squeeze(2) |
| else: |
| enc = temp_vae.encode(orig_low) |
| latents_mean = enc.latent_dist.mean |
| rec = temp_vae.decode(latents_mean).sample |
|
|
| |
| |
| |
|
|
| |
| for i in range(rec.shape[0]): |
| real_img = _to_pil_uint8(orig_high[i]) |
| dec_img = _to_pil_uint8(rec[i]) |
| real_img.save(f"{generated_folder}/sample_real_{i}.png") |
| dec_img.save(f"{generated_folder}/sample_decoded_{i}.png") |
|
|
| |
| lpips_scores = [] |
| for i in range(rec.shape[0]): |
| orig_full = orig_high[i:i+1].to(torch.float32) |
| rec_full = rec[i:i+1].to(torch.float32) |
| |
| |
| lpips_val = lpips_net(orig_full, rec_full).item() |
| lpips_scores.append(lpips_val) |
| avg_lpips = float(np.mean(lpips_scores)) |
|
|
| |
| if use_wandb and accelerator.is_main_process: |
| log_data = {"lpips_mean": avg_lpips} |
| for i in range(rec.shape[0]): |
| log_data[f"sample/real_{i}"] = wandb.Image(f"{generated_folder}/sample_real_{i}.png", caption=f"real_{i}") |
| log_data[f"sample/decoded_{i}"] = wandb.Image(f"{generated_folder}/sample_decoded_{i}.png", caption=f"decoded_{i}") |
| wandb.log(log_data, step=step) |
|
|
| finally: |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
|
|
| if accelerator.is_main_process and save_model: |
| print("Генерация сэмплов до старта обучения...") |
| generate_and_save_samples(0) |
|
|
| accelerator.wait_for_everyone() |
|
|
| |
| progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process) |
| global_step = 0 |
| min_loss = float("inf") |
| sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs)) |
|
|
| for epoch in range(num_epochs): |
| vae.train() |
| batch_losses, batch_grads = [], [] |
| track_losses = {k: [] for k in loss_ratios.keys()} |
|
|
| for imgs in dataloader: |
| with accelerator.accumulate(vae): |
| imgs = imgs.to(accelerator.device) |
|
|
| if high_resolution != model_resolution: |
| imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution),mode="area") |
| else: |
| imgs_low = imgs |
|
|
| model_dtype = next(vae.parameters()).dtype |
| imgs_low_model = imgs_low.to(dtype=model_dtype) if imgs_low.dtype != model_dtype else imgs_low |
|
|
| |
| unwrapped = vae.module if hasattr(vae, "module") else vae |
| current_vae = getattr(unwrapped, "_orig_mod", unwrapped) |
|
|
|
|
| |
| if is_video_vae(current_vae): |
| x_in = imgs_low_model.unsqueeze(2) |
| enc = current_vae.encode(x_in) |
| latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample() |
| dec = current_vae.decode(latents).sample |
| rec = dec.squeeze(2) |
| else: |
| enc = current_vae.encode(imgs_low_model) |
| latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample() |
| rec = current_vae.decode(latents).sample |
|
|
| |
| |
|
|
| rec_f32 = rec.to(torch.float32) |
| imgs_f32 = imgs.to(torch.float32) |
|
|
| abs_losses = { |
| "mae": F.l1_loss(rec_f32, imgs_f32), |
| "mse": F.mse_loss(rec_f32, imgs_f32), |
| "lpips": _get_lpips()(rec_f32, imgs_f32).mean(), |
| "fdl": fdl_loss(rec_f32, imgs_f32), |
| "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)), |
| } |
|
|
| if full_training and not train_decoder_only: |
| mean = enc.latent_dist.mean |
| logvar = enc.latent_dist.logvar |
| kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp()) |
| abs_losses["kl"] = kl |
| else: |
| abs_losses["kl"] = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32) |
|
|
| total_loss, coeffs, meds = normalizer.update_and_total(abs_losses) |
|
|
| if torch.isnan(total_loss) or torch.isinf(total_loss): |
| raise RuntimeError("NaN/Inf loss") |
|
|
| accelerator.backward(total_loss) |
|
|
| grad_norm = torch.tensor(0.0, device=accelerator.device) |
| if accelerator.sync_gradients: |
| grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad(set_to_none=True) |
| global_step += 1 |
| progress.update(1) |
|
|
| if accelerator.is_main_process: |
| try: |
| current_lr = optimizer.param_groups[0]["lr"] |
| except Exception: |
| current_lr = scheduler.get_last_lr()[0] |
|
|
| batch_losses.append(total_loss.detach().item()) |
| batch_grads.append(float(grad_norm.detach().cpu().item()) if isinstance(grad_norm, torch.Tensor) else float(grad_norm)) |
| for k, v in abs_losses.items(): |
| track_losses[k].append(float(v.detach().item())) |
|
|
| if use_wandb and accelerator.sync_gradients: |
| log_dict = { |
| "total_loss": float(total_loss.detach().item()), |
| "learning_rate": current_lr, |
| "epoch": epoch, |
| "grad_norm": batch_grads[-1], |
| } |
| for k, v in abs_losses.items(): |
| log_dict[f"loss_{k}"] = float(v.detach().item()) |
| for k in coeffs: |
| log_dict[f"coeff_{k}"] = float(coeffs[k]) |
| log_dict[f"median_{k}"] = float(meds[k]) |
| wandb.log(log_dict, step=global_step) |
|
|
| if global_step > 0 and global_step % sample_interval == 0: |
| if accelerator.is_main_process: |
| generate_and_save_samples(global_step) |
| accelerator.wait_for_everyone() |
|
|
| n_micro = sample_interval * gradient_accumulation_steps |
| avg_loss = float(np.mean(batch_losses[-n_micro:])) if len(batch_losses) >= n_micro else float(np.mean(batch_losses)) if batch_losses else float("nan") |
| avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0 |
|
|
| if accelerator.is_main_process: |
| print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}") |
| if save_model and avg_loss < min_loss * save_barrier: |
| min_loss = avg_loss |
| unwrapped = vae.module if hasattr(vae, "module") else vae |
| current_vae = getattr(unwrapped, "_orig_mod", unwrapped) |
| current_vae.save_pretrained(save_as) |
| if use_wandb: |
| wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step) |
|
|
| if accelerator.is_main_process: |
| epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan") |
| print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}") |
| if use_wandb: |
| wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step) |
|
|
| |
| if accelerator.is_main_process: |
| print("Training finished – saving final model") |
| if save_model: |
| unwrapped = vae.module if hasattr(vae, "module") else vae |
| current_vae = getattr(unwrapped, "_orig_mod", unwrapped) |
| current_vae.save_pretrained(save_as) |
|
|
| accelerator.free_memory() |
| if torch.distributed.is_initialized(): |
| torch.distributed.destroy_process_group() |
| print("Готово!") |