|
|
""" |
|
|
Example Training Script using Arch Components |
|
|
使用架构组件的示例训练脚本 |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import DataLoader |
|
|
import os |
|
|
import argparse |
|
|
from tqdm import tqdm |
|
|
import wandb |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
from arch import ( |
|
|
QwenTextEncoder, |
|
|
QwenEmbeddingAdapter, |
|
|
load_unet_from_safetensors, |
|
|
load_vae_from_safetensors, |
|
|
create_scheduler, |
|
|
DiffusionLoss, |
|
|
AdapterTrainingStep, |
|
|
get_cosine_schedule_with_warmup, |
|
|
EMAModel, |
|
|
ImageCaptionDataset, |
|
|
MultiAspectDataset, |
|
|
create_dataloader |
|
|
) |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="Train Qwen-SDXL Adapter") |
|
|
|
|
|
|
|
|
parser.add_argument("--qwen_model_path", type=str, default="models/Qwen3-Embedding-0.6B") |
|
|
parser.add_argument("--unet_path", type=str, default="models/extracted_components/waiNSFWIllustrious_v140_unet.safetensors") |
|
|
parser.add_argument("--unet_config_path", type=str, default="models/extracted_components/waiNSFWIllustrious_v140_unet_config.json") |
|
|
parser.add_argument("--vae_path", type=str, default="models/extracted_components/waiNSFWIllustrious_v140_vae.safetensors") |
|
|
parser.add_argument("--vae_config_path", type=str, default="models/extracted_components/waiNSFWIllustrious_v140_vae_config.json") |
|
|
|
|
|
|
|
|
parser.add_argument("--data_root", type=str, required=True, help="Root directory of training images") |
|
|
parser.add_argument("--annotations_file", type=str, required=True, help="Path to annotations file (JSON/JSONL)") |
|
|
parser.add_argument("--caption_column", type=str, default="caption") |
|
|
parser.add_argument("--image_column", type=str, default="image") |
|
|
parser.add_argument("--use_multi_aspect", action="store_true", help="Use multi-aspect ratio dataset") |
|
|
|
|
|
|
|
|
parser.add_argument("--batch_size", type=int, default=4) |
|
|
parser.add_argument("--learning_rate", type=float, default=1e-4) |
|
|
parser.add_argument("--num_epochs", type=int, default=10) |
|
|
parser.add_argument("--warmup_steps", type=int, default=500) |
|
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1) |
|
|
parser.add_argument("--max_grad_norm", type=float, default=1.0) |
|
|
|
|
|
|
|
|
parser.add_argument("--loss_type", type=str, default="mse", choices=["mse", "l1", "huber"]) |
|
|
parser.add_argument("--snr_gamma", type=float, default=None, help="SNR gamma for loss weighting") |
|
|
parser.add_argument("--use_v_parameterization", action="store_true") |
|
|
|
|
|
|
|
|
parser.add_argument("--optimizer", type=str, default="adamw", choices=["adamw", "adam"]) |
|
|
parser.add_argument("--weight_decay", type=float, default=0.01) |
|
|
parser.add_argument("--use_ema", action="store_true", help="Use EMA for adapter") |
|
|
parser.add_argument("--ema_decay", type=float, default=0.9999) |
|
|
|
|
|
|
|
|
parser.add_argument("--output_dir", type=str, default="./checkpoints") |
|
|
parser.add_argument("--save_steps", type=int, default=1000) |
|
|
parser.add_argument("--resume_from_checkpoint", type=str, default=None) |
|
|
|
|
|
|
|
|
parser.add_argument("--logging_steps", type=int, default=50) |
|
|
parser.add_argument("--use_wandb", action="store_true") |
|
|
parser.add_argument("--wandb_project", type=str, default="qwen-sdxl-training") |
|
|
parser.add_argument("--wandb_run_name", type=str, default=None) |
|
|
|
|
|
|
|
|
parser.add_argument("--device", type=str, default="cuda") |
|
|
parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float32", "float16", "bfloat16"]) |
|
|
parser.add_argument("--num_workers", type=int, default=4) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def setup_models(args): |
|
|
"""Setup all model components""" |
|
|
print("🚀 设置模型组件...") |
|
|
|
|
|
|
|
|
dtype_map = { |
|
|
"float32": torch.float32, |
|
|
"float16": torch.float16, |
|
|
"bfloat16": torch.bfloat16 |
|
|
} |
|
|
dtype = dtype_map[args.dtype] |
|
|
|
|
|
|
|
|
print("📝 加载 Qwen 文本编码器...") |
|
|
text_encoder = QwenTextEncoder( |
|
|
model_path=args.qwen_model_path, |
|
|
device=args.device, |
|
|
freeze_encoder=True |
|
|
) |
|
|
|
|
|
|
|
|
print("🔧 初始化适配器...") |
|
|
adapter = QwenEmbeddingAdapter() |
|
|
adapter.to(args.device, dtype) |
|
|
|
|
|
|
|
|
print("🏗️ 加载 UNet...") |
|
|
unet = load_unet_from_safetensors( |
|
|
args.unet_path, |
|
|
args.unet_config_path, |
|
|
args.device, |
|
|
dtype |
|
|
) |
|
|
|
|
|
|
|
|
print("🎨 加载 VAE...") |
|
|
vae = load_vae_from_safetensors( |
|
|
args.vae_path, |
|
|
args.vae_config_path, |
|
|
args.device, |
|
|
dtype |
|
|
) |
|
|
|
|
|
|
|
|
print("⏰ 创建调度器...") |
|
|
noise_scheduler = create_scheduler("DDPM") |
|
|
|
|
|
return text_encoder, adapter, unet, vae, noise_scheduler, dtype |
|
|
|
|
|
|
|
|
def setup_data(args): |
|
|
"""Setup data loaders""" |
|
|
print("📚 设置数据加载器...") |
|
|
|
|
|
if args.use_multi_aspect: |
|
|
dataset = MultiAspectDataset( |
|
|
data_root=args.data_root, |
|
|
annotations_file=args.annotations_file, |
|
|
caption_column=args.caption_column, |
|
|
image_column=args.image_column |
|
|
) |
|
|
else: |
|
|
dataset = ImageCaptionDataset( |
|
|
data_root=args.data_root, |
|
|
annotations_file=args.annotations_file, |
|
|
caption_column=args.caption_column, |
|
|
image_column=args.image_column |
|
|
) |
|
|
|
|
|
dataloader = create_dataloader( |
|
|
dataset, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=args.num_workers, |
|
|
pin_memory=True, |
|
|
drop_last=True |
|
|
) |
|
|
|
|
|
return dataloader |
|
|
|
|
|
|
|
|
def setup_training(args, adapter, noise_scheduler): |
|
|
"""Setup training components""" |
|
|
print("🎯 设置训练组件...") |
|
|
|
|
|
|
|
|
loss_fn = DiffusionLoss( |
|
|
noise_scheduler=noise_scheduler, |
|
|
loss_type=args.loss_type, |
|
|
snr_gamma=args.snr_gamma, |
|
|
use_v_parameterization=args.use_v_parameterization |
|
|
) |
|
|
|
|
|
|
|
|
if args.optimizer == "adamw": |
|
|
optimizer = optim.AdamW( |
|
|
adapter.parameters(), |
|
|
lr=args.learning_rate, |
|
|
weight_decay=args.weight_decay, |
|
|
betas=(0.9, 0.999) |
|
|
) |
|
|
else: |
|
|
optimizer = optim.Adam( |
|
|
adapter.parameters(), |
|
|
lr=args.learning_rate, |
|
|
weight_decay=args.weight_decay |
|
|
) |
|
|
|
|
|
|
|
|
ema = None |
|
|
if args.use_ema: |
|
|
ema = EMAModel(adapter, decay=args.ema_decay) |
|
|
|
|
|
return loss_fn, optimizer, ema |
|
|
|
|
|
|
|
|
def train_step(training_step_fn, batch, optimizer, args, ema=None): |
|
|
"""Execute one training step""" |
|
|
|
|
|
if isinstance(batch["images"], list): |
|
|
|
|
|
total_loss = 0 |
|
|
num_samples = 0 |
|
|
|
|
|
for i in range(len(batch["images"])): |
|
|
images = batch["images"][i].unsqueeze(0) |
|
|
captions = [batch["captions"][i]] |
|
|
|
|
|
step_output = training_step_fn.training_step(images, captions) |
|
|
loss = step_output["loss"] / args.gradient_accumulation_steps |
|
|
|
|
|
loss.backward() |
|
|
total_loss += loss.item() |
|
|
num_samples += 1 |
|
|
|
|
|
avg_loss = total_loss / num_samples if num_samples > 0 else 0 |
|
|
else: |
|
|
|
|
|
images = batch["images"] |
|
|
captions = batch["captions"] |
|
|
|
|
|
step_output = training_step_fn.training_step(images, captions) |
|
|
loss = step_output["loss"] / args.gradient_accumulation_steps |
|
|
|
|
|
loss.backward() |
|
|
avg_loss = loss.item() |
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(training_step_fn.adapter.parameters(), args.max_grad_norm) |
|
|
optimizer.step() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
if ema is not None: |
|
|
ema.update() |
|
|
|
|
|
return avg_loss |
|
|
|
|
|
|
|
|
def save_checkpoint(adapter, optimizer, ema, epoch, step, args): |
|
|
"""Save training checkpoint""" |
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
adapter_path = os.path.join(args.output_dir, f"adapter_epoch_{epoch}_step_{step}.safetensors") |
|
|
if hasattr(adapter, 'save_adapter'): |
|
|
adapter.save_adapter(adapter_path) |
|
|
else: |
|
|
import safetensors.torch |
|
|
safetensors.torch.save_file(adapter.state_dict(), adapter_path) |
|
|
|
|
|
|
|
|
if ema is not None: |
|
|
ema.apply_shadow() |
|
|
ema_path = os.path.join(args.output_dir, f"adapter_ema_epoch_{epoch}_step_{step}.safetensors") |
|
|
import safetensors.torch |
|
|
safetensors.torch.save_file(adapter.state_dict(), ema_path) |
|
|
ema.restore() |
|
|
|
|
|
|
|
|
state_path = os.path.join(args.output_dir, f"training_state_epoch_{epoch}_step_{step}.pt") |
|
|
torch.save({ |
|
|
"epoch": epoch, |
|
|
"step": step, |
|
|
"optimizer_state_dict": optimizer.state_dict(), |
|
|
"args": args |
|
|
}, state_path) |
|
|
|
|
|
print(f"💾 检查点已保存: epoch {epoch}, step {step}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
if args.use_wandb: |
|
|
wandb.init( |
|
|
project=args.wandb_project, |
|
|
name=args.wandb_run_name, |
|
|
config=vars(args) |
|
|
) |
|
|
|
|
|
|
|
|
text_encoder, adapter, unet, vae, noise_scheduler, dtype = setup_models(args) |
|
|
|
|
|
|
|
|
dataloader = setup_data(args) |
|
|
|
|
|
|
|
|
loss_fn, optimizer, ema = setup_training(args, adapter, noise_scheduler) |
|
|
|
|
|
|
|
|
training_step_fn = AdapterTrainingStep( |
|
|
unet=unet, |
|
|
vae=vae, |
|
|
text_encoder=text_encoder, |
|
|
adapter=adapter, |
|
|
noise_scheduler=noise_scheduler, |
|
|
loss_fn=loss_fn, |
|
|
device=args.device, |
|
|
dtype=dtype |
|
|
) |
|
|
|
|
|
|
|
|
total_steps = len(dataloader) * args.num_epochs // args.gradient_accumulation_steps |
|
|
lr_scheduler = get_cosine_schedule_with_warmup( |
|
|
optimizer, |
|
|
num_warmup_steps=args.warmup_steps, |
|
|
num_training_steps=total_steps |
|
|
) |
|
|
|
|
|
print(f"🎓 开始训练: {args.num_epochs} epochs, {len(dataloader)} steps/epoch") |
|
|
print(f"📊 总训练步数: {total_steps}") |
|
|
|
|
|
|
|
|
global_step = 0 |
|
|
|
|
|
for epoch in range(args.num_epochs): |
|
|
adapter.train() |
|
|
epoch_loss = 0 |
|
|
|
|
|
progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.num_epochs}") |
|
|
|
|
|
for step, batch in enumerate(progress_bar): |
|
|
step_loss = train_step(training_step_fn, batch, optimizer, args, ema) |
|
|
epoch_loss += step_loss |
|
|
|
|
|
|
|
|
lr_scheduler.step() |
|
|
|
|
|
global_step += 1 |
|
|
|
|
|
|
|
|
if global_step % args.logging_steps == 0: |
|
|
avg_loss = epoch_loss / (step + 1) |
|
|
current_lr = lr_scheduler.get_last_lr()[0] |
|
|
|
|
|
progress_bar.set_postfix({ |
|
|
"loss": f"{step_loss:.4f}", |
|
|
"avg_loss": f"{avg_loss:.4f}", |
|
|
"lr": f"{current_lr:.2e}" |
|
|
}) |
|
|
|
|
|
if args.use_wandb: |
|
|
wandb.log({ |
|
|
"train/loss": step_loss, |
|
|
"train/avg_loss": avg_loss, |
|
|
"train/learning_rate": current_lr, |
|
|
"train/epoch": epoch, |
|
|
"train/step": global_step |
|
|
}) |
|
|
|
|
|
|
|
|
if global_step % args.save_steps == 0: |
|
|
save_checkpoint(adapter, optimizer, ema, epoch, global_step, args) |
|
|
|
|
|
|
|
|
avg_epoch_loss = epoch_loss / len(dataloader) |
|
|
print(f"📈 Epoch {epoch+1} 完成,平均损失: {avg_epoch_loss:.4f}") |
|
|
|
|
|
|
|
|
save_checkpoint(adapter, optimizer, ema, epoch+1, global_step, args) |
|
|
|
|
|
print("🎉 训练完成!") |
|
|
|
|
|
if args.use_wandb: |
|
|
wandb.finish() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|