| import argparse |
| import copy |
| from copy import deepcopy |
| import logging |
| import os |
| from pathlib import Path |
| from collections import OrderedDict |
| import json |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
| from tqdm.auto import tqdm |
| from torch.utils.data import DataLoader |
|
|
| from accelerate import Accelerator, DistributedDataParallelKwargs |
| from accelerate.logging import get_logger |
| from accelerate.utils import ProjectConfiguration, set_seed |
|
|
| from models.sit import SiT_models |
| from loss import SILoss |
| from utils import load_encoders |
|
|
| from dataset import CustomDataset |
| from diffusers.models import AutoencoderKL |
| |
| import wandb |
| import math |
| from torchvision.utils import make_grid |
| from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
| from torchvision.transforms import Normalize |
| from PIL import Image |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| def semantic_dim_from_enc_type(enc_type): |
| """DINOv2 等 enc_type 字符串推断 class token 维度(与预处理特征一致)。""" |
| if enc_type is None: |
| return 768 |
| s = str(enc_type).lower() |
| if "vit-g" in s or "vitg" in s: |
| return 1536 |
| if "vit-l" in s or "vitl" in s: |
| return 1024 |
| if "vit-s" in s or "vits" in s: |
| return 384 |
| return 768 |
|
|
|
|
| CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073) |
| CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711) |
|
|
|
|
|
|
| def preprocess_raw_image(x, enc_type): |
| resolution = x.shape[-1] |
| if 'clip' in enc_type: |
| x = x / 255. |
| x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic') |
| x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x) |
| elif 'mocov3' in enc_type or 'mae' in enc_type: |
| x = x / 255. |
| x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) |
| elif 'dinov2' in enc_type: |
| x = x / 255. |
| x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) |
| x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic') |
| elif 'dinov1' in enc_type: |
| x = x / 255. |
| x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) |
| elif 'jepa' in enc_type: |
| x = x / 255. |
| x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) |
| x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic') |
|
|
| return x |
|
|
|
|
| def array2grid(x): |
| nrow = round(math.sqrt(x.size(0))) |
| x = make_grid(x.clamp(0, 1), nrow=nrow, value_range=(0, 1)) |
| x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() |
| return x |
|
|
|
|
| @torch.no_grad() |
| def sample_posterior(moments, latents_scale=1., latents_bias=0.): |
| device = moments.device |
| |
| mean, std = torch.chunk(moments, 2, dim=1) |
| z = mean + std * torch.randn_like(mean) |
| z = (z * latents_scale + latents_bias) |
| return z |
|
|
|
|
| @torch.no_grad() |
| def update_ema(ema_model, model, decay=0.9999): |
| """ |
| Step the EMA model towards the current model. |
| """ |
| ema_params = OrderedDict(ema_model.named_parameters()) |
| model_params = OrderedDict(model.named_parameters()) |
|
|
| for name, param in model_params.items(): |
| name = name.replace("module.", "") |
| |
| ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) |
|
|
|
|
| def create_logger(logging_dir): |
| """ |
| Create a logger that writes to a log file and stdout. |
| """ |
| logging.basicConfig( |
| level=logging.INFO, |
| format='[\033[34m%(asctime)s\033[0m] %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S', |
| handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] |
| ) |
| logger = logging.getLogger(__name__) |
| return logger |
|
|
|
|
| def requires_grad(model, flag=True): |
| """ |
| Set requires_grad flag for all parameters in a model. |
| """ |
| for p in model.parameters(): |
| p.requires_grad = flag |
|
|
|
|
| |
| |
| |
|
|
| def main(args): |
| |
| logging_dir = Path(args.output_dir, args.logging_dir) |
| accelerator_project_config = ProjectConfiguration( |
| project_dir=args.output_dir, logging_dir=logging_dir |
| ) |
|
|
| accelerator = Accelerator( |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| mixed_precision=args.mixed_precision, |
| log_with=args.report_to, |
| project_config=accelerator_project_config, |
| kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)] |
| ) |
|
|
| if accelerator.is_main_process: |
| os.makedirs(args.output_dir, exist_ok=True) |
| save_dir = os.path.join(args.output_dir, args.exp_name) |
| os.makedirs(save_dir, exist_ok=True) |
| args_dict = vars(args) |
| |
| json_dir = os.path.join(save_dir, "args.json") |
| with open(json_dir, 'w') as f: |
| json.dump(args_dict, f, indent=4) |
| checkpoint_dir = f"{save_dir}/checkpoints" |
| os.makedirs(checkpoint_dir, exist_ok=True) |
| logger = create_logger(save_dir) |
| logger.info(f"Experiment directory created at {save_dir}") |
| device = accelerator.device |
| if torch.backends.mps.is_available(): |
| accelerator.native_amp = False |
| if args.seed is not None: |
| set_seed(args.seed + accelerator.process_index) |
| |
| |
| assert args.resolution % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)." |
| latent_size = args.resolution // 8 |
|
|
| train_dataset = CustomDataset( |
| args.data_dir, semantic_features_dir=args.semantic_features_dir |
| ) |
| use_preprocessed_semantic = train_dataset.use_preprocessed_semantic |
|
|
| if use_preprocessed_semantic: |
| encoders, encoder_types, architectures = [], [], [] |
| z_dims = [semantic_dim_from_enc_type(args.enc_type)] |
| if accelerator.is_main_process: |
| logger.info( |
| f"Preprocessed semantic features: skip loading online encoder, z_dims={z_dims}" |
| ) |
| elif args.enc_type is not None: |
| encoders, encoder_types, architectures = load_encoders( |
| args.enc_type, device, args.resolution |
| ) |
| z_dims = [encoder.embed_dim for encoder in encoders] |
| else: |
| raise NotImplementedError() |
| block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm} |
| model = SiT_models[args.model]( |
| input_size=latent_size, |
| num_classes=args.num_classes, |
| use_cfg = (args.cfg_prob > 0), |
| z_dims = z_dims, |
| encoder_depth=args.encoder_depth, |
| **block_kwargs |
| ) |
|
|
| model = model.to(device) |
| ema = deepcopy(model).to(device) |
| requires_grad(ema, False) |
| |
| latents_scale = torch.tensor( |
| [0.18215, 0.18215, 0.18215, 0.18215] |
| ).view(1, 4, 1, 1).to(device) |
| latents_bias = torch.tensor( |
| [0., 0., 0., 0.] |
| ).view(1, 4, 1, 1).to(device) |
|
|
| |
| try: |
| from preprocessing import dnnlib |
| cache_dir = dnnlib.make_cache_dir_path("diffusers") |
| os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" |
| os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" |
| os.environ["HF_HOME"] = cache_dir |
| try: |
| vae = AutoencoderKL.from_pretrained( |
| "stabilityai/sd-vae-ft-mse", |
| cache_dir=cache_dir, |
| local_files_only=True, |
| ).to(device) |
| vae.eval() |
| if accelerator.is_main_process: |
| logger.info( |
| "Loaded VAE 'stabilityai/sd-vae-ft-mse' from local diffusers cache " |
| f"at '{cache_dir}' for intermediate sampling." |
| ) |
| except Exception as e_main: |
| vae = None |
| candidate_dir = None |
| possible_roots = [ |
| cache_dir, |
| os.path.join(os.path.expanduser("~"), ".cache", "dnnlib", "diffusers"), |
| os.path.join(os.path.expanduser("~"), ".cache", "diffusers"), |
| os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"), |
| ] |
| checked_roots = [] |
| for root_dir in possible_roots: |
| if not os.path.isdir(root_dir): |
| continue |
| checked_roots.append(root_dir) |
| for root, dirs, files in os.walk(root_dir): |
| if "config.json" in files and "sd-vae-ft-mse" in root.replace("\\", "/"): |
| candidate_dir = root |
| break |
| if candidate_dir is not None: |
| break |
| if candidate_dir is not None: |
| try: |
| vae = AutoencoderKL.from_pretrained( |
| candidate_dir, |
| local_files_only=True, |
| ).to(device) |
| vae.eval() |
| if accelerator.is_main_process: |
| logger.info( |
| "Loaded VAE 'stabilityai/sd-vae-ft-mse' from discovered local path " |
| f"'{candidate_dir}'. Searched roots: {checked_roots}" |
| ) |
| except Exception as e_fallback: |
| if accelerator.is_main_process: |
| logger.warning( |
| "Tried to load VAE from discovered local path " |
| f"'{candidate_dir}' but failed: {e_fallback}" |
| ) |
| if vae is None and accelerator.is_main_process: |
| logger.warning( |
| "Could not load VAE 'stabilityai/sd-vae-ft-mse' via repo name or local search. " |
| f"Last repo-level error: {e_main}" |
| ) |
| except Exception as e: |
| vae = None |
| if accelerator.is_main_process: |
| logger.warning( |
| f"Failed to initialize VAE loading logic (will skip image decoding): {e}" |
| ) |
|
|
| |
| loss_fn = SILoss( |
| prediction=args.prediction, |
| path_type=args.path_type, |
| encoders=encoders, |
| accelerator=accelerator, |
| latents_scale=latents_scale, |
| latents_bias=latents_bias, |
| weighting=args.weighting, |
| t_c=args.t_c, |
| ot_cls=args.ot_cls, |
| ) |
| if accelerator.is_main_process: |
| logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}") |
| |
| |
| if args.allow_tf32: |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=args.learning_rate, |
| betas=(args.adam_beta1, args.adam_beta2), |
| weight_decay=args.adam_weight_decay, |
| eps=args.adam_epsilon, |
| ) |
| |
| |
| local_batch_size = int(args.batch_size // accelerator.num_processes) |
| train_dataloader = DataLoader( |
| train_dataset, |
| batch_size=local_batch_size, |
| shuffle=True, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| drop_last=True |
| ) |
| if accelerator.is_main_process: |
| logger.info(f"Dataset contains {len(train_dataset):,} images ({args.data_dir})") |
| |
| |
| update_ema(ema, model, decay=0) |
| model.train() |
| ema.eval() |
| |
| |
| global_step = 0 |
| if args.resume_step > 0: |
| ckpt_name = str(args.resume_step).zfill(7) +'.pt' |
| ckpt = torch.load( |
| f'{os.path.join(args.output_dir, args.exp_name)}/checkpoints/{ckpt_name}', |
| map_location='cpu', |
| ) |
| model.load_state_dict(ckpt['model']) |
| ema.load_state_dict(ckpt['ema']) |
| optimizer.load_state_dict(ckpt['opt']) |
| global_step = ckpt['steps'] |
|
|
| model, optimizer, train_dataloader = accelerator.prepare( |
| model, optimizer, train_dataloader |
| ) |
|
|
| if accelerator.is_main_process: |
| tracker_config = vars(copy.deepcopy(args)) |
| accelerator.init_trackers( |
| project_name="REG", |
| config=tracker_config, |
| init_kwargs={ |
| "wandb": {"name": f"{args.exp_name}"} |
| }, |
| ) |
|
|
| |
| progress_bar = tqdm( |
| range(0, args.max_train_steps), |
| initial=global_step, |
| desc="Steps", |
| |
| disable=not accelerator.is_local_main_process, |
| ) |
|
|
| |
| sample_batch_size = 64 // accelerator.num_processes |
| first_batch = next(iter(train_dataloader)) |
| if len(first_batch) == 4: |
| gt_raw_images, gt_xs, _, _ = first_batch |
| else: |
| gt_raw_images, gt_xs, _ = first_batch |
| assert gt_raw_images.shape[-1] == args.resolution |
| gt_xs = gt_xs[:sample_batch_size] |
| gt_xs = sample_posterior( |
| gt_xs.to(device), latents_scale=latents_scale, latents_bias=latents_bias |
| ) |
| ys = torch.randint(1000, size=(sample_batch_size,), device=device) |
| ys = ys.to(device) |
| |
| n = ys.size(0) |
| xT = torch.randn((n, 4, latent_size, latent_size), device=device) |
| |
| for epoch in range(args.epochs): |
| model.train() |
| for batch in train_dataloader: |
| if len(batch) == 4: |
| raw_image, x, r_preprocessed, y = batch |
| use_sem_file = True |
| else: |
| raw_image, x, y = batch |
| r_preprocessed = None |
| use_sem_file = False |
|
|
| raw_image = raw_image.to(device) |
| x = x.squeeze(dim=1).to(device).float() |
| y = y.to(device) |
| if args.legacy: |
| |
| |
| |
| drop_ids = torch.rand(y.shape[0], device=y.device) < args.cfg_prob |
| labels = torch.where(drop_ids, args.num_classes, y) |
| else: |
| labels = y |
| with torch.no_grad(): |
| x = sample_posterior(x, latents_scale=latents_scale, latents_bias=latents_bias) |
| zs = [] |
| if use_sem_file and r_preprocessed is not None: |
| cls_token = r_preprocessed.to(device).float() |
| if cls_token.dim() == 1: |
| cls_token = cls_token.unsqueeze(0) |
| while cls_token.dim() > 2: |
| cls_token = cls_token.squeeze(1) |
| base_m = model.module if hasattr(model, "module") else model |
| n_pad = base_m.x_embedder.num_patches |
| zs = [ |
| torch.cat( |
| [ |
| cls_token.unsqueeze(1), |
| cls_token.unsqueeze(1).expand(-1, n_pad, -1), |
| ], |
| dim=1, |
| ) |
| ] |
| else: |
| with accelerator.autocast(): |
| for encoder, encoder_type, arch in zip( |
| encoders, encoder_types, architectures |
| ): |
| raw_image_ = preprocess_raw_image(raw_image, encoder_type) |
| z = encoder.forward_features(raw_image_) |
| if 'dinov2' in encoder_type: |
| dense_z = z['x_norm_patchtokens'] |
| cls_token = z['x_norm_clstoken'] |
| dense_z = torch.cat([cls_token.unsqueeze(1), dense_z], dim=1) |
| else: |
| exit() |
| zs.append(dense_z) |
|
|
| with accelerator.accumulate(model): |
| model_kwargs = dict(y=labels) |
| loss1, proj_loss1, time_input, noises, loss2 = loss_fn(model, x, model_kwargs, zs=zs, |
| cls_token=cls_token, |
| time_input=None, noises=None) |
| loss_mean = loss1.mean() |
| loss_mean_cls = loss2.mean() * args.cls |
| proj_loss_mean = proj_loss1.mean() * args.proj_coeff |
| loss = loss_mean + proj_loss_mean + loss_mean_cls |
|
|
|
|
| |
| accelerator.backward(loss) |
| if accelerator.sync_gradients: |
| params_to_clip = model.parameters() |
| grad_norm = accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) |
| optimizer.step() |
| optimizer.zero_grad(set_to_none=True) |
|
|
| if accelerator.sync_gradients: |
| update_ema(ema, model) |
| |
| |
| if accelerator.sync_gradients: |
| progress_bar.update(1) |
| global_step += 1 |
| if global_step % args.checkpointing_steps == 0 and global_step > 0: |
| if accelerator.is_main_process: |
| checkpoint = { |
| "model": model.module.state_dict(), |
| "ema": ema.state_dict(), |
| "opt": optimizer.state_dict(), |
| "args": args, |
| "steps": global_step, |
| } |
| checkpoint_path = f"{checkpoint_dir}/{global_step:07d}.pt" |
| torch.save(checkpoint, checkpoint_path) |
| logger.info(f"Saved checkpoint to {checkpoint_path}") |
|
|
| if (global_step == 1 or (global_step % args.sampling_steps == 0 and global_step > 0)): |
| t_mid_vis = float(args.t_c) |
| tc_tag = f"{t_mid_vis:.4f}".rstrip("0").rstrip(".").replace(".", "_") |
| logging.info( |
| f"Generating EMA samples (Euler-Maruyama; t≈{t_mid_vis:g} → t=0)..." |
| ) |
| ema.eval() |
| with torch.no_grad(): |
| latent_size = args.resolution // 8 |
| n_samples = min(16, args.batch_size) |
| base_model = model.module if hasattr(model, "module") else model |
| cls_dim = base_model.z_dims[0] |
| shared_seed = torch.randint(0, 2**32, (1,), device=device).item() |
| torch.manual_seed(shared_seed) |
| z_init = torch.randn(n_samples, base_model.in_channels, latent_size, latent_size, device=device) |
| torch.manual_seed(shared_seed) |
| cls_init = torch.randn(n_samples, cls_dim, device=device) |
| y_samples = torch.randint(0, args.num_classes, (n_samples,), device=device) |
|
|
| from samplers import euler_maruyama_sampler |
| z_0, z_mid, _ = euler_maruyama_sampler( |
| ema, |
| z_init, |
| y_samples, |
| num_steps=50, |
| cfg_scale=1.0, |
| guidance_low=0.0, |
| guidance_high=1.0, |
| path_type=args.path_type, |
| cls_latents=cls_init, |
| args=args, |
| return_mid_state=True, |
| t_mid=t_mid_vis, |
| ) |
|
|
| samples_root = os.path.join(args.output_dir, args.exp_name, "samples") |
| t0_dir = os.path.join(samples_root, "t0") |
| t_mid_dir = os.path.join(samples_root, f"t0_{tc_tag}") |
| os.makedirs(t0_dir, exist_ok=True) |
| os.makedirs(t_mid_dir, exist_ok=True) |
|
|
| if vae is not None: |
| z_f = z_0.to(dtype=torch.float32) |
| samples_final = vae.decode((z_f - latents_bias) / latents_scale).sample |
| samples_final = (samples_final + 1) / 2.0 |
| samples_final = samples_final.clamp(0, 1) |
| grid_final = array2grid(samples_final) |
| Image.fromarray(grid_final).save( |
| os.path.join(t0_dir, f"step_{global_step:07d}_t0.png") |
| ) |
|
|
| if z_mid is not None: |
| z_m = z_mid.to(dtype=torch.float32) |
| samples_mid = vae.decode((z_m - latents_bias) / latents_scale).sample |
| samples_mid = (samples_mid + 1) / 2.0 |
| samples_mid = samples_mid.clamp(0, 1) |
| grid_mid = array2grid(samples_mid) |
| Image.fromarray(grid_mid).save( |
| os.path.join(t_mid_dir, f"step_{global_step:07d}_t0_{tc_tag}.png") |
| ) |
| else: |
| logging.warning( |
| f"Sampling time grid did not bracket t_mid={t_mid_vis:g}; " |
| f"skip t0_{tc_tag} image this step." |
| ) |
|
|
| del z_init, cls_init, y_samples, z_0 |
| if z_mid is not None: |
| del z_mid |
| if vae is not None: |
| del samples_final, grid_final |
| if "samples_mid" in locals(): |
| del samples_mid, grid_mid |
| torch.cuda.empty_cache() |
|
|
|
|
| logs = { |
| "loss_final": accelerator.gather(loss).mean().detach().item(), |
| "loss_mean": accelerator.gather(loss_mean).mean().detach().item(), |
| "proj_loss": accelerator.gather(proj_loss_mean).mean().detach().item(), |
| "loss_mean_cls": accelerator.gather(loss_mean_cls).mean().detach().item(), |
| "grad_norm": accelerator.gather(grad_norm).mean().detach().item() |
| } |
|
|
| log_message = ", ".join(f"{key}: {value:.6f}" for key, value in logs.items()) |
| logging.info(f"Step: {global_step}, Training Logs: {log_message}") |
|
|
| progress_bar.set_postfix(**logs) |
| accelerator.log(logs, step=global_step) |
|
|
| if global_step >= args.max_train_steps: |
| break |
| if global_step >= args.max_train_steps: |
| break |
|
|
| model.eval() |
| |
| |
| accelerator.wait_for_everyone() |
| if accelerator.is_main_process: |
| logger.info("Done!") |
| accelerator.end_training() |
|
|
| def parse_args(input_args=None): |
| parser = argparse.ArgumentParser(description="Training") |
|
|
| |
| parser.add_argument("--output-dir", type=str, default="exps") |
| parser.add_argument("--exp-name", type=str, required=True) |
| parser.add_argument("--logging-dir", type=str, default="logs") |
| parser.add_argument("--report-to", type=str, default="wandb") |
| parser.add_argument("--sampling-steps", type=int, default=2000) |
| parser.add_argument("--resume-step", type=int, default=0) |
|
|
| |
| parser.add_argument("--model", type=str) |
| parser.add_argument("--num-classes", type=int, default=1000) |
| parser.add_argument("--encoder-depth", type=int, default=8) |
| parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=True) |
| parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False) |
| parser.add_argument("--ops-head", type=int, default=16) |
|
|
| |
| parser.add_argument("--data-dir", type=str, default="../data/imagenet256") |
| parser.add_argument( |
| "--semantic-features-dir", |
| type=str, |
| default=None, |
| help="预处理 DINOv2 class token 等特征目录(含 dataset.json)。" |
| "默认 None 时若存在 data-dir/imagenet_256_features/dinov2-vit-b_tmp/gpu0 则自动使用。", |
| ) |
| parser.add_argument("--resolution", type=int, choices=[256, 512], default=256) |
| parser.add_argument("--batch-size", type=int, default=256) |
|
|
| |
| parser.add_argument("--allow-tf32", action="store_true") |
| parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["no", "fp16", "bf16"]) |
|
|
| |
| parser.add_argument("--epochs", type=int, default=1400) |
| parser.add_argument("--max-train-steps", type=int, default=1000000) |
| parser.add_argument("--checkpointing-steps", type=int, default=10000) |
| parser.add_argument("--gradient-accumulation-steps", type=int, default=1) |
| parser.add_argument("--learning-rate", type=float, default=1e-4) |
| parser.add_argument("--adam-beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") |
| parser.add_argument("--adam-beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") |
| parser.add_argument("--adam-weight-decay", type=float, default=0., help="Weight decay to use.") |
| parser.add_argument("--adam-epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") |
| parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.") |
|
|
| |
| parser.add_argument("--seed", type=int, default=0) |
|
|
| |
| parser.add_argument("--num-workers", type=int, default=4) |
|
|
| |
| parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"]) |
| parser.add_argument("--prediction", type=str, default="v", choices=["v"]) |
| parser.add_argument("--cfg-prob", type=float, default=0.1) |
| parser.add_argument("--enc-type", type=str, default='dinov2-vit-b') |
| parser.add_argument("--proj-coeff", type=float, default=0.5) |
| parser.add_argument("--weighting", default="uniform", type=str, help="Max gradient norm.") |
| parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False) |
| parser.add_argument("--cls", type=float, default=0.03) |
| parser.add_argument( |
| "--t-c", |
| type=float, |
| default=0.5, |
| help="语义分界时刻(与脚本内 t 约定一致:t=1 噪声→t=0 数据)。" |
| "t∈(t_c,1]:cls 沿 OT 配对后的路径插值(CFM/OT-CFM 式 minibatch OT);" |
| "t∈[0,t_c]:cls 固定为真实 encoder cls,目标 cls 速度为 0。", |
| ) |
| parser.add_argument( |
| "--ot-cls", |
| action=argparse.BooleanOptionalAction, |
| default=True, |
| help="在 t>t_c 段对 cls 噪声与 batch 内 cls_gt 做 minibatch 最优传输配对(需 scipy);关闭则退化为独立高斯噪声配对。", |
| ) |
| if input_args is not None: |
| args = parser.parse_args(input_args) |
| else: |
| args = parser.parse_args() |
|
|
| return args |
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| |
| main(args) |
|
|