| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import annotations |
| | import os, json, math, random, re |
| | from dataclasses import dataclass, asdict |
| | from pathlib import Path |
| | from typing import Dict, List, Tuple, Optional |
| | import urllib.request |
| | import subprocess |
| | import shutil |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import Dataset, DataLoader |
| | from torch.utils.tensorboard import SummaryWriter |
| | from tqdm import tqdm |
| |
|
| | |
| | from diffusers import StableDiffusionPipeline, DDPMScheduler |
| | from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel |
| |
|
| | |
| | from geovocab2.train.model.core.geo_david_collective import GeoDavidCollective |
| | from geovocab2.data.prompt.symbolic_tree import SynthesisSystem |
| |
|
| | |
| | from huggingface_hub import snapshot_download, HfApi, create_repo, hf_hub_download |
| | from safetensors.torch import load_file |
| |
|
| |
|
| | |
| | |
| | |
| | @dataclass |
| | class BaseConfig: |
| | run_name: str = "sd15_flowmatch_david_hf" |
| | out_dir: str = "./runs/sd15_flowmatch_david_hf" |
| | ckpt_dir: str = "./checkpoints_sd15_flow_david_hf" |
| | save_every: int = 1 |
| |
|
| | |
| | num_samples: int = 200_000 |
| | batch_size: int = 32 |
| | num_workers: int = 2 |
| | seed: int = 42 |
| |
|
| | |
| | model_id: str = "runwayml/stable-diffusion-v1-5" |
| | active_blocks: Tuple[str, ...] = ("down_0","down_1","down_2","down_3","mid","up_0","up_1","up_2","up_3") |
| | pooling: str = "mean" |
| |
|
| | |
| | epochs: int = 10 |
| | lr: float = 1e-4 |
| | weight_decay: float = 1e-3 |
| | grad_clip: float = 1.0 |
| | amp: bool = True |
| | |
| | global_flow_weight: float = 1.0 |
| | block_penalty_weight: float = 0.2 |
| | use_local_flow_heads: bool = False |
| | local_flow_weight: float = 1.0 |
| |
|
| | |
| | use_kd: bool = True |
| | kd_weight: float = 0.25 |
| |
|
| | |
| | david_repo_id: str = "AbstractPhil/geo-david-collective-sd15-base-e40" |
| | david_cache_dir: str = "./_hf_david_cache" |
| | david_state_key: Optional[str] = None |
| |
|
| | |
| | alpha_timestep: float = 0.5 |
| | beta_pattern: float = 0.25 |
| | delta_incoherence: float = 0.25 |
| | lambda_min: float = 0.5 |
| | lambda_max: float = 3.0 |
| |
|
| | |
| | block_weights: Dict[str, float] = None |
| |
|
| | |
| | num_train_timesteps: int = 1000 |
| |
|
| | |
| | sample_steps: int = 30 |
| | guidance_scale: float = 7.5 |
| | |
| | |
| | hf_repo_id: Optional[str] = "AbstractPhil/sd15-flow-matching" |
| | upload_every_epoch: bool = True |
| | continue_training: bool = True |
| |
|
| | def __post_init__(self): |
| | Path(self.out_dir).mkdir(parents=True, exist_ok=True) |
| | Path(self.ckpt_dir).mkdir(parents=True, exist_ok=True) |
| | Path(self.david_cache_dir).mkdir(parents=True, exist_ok=True) |
| | if self.block_weights is None: |
| | self.block_weights = {'down_0':0.7,'down_1':0.9,'down_2':1.0,'down_3':1.1,'mid':1.2,'up_0':1.1,'up_1':1.0,'up_2':0.9,'up_3':0.7} |
| |
|
| |
|
| | |
| | |
| | |
| | class SymbolicPromptDataset(Dataset): |
| | def __init__(self, n:int, seed:int=42): |
| | self.n = n |
| | random.seed(seed) |
| | self.sys = SynthesisSystem(seed=seed) |
| |
|
| | def __len__(self): return self.n |
| |
|
| | def __getitem__(self, idx): |
| | r = self.sys.synthesize(complexity=random.choice([1,2,3,4,5])) |
| | prompt = r['text'] |
| | t = random.randint(0, 999) |
| | return {"prompt": prompt, "t": t} |
| |
|
| | def collate(batch: List[dict]): |
| | prompts = [b["prompt"] for b in batch] |
| | t = torch.tensor([b["t"] for b in batch], dtype=torch.long) |
| | t_bins = t // 10 |
| | return {"prompts": prompts, "t": t, "t_bins": t_bins} |
| |
|
| |
|
| | |
| | |
| | |
| | class HookBank: |
| | def __init__(self, unet: UNet2DConditionModel, active: Tuple[str, ...]): |
| | self.active = set(active) |
| | self.bank: Dict[str, torch.Tensor] = {} |
| | self.hooks: List[torch.utils.hooks.RemovableHandle] = [] |
| | self._register(unet) |
| |
|
| | def _register(self, unet: UNet2DConditionModel): |
| | def mk(name): |
| | def h(m, i, o): |
| | out = o[0] if isinstance(o,(tuple,list)) else o |
| | self.bank[name] = out |
| | return h |
| | for i, blk in enumerate(unet.down_blocks): |
| | nm = f"down_{i}" |
| | if nm in self.active: self.hooks.append(blk.register_forward_hook(mk(nm))) |
| | if "mid" in self.active: |
| | self.hooks.append(unet.mid_block.register_forward_hook(mk("mid"))) |
| | for i, blk in enumerate(unet.up_blocks): |
| | nm = f"up_{i}" |
| | if nm in self.active: self.hooks.append(blk.register_forward_hook(mk(nm))) |
| |
|
| | def clear(self): self.bank.clear() |
| | def close(self): |
| | for h in self.hooks: h.remove() |
| | self.hooks.clear() |
| |
|
| | def spatial_pool(x: torch.Tensor, name: str, policy: str) -> torch.Tensor: |
| | if policy == "mean": return x.mean(dim=(2,3)) |
| | if policy == "max": return x.amax(dim=(2,3)) |
| | if policy == "adaptive": |
| | return x.mean(dim=(2,3)) if (name.startswith("down") or name=="mid") else x.amax(dim=(2,3)) |
| | raise ValueError(f"Unknown pooling: {policy}") |
| |
|
| |
|
| | |
| | |
| | |
| | class SD15Teacher(nn.Module): |
| | def __init__(self, cfg: BaseConfig, device: str): |
| | super().__init__() |
| | self.pipe = StableDiffusionPipeline.from_pretrained(cfg.model_id, torch_dtype=torch.float16, safety_checker=None).to(device) |
| | self.unet: UNet2DConditionModel = self.pipe.unet |
| | self.text_encoder = self.pipe.text_encoder |
| | self.tokenizer = self.pipe.tokenizer |
| | self.hooks = HookBank(self.unet, cfg.active_blocks) |
| | self.sched = DDPMScheduler(num_train_timesteps=cfg.num_train_timesteps) |
| | self.device = device |
| | for p in self.parameters(): p.requires_grad_(False) |
| |
|
| | @torch.no_grad() |
| | def encode(self, prompts: List[str]) -> torch.Tensor: |
| | tok = self.tokenizer(prompts, padding="max_length", max_length=self.tokenizer.model_max_length, |
| | truncation=True, return_tensors="pt") |
| | return self.text_encoder(tok.input_ids.to(self.device))[0] |
| |
|
| | @torch.no_grad() |
| | def forward_eps_and_feats(self, x_t: torch.Tensor, t: torch.LongTensor, ehs: torch.Tensor): |
| | self.hooks.clear() |
| | eps_hat = self.unet(x_t, t, encoder_hidden_states=ehs).sample |
| | feats = {k: v.detach().float() for k, v in self.hooks.bank.items()} |
| | return eps_hat.float(), feats |
| |
|
| | def alpha_sigma(self, t: torch.LongTensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | ac = self.sched.alphas_cumprod.to(self.device)[t] |
| | alpha = ac.sqrt().view(-1,1,1,1).float() |
| | sigma = (1.0 - ac).sqrt().view(-1,1,1,1).float() |
| | return alpha, sigma |
| |
|
| |
|
| | |
| | |
| | |
| | class StudentUNet(nn.Module): |
| | def __init__(self, teacher_unet: UNet2DConditionModel, active_blocks: Tuple[str,...], use_local_heads: bool): |
| | super().__init__() |
| | self.unet = UNet2DConditionModel.from_config(teacher_unet.config) |
| | self.unet.load_state_dict(teacher_unet.state_dict(), strict=True) |
| | self.hooks = HookBank(self.unet, active_blocks) |
| | self.use_local_heads = use_local_heads |
| | self.local_heads = nn.ModuleDict() |
| |
|
| | def _ensure_heads(self, feats: Dict[str, torch.Tensor]): |
| | if not self.use_local_heads: return |
| | if len(self.local_heads) == len(feats): return |
| | |
| | |
| | target_dtype = next(self.unet.parameters()).dtype |
| | |
| | for name, f in feats.items(): |
| | c = f.shape[1] |
| | if name not in self.local_heads: |
| | head = nn.Conv2d(c, 4, kernel_size=1) |
| | |
| | head = head.to(dtype=target_dtype, device=f.device) |
| | self.local_heads[name] = head |
| |
|
| | def forward(self, x_t: torch.Tensor, t: torch.LongTensor, ehs: torch.Tensor): |
| | self.hooks.clear() |
| | v_hat = self.unet(x_t, t, encoder_hidden_states=ehs).sample |
| | feats = {k: v for k, v in self.hooks.bank.items()} |
| | self._ensure_heads(feats) |
| | return v_hat, feats |
| |
|
| |
|
| | |
| | |
| | |
| | class DavidLoader: |
| | """ |
| | Downloads HF repo (config + safetensors), instantiates GeoDavidCollective with HF config, |
| | loads weights, returns a frozen model + the parsed HF config. |
| | """ |
| | def __init__(self, cfg: BaseConfig, device: str): |
| | self.cfg = cfg |
| | self.device = device |
| | self.repo_dir = snapshot_download(repo_id=cfg.david_repo_id, local_dir=cfg.david_cache_dir, local_dir_use_symlinks=False) |
| | self.config_path = os.path.join(self.repo_dir, "config.json") |
| | self.weights_path = os.path.join(self.repo_dir, "model.safetensors") |
| | with open(self.config_path, "r") as f: |
| | self.hf_config = json.load(f) |
| | |
| | self.gdc = GeoDavidCollective( |
| | block_configs=self.hf_config["block_configs"], |
| | num_timestep_bins=int(self.hf_config["num_timestep_bins"]), |
| | num_patterns_per_bin=int(self.hf_config["num_patterns_per_bin"]), |
| | block_weights=self.hf_config.get("block_weights", {k:1.0 for k in self.hf_config["block_configs"].keys()}), |
| | loss_config=self.hf_config.get("loss_config", {}) |
| | ).to(device).eval() |
| | |
| | state = load_file(self.weights_path) |
| | self.gdc.load_state_dict(state, strict=False) |
| | for p in self.gdc.parameters(): p.requires_grad_(False) |
| | |
| | print(f"β David loaded from HF: {self.repo_dir}") |
| | print(f" blocks={len(self.hf_config['block_configs'])} bins={self.hf_config['num_timestep_bins']} patterns={self.hf_config['num_patterns_per_bin']}") |
| | |
| | if "block_weights" in self.hf_config: |
| | cfg.block_weights = self.hf_config["block_weights"] |
| |
|
| | class DavidAssessor(nn.Module): |
| | """ |
| | Uses David to score STUDENT pooled features (per block) and timesteps. |
| | Produces: |
| | e_t[name] : timestep CE error proxy (scalar) |
| | e_p[name] : pattern CE error proxy if logits present, else 0 |
| | coh[name] : coherence proxy (avg Cantor alpha if provided, else 1) |
| | """ |
| | def __init__(self, gdc: GeoDavidCollective, pooling: str): |
| | super().__init__() |
| | self.gdc = gdc |
| | self.pooling = pooling |
| |
|
| | def _pool(self, feats: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| | return {k: spatial_pool(v, k, self.pooling) for k, v in feats.items()} |
| |
|
| | @torch.no_grad() |
| | def forward(self, feats_student: Dict[str, torch.Tensor], t: torch.LongTensor |
| | ) -> Tuple[Dict[str,float], Dict[str,float], Dict[str,float]]: |
| | Zs = self._pool(feats_student) |
| | outs = self.gdc(Zs, t.float()) |
| | e_t, e_p, coh = {}, {}, {} |
| |
|
| | |
| | ts_key = None |
| | for key in ["timestep_logits", "logits_timestep", "timestep_head_logits"]: |
| | if key in outs: ts_key = key; break |
| | |
| | pt_key = None |
| | for key in ["pattern_logits", "logits_pattern", "pattern_head_logits"]: |
| | if key in outs: pt_key = key; break |
| |
|
| | t_bins = (t // 10).to(next(self.gdc.parameters()).device) |
| | if ts_key is not None: |
| | |
| | ts_logits = outs[ts_key] |
| | if isinstance(ts_logits, dict): |
| | for name, L in ts_logits.items(): |
| | ce = F.cross_entropy(L, t_bins, reduction="mean") |
| | e_t[name] = float(ce.item()) |
| | else: |
| | |
| | ce = F.cross_entropy(ts_logits, t_bins, reduction="mean") |
| | for name in Zs.keys(): |
| | e_t[name] = float(ce.item()) |
| | else: |
| | for name in Zs.keys(): e_t[name] = 0.0 |
| |
|
| | if pt_key is not None: |
| | pt_logits = outs[pt_key] |
| | |
| | if isinstance(pt_logits, dict): |
| | for name, L in pt_logits.items(): |
| | P = L.softmax(-1) |
| | ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean() |
| | e_p[name] = float(ent.item() / math.log(P.shape[-1])) |
| | else: |
| | P = pt_logits.softmax(-1) |
| | ent = -(P * (P.clamp_min(1e-9)).log()).sum(-1).mean() |
| | for name in Zs.keys(): |
| | e_p[name] = float(ent.item() / math.log(P.shape[-1])) |
| | else: |
| | for name in Zs.keys(): e_p[name] = 0.0 |
| |
|
| | |
| | alphas = {} |
| | try: |
| | alphas = self.gdc.get_cantor_alphas() |
| | except Exception: |
| | alphas = {} |
| | avg_alpha = float(sum(alphas.values())/max(len(alphas),1)) if alphas else 1.0 |
| | for name in Zs.keys(): |
| | coh[name] = avg_alpha |
| |
|
| | return e_t, e_p, coh |
| |
|
| | class BlockPenaltyFusion: |
| | def __init__(self, cfg: BaseConfig): self.cfg = cfg |
| | def lambdas(self, e_t:Dict[str,float], e_p:Dict[str,float], coh:Dict[str,float]) -> Dict[str,float]: |
| | lam = {} |
| | for name, base in self.cfg.block_weights.items(): |
| | val = base * (1.0 |
| | + self.cfg.alpha_timestep * float(e_t.get(name,0.0)) |
| | + self.cfg.beta_pattern * float(e_p.get(name,0.0)) |
| | + self.cfg.delta_incoherence * (1.0 - float(coh.get(name,1.0)))) |
| | lam[name] = float(max(self.cfg.lambda_min, min(self.cfg.lambda_max, val))) |
| | return lam |
| |
|
| |
|
| | |
| | |
| | |
| | class FlowMatchDavidTrainer: |
| | def __init__(self, cfg: BaseConfig, device: str = "cuda"): |
| | self.cfg = cfg |
| | self.device = device |
| | self.start_epoch = 0 |
| | self.start_gstep = 0 |
| |
|
| | |
| | self.dataset = SymbolicPromptDataset(cfg.num_samples, cfg.seed) |
| | self.loader = DataLoader(self.dataset, batch_size=cfg.batch_size, shuffle=True, |
| | num_workers=cfg.num_workers, pin_memory=True, collate_fn=collate) |
| |
|
| | |
| | self.teacher = SD15Teacher(cfg, device).eval() |
| | self.student = StudentUNet(self.teacher.unet, cfg.active_blocks, cfg.use_local_flow_heads).to(device) |
| |
|
| | |
| | self.david_loader = DavidLoader(cfg, device) |
| | self.david = self.david_loader.gdc |
| | |
| | self.assessor = DavidAssessor(self.david, cfg.pooling) |
| | self.fusion = BlockPenaltyFusion(cfg) |
| |
|
| | |
| | self.opt = torch.optim.AdamW(self.student.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) |
| | self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=cfg.epochs * len(self.loader)) |
| | self.scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp) |
| |
|
| | |
| | if cfg.continue_training: |
| | self._load_latest_from_hf() |
| |
|
| | |
| | self.writer = SummaryWriter(log_dir=os.path.join(cfg.out_dir, cfg.run_name)) |
| |
|
| | def _load_latest_from_hf(self): |
| | """Download and load the latest checkpoint from HuggingFace.""" |
| | if not self.cfg.hf_repo_id: |
| | print("β οΈ continue_training=True but no hf_repo_id specified") |
| | return |
| | |
| | try: |
| | api = HfApi() |
| | print(f"\nπ Searching for latest checkpoint in {self.cfg.hf_repo_id}...") |
| | |
| | |
| | try: |
| | repo_info = api.repo_info(repo_id=self.cfg.hf_repo_id, repo_type="model") |
| | except Exception as e: |
| | print(f"β οΈ Could not access repo: {e}") |
| | print(" Starting training from scratch") |
| | return |
| | |
| | |
| | files = api.list_repo_files(repo_id=self.cfg.hf_repo_id, repo_type="model") |
| | |
| | if not files: |
| | print("βΉοΈ Repo is empty, starting from scratch") |
| | return |
| | |
| | print(f"π Found {len(files)} files in repo:") |
| | for f in files: |
| | print(f" - {f}") |
| | |
| | |
| | |
| | epochs = [] |
| | |
| | for f in files: |
| | if not f.endswith('.safetensors'): |
| | continue |
| | |
| | |
| | match = re.search(r'_e(\d+)\.safetensors$', f) |
| | if match: |
| | epoch_num = int(match.group(1)) |
| | epochs.append((epoch_num, f)) |
| | print(f"β Found checkpoint: {f} (epoch {epoch_num})") |
| | |
| | if not epochs: |
| | print("βΉοΈ No checkpoint files found (looking for *_e<num>.safetensors)") |
| | return |
| | |
| | |
| | latest_epoch, latest_file = max(epochs, key=lambda x: x[0]) |
| | print(f"\nπ₯ Downloading latest checkpoint: {latest_file} (epoch {latest_epoch})") |
| | |
| | |
| | local_path = hf_hub_download( |
| | repo_id=self.cfg.hf_repo_id, |
| | filename=latest_file, |
| | repo_type="model", |
| | cache_dir=self.cfg.ckpt_dir |
| | ) |
| | print(f"β Downloaded to: {local_path}") |
| | |
| | |
| | print("π¦ Loading checkpoint into pipeline...") |
| | pipe = StableDiffusionPipeline.from_single_file( |
| | local_path, |
| | torch_dtype=torch.float16, |
| | safety_checker=None, |
| | load_safety_checker=False |
| | ) |
| | |
| | |
| | unet_state = pipe.unet.state_dict() |
| | |
| | |
| | missing, unexpected = self.student.unet.load_state_dict(unet_state, strict=False) |
| | print(f"β Loaded student UNet from epoch {latest_epoch}") |
| | if missing: |
| | print(f" Missing keys: {len(missing)}") |
| | if unexpected: |
| | print(f" Unexpected keys: {len(unexpected)}") |
| | |
| | |
| | self.start_epoch = latest_epoch |
| | self.start_gstep = latest_epoch * len(self.loader) |
| | |
| | print(f"π― Resuming training from epoch {self.start_epoch + 1}") |
| | |
| | |
| | del pipe |
| | torch.cuda.empty_cache() |
| | |
| | except Exception as e: |
| | print(f"β οΈ Failed to load checkpoint from HF: {e}") |
| | print(" Starting training from scratch") |
| | import traceback |
| | traceback.print_exc() |
| |
|
| | |
| | |
| | def _v_star(self, x_t, t, eps_hat): |
| | alpha, sigma = self.teacher.alpha_sigma(t) |
| | x0_hat = (x_t - sigma * eps_hat) / (alpha + 1e-8) |
| | return alpha * eps_hat - sigma * x0_hat |
| |
|
| | def _down_like(self, tgt: torch.Tensor, ref: torch.Tensor) -> torch.Tensor: |
| | return F.interpolate(tgt, size=ref.shape[-2:], mode="bilinear", align_corners=False) |
| |
|
| | def _kd_cos(self, s: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
| | s = F.normalize(s, dim=-1); t = F.normalize(t, dim=-1) |
| | return 1.0 - (s*t).sum(-1).mean() |
| |
|
| | |
| | def train(self): |
| | cfg = self.cfg |
| | gstep = self.start_gstep |
| | |
| | for ep in range(self.start_epoch, cfg.epochs): |
| | self.student.train() |
| | pbar = tqdm(self.loader, desc=f"Epoch {ep+1}/{cfg.epochs}", |
| | dynamic_ncols=True, leave=True, position=0) |
| | acc = {"L":0.0, "Lf":0.0, "Lb":0.0} |
| |
|
| | for it, batch in enumerate(pbar): |
| | prompts = batch["prompts"] |
| | t = batch["t"].to(self.device) |
| |
|
| | with torch.no_grad(): |
| | ehs = self.teacher.encode(prompts) |
| |
|
| | |
| | x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device, dtype=torch.float16) |
| |
|
| | |
| | with torch.no_grad(): |
| | eps_hat, t_feats_spatial = self.teacher.forward_eps_and_feats(x_t.half(), t, ehs) |
| | v_star = self._v_star(x_t, t, eps_hat) |
| |
|
| | with torch.cuda.amp.autocast(enabled=cfg.amp): |
| | |
| | v_hat, s_feats_spatial = self.student(x_t, t, ehs) |
| | L_flow = F.mse_loss(v_hat, v_star) |
| |
|
| | |
| | e_t, e_p, coh = self.assessor(s_feats_spatial, t) |
| | lam = self.fusion.lambdas(e_t, e_p, coh) |
| |
|
| | |
| | L_blocks = torch.zeros((), device=self.device) |
| | for name, s_feat in s_feats_spatial.items(): |
| | |
| | L_kd = torch.zeros((), device=self.device) |
| | if cfg.use_kd: |
| | s_pool = spatial_pool(s_feat, name, cfg.pooling) |
| | t_pool = spatial_pool(t_feats_spatial[name], name, cfg.pooling) |
| | L_kd = self._kd_cos(s_pool, t_pool) |
| | |
| | L_lf = torch.zeros((), device=self.device) |
| | if cfg.use_local_flow_heads and name in self.student.local_heads: |
| | v_loc = self.student.local_heads[name](s_feat) |
| | v_ds = self._down_like(v_star, v_loc) |
| | L_lf = F.mse_loss(v_loc, v_ds) |
| | L_blocks = L_blocks + lam.get(name,1.0) * (cfg.kd_weight * L_kd + cfg.local_flow_weight * L_lf) |
| |
|
| | L_total = cfg.global_flow_weight*L_flow + cfg.block_penalty_weight*L_blocks |
| |
|
| | self.opt.zero_grad(set_to_none=True) |
| | if cfg.amp: |
| | self.scaler.scale(L_total).backward() |
| | nn.utils.clip_grad_norm_(self.student.parameters(), cfg.grad_clip) |
| | self.scaler.step(self.opt); self.scaler.update() |
| | else: |
| | L_total.backward() |
| | nn.utils.clip_grad_norm_(self.student.parameters(), cfg.grad_clip) |
| | self.opt.step() |
| | self.sched.step(); gstep += 1 |
| |
|
| | acc["L"] += float(L_total.item()) |
| | acc["Lf"] += float(L_flow.item()) |
| | acc["Lb"] += float(L_blocks.item()) |
| |
|
| | |
| | if it % 50 == 0: |
| | self.writer.add_scalar("train/total", float(L_total.item()), gstep) |
| | self.writer.add_scalar("train/flow", float(L_flow.item()), gstep) |
| | self.writer.add_scalar("train/blocks",float(L_blocks.item()), gstep) |
| | |
| | for k in list(lam.keys())[:4]: |
| | self.writer.add_scalar(f"lambda/{k}", lam[k], gstep) |
| |
|
| | |
| | if it % 10 == 0 or it == len(self.loader) - 1: |
| | pbar.set_postfix({ |
| | "L": f"{float(L_total.item()):.4f}", |
| | "Lf": f"{float(L_flow.item()):.4f}", |
| | "Lb": f"{float(L_blocks.item()):.4f}" |
| | }, refresh=False) |
| | |
| | del x_t, eps_hat, v_star, v_hat, s_feats_spatial, t_feats_spatial |
| |
|
| | pbar.close() |
| | |
| | n = len(self.loader) |
| | print(f"\n[Epoch {ep+1}] L={acc['L']/n:.4f} | L_flow={acc['Lf']/n:.4f} | L_blocks={acc['Lb']/n:.4f}") |
| | self.writer.add_scalar("epoch/total", acc['L']/n, ep+1) |
| | self.writer.add_scalar("epoch/flow", acc['Lf']/n, ep+1) |
| | self.writer.add_scalar("epoch/blocks",acc['Lb']/n, ep+1) |
| |
|
| | if (ep+1) % cfg.save_every == 0: |
| | self._save(ep+1, gstep) |
| |
|
| | self._save("final", gstep) |
| | self.writer.close() |
| |
|
| |
|
| | def _save(self, tag, gstep): |
| | """Save and convert to ComfyUI format, then upload.""" |
| | |
| | pt_path = Path(self.cfg.ckpt_dir) / f"{self.cfg.run_name}_e{tag}.pt" |
| | torch.save({ |
| | "cfg": asdict(self.cfg), |
| | "student": self.student.state_dict(), |
| | "opt": self.opt.state_dict(), |
| | "sched": self.sched.state_dict(), |
| | "gstep": gstep |
| | }, pt_path) |
| | print(f"β Saved temp .pt: {pt_path}") |
| | |
| | |
| | safetensors_path = self._convert_to_comfyui(pt_path, tag) |
| | |
| | |
| | if self.cfg.upload_every_epoch and self.cfg.hf_repo_id and safetensors_path: |
| | self._upload_to_hf(safetensors_path, tag) |
| | |
| | |
| | pt_path.unlink() |
| | print(f"β Cleaned up temp .pt file") |
| |
|
| | def _convert_to_comfyui(self, pt_path: Path, tag) -> Optional[Path]: |
| | """Convert .pt to ComfyUI-compatible safetensors.""" |
| | try: |
| | temp_pipeline = Path(self.cfg.ckpt_dir) / f"temp_pipeline_e{tag}" |
| | output_safetensors = Path(self.cfg.ckpt_dir) / f"{self.cfg.run_name}_e{tag}.safetensors" |
| | |
| | |
| | converter_path = Path(self.cfg.ckpt_dir) / "convert_diffusers_to_original_stable_diffusion.py" |
| | if not converter_path.exists(): |
| | print("π₯ Downloading official converter...") |
| | url = "https://raw.githubusercontent.com/huggingface/diffusers/main/scripts/convert_diffusers_to_original_stable_diffusion.py" |
| | urllib.request.urlretrieve(url, str(converter_path)) |
| | print("β Converter downloaded") |
| | |
| | |
| | print(f"π¦ Creating diffusers pipeline from checkpoint...") |
| | checkpoint = torch.load(pt_path, map_location='cpu') |
| | student_state = checkpoint.get('student', checkpoint) |
| | |
| | |
| | print("π₯ Loading base UNet...") |
| | unet = UNet2DConditionModel.from_pretrained( |
| | "runwayml/stable-diffusion-v1-5", |
| | subfolder="unet", |
| | torch_dtype=torch.float16 |
| | ) |
| | unet.load_state_dict(student_state, strict=False) |
| | print("β Loaded student weights into UNet") |
| | |
| | |
| | print("π₯ Loading base SD1.5 pipeline...") |
| | pipe = StableDiffusionPipeline.from_pretrained( |
| | "runwayml/stable-diffusion-v1-5", |
| | torch_dtype=torch.float16, |
| | safety_checker=None |
| | ) |
| | pipe.unet = unet |
| | print("β Replaced UNet with student") |
| | |
| | |
| | print(f"πΎ Saving diffusers pipeline...") |
| | pipe.save_pretrained(str(temp_pipeline), safe_serialization=True) |
| | print(f"β Pipeline saved to {temp_pipeline}") |
| | |
| | |
| | print(f"π Converting to ComfyUI format...") |
| | cmd = [ |
| | "python", str(converter_path), |
| | "--model_path", str(temp_pipeline), |
| | "--checkpoint_path", str(output_safetensors), |
| | "--half" |
| | ] |
| | |
| | result = subprocess.run(cmd, capture_output=True, text=True) |
| | if result.returncode != 0: |
| | print(f"β Conversion failed: {result.stderr}") |
| | return None |
| | |
| | |
| | if output_safetensors.exists(): |
| | size_mb = output_safetensors.stat().st_size / 1e6 |
| | print(f"β Converted: {output_safetensors.name} ({size_mb:.1f}MB)") |
| | |
| | |
| | shutil.rmtree(temp_pipeline) |
| | print("β Cleaned up temp pipeline") |
| | |
| | return output_safetensors |
| | else: |
| | print(f"β Output file not created") |
| | return None |
| | |
| | except Exception as e: |
| | print(f"β Conversion failed: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return None |
| |
|
| | def _upload_to_hf(self, path: Path, tag): |
| | """Upload safetensors to HuggingFace.""" |
| | try: |
| | api = HfApi() |
| | |
| | |
| | try: |
| | create_repo(self.cfg.hf_repo_id, exist_ok=True, private=False, repo_type="model") |
| | print(f"β Repo ready: {self.cfg.hf_repo_id}") |
| | except Exception: |
| | pass |
| | |
| | |
| | print(f"π€ Uploading to {self.cfg.hf_repo_id}...") |
| | api.upload_file( |
| | path_or_fileobj=str(path), |
| | path_in_repo=path.name, |
| | repo_id=self.cfg.hf_repo_id, |
| | repo_type="model", |
| | commit_message=f"Epoch {tag}" |
| | ) |
| | print(f"β
Uploaded: https://huggingface.co/{self.cfg.hf_repo_id}/{path.name}") |
| | |
| | except Exception as e: |
| | print(f"β οΈ Upload failed: {e}") |
| |
|
| | |
| | @torch.no_grad() |
| | def sample(self, prompts: List[str], steps: Optional[int]=None, guidance: Optional[float]=None) -> torch.Tensor: |
| | steps = steps or self.cfg.sample_steps |
| | guidance = guidance if guidance is not None else self.cfg.guidance_scale |
| | cond_e = self.teacher.encode(prompts) |
| | uncond_e = self.teacher.encode([""]*len(prompts)) |
| | sched = self.teacher.sched |
| | sched.set_timesteps(steps, device=self.device) |
| | x_t = torch.randn(len(prompts), 4, 64, 64, device=self.device) |
| |
|
| | for t_scalar in sched.timesteps: |
| | t = torch.full((x_t.shape[0],), t_scalar, device=self.device, dtype=torch.long) |
| | v_u, _ = self.student(x_t, t, uncond_e) |
| | v_c, _ = self.student(x_t, t, cond_e) |
| | v_hat = v_u + guidance*(v_c - v_u) |
| |
|
| | alpha, sigma = self.teacher.alpha_sigma(t) |
| | denom = (alpha**2 + sigma**2) |
| | x0_hat = (alpha * x_t - sigma * v_hat) / (denom + 1e-8) |
| | eps_hat = (x_t - alpha * x0_hat) / (sigma + 1e-8) |
| |
|
| | step = sched.step(model_output=eps_hat, timestep=t_scalar, sample=x_t) |
| | x_t = step.prev_sample |
| |
|
| | imgs = self.teacher.pipe.vae.decode(x_t / 0.18215).sample |
| | return imgs.clamp(-1,1) |
| |
|
| |
|
| | |
| | |
| | |
| | def main(): |
| | cfg = BaseConfig() |
| | print(json.dumps(asdict(cfg), indent=2)) |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | if device != "cuda": |
| | print("β οΈ A100 strongly recommended.") |
| | trainer = FlowMatchDavidTrainer(cfg, device=device) |
| | trainer.train() |
| | |
| | _ = trainer.sample(["a castle at sunset"], steps=10, guidance=7.0) |
| | print("β Inference sanity done.") |
| |
|
| | if __name__ == "__main__": |
| | main() |