qwenillustrious / arch /example_train.py
lsmpp's picture
Add files using upload-large-folder tool
4960ef6 verified
"""
Example Training Script using Arch Components
使用架构组件的示例训练脚本
"""
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import os
import argparse
from tqdm import tqdm
import wandb
from typing import Optional
# Import arch components
from arch import (
QwenTextEncoder,
QwenEmbeddingAdapter,
load_unet_from_safetensors,
load_vae_from_safetensors,
create_scheduler,
DiffusionLoss,
AdapterTrainingStep,
get_cosine_schedule_with_warmup,
EMAModel,
ImageCaptionDataset,
MultiAspectDataset,
create_dataloader
)
def parse_args():
parser = argparse.ArgumentParser(description="Train Qwen-SDXL Adapter")
# Model paths
parser.add_argument("--qwen_model_path", type=str, default="models/Qwen3-Embedding-0.6B")
parser.add_argument("--unet_path", type=str, default="models/extracted_components/waiNSFWIllustrious_v140_unet.safetensors")
parser.add_argument("--unet_config_path", type=str, default="models/extracted_components/waiNSFWIllustrious_v140_unet_config.json")
parser.add_argument("--vae_path", type=str, default="models/extracted_components/waiNSFWIllustrious_v140_vae.safetensors")
parser.add_argument("--vae_config_path", type=str, default="models/extracted_components/waiNSFWIllustrious_v140_vae_config.json")
# Data
parser.add_argument("--data_root", type=str, required=True, help="Root directory of training images")
parser.add_argument("--annotations_file", type=str, required=True, help="Path to annotations file (JSON/JSONL)")
parser.add_argument("--caption_column", type=str, default="caption")
parser.add_argument("--image_column", type=str, default="image")
parser.add_argument("--use_multi_aspect", action="store_true", help="Use multi-aspect ratio dataset")
# Training
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--num_epochs", type=int, default=10)
parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--max_grad_norm", type=float, default=1.0)
# Loss
parser.add_argument("--loss_type", type=str, default="mse", choices=["mse", "l1", "huber"])
parser.add_argument("--snr_gamma", type=float, default=None, help="SNR gamma for loss weighting")
parser.add_argument("--use_v_parameterization", action="store_true")
# Optimization
parser.add_argument("--optimizer", type=str, default="adamw", choices=["adamw", "adam"])
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--use_ema", action="store_true", help="Use EMA for adapter")
parser.add_argument("--ema_decay", type=float, default=0.9999)
# Checkpointing
parser.add_argument("--output_dir", type=str, default="./checkpoints")
parser.add_argument("--save_steps", type=int, default=1000)
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
# Logging
parser.add_argument("--logging_steps", type=int, default=50)
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--wandb_project", type=str, default="qwen-sdxl-training")
parser.add_argument("--wandb_run_name", type=str, default=None)
# Hardware
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float32", "float16", "bfloat16"])
parser.add_argument("--num_workers", type=int, default=4)
return parser.parse_args()
def setup_models(args):
"""Setup all model components"""
print("🚀 设置模型组件...")
# Convert dtype string to torch dtype
dtype_map = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16
}
dtype = dtype_map[args.dtype]
# Load text encoder
print("📝 加载 Qwen 文本编码器...")
text_encoder = QwenTextEncoder(
model_path=args.qwen_model_path,
device=args.device,
freeze_encoder=True
)
# Initialize adapter
print("🔧 初始化适配器...")
adapter = QwenEmbeddingAdapter()
adapter.to(args.device, dtype)
# Load UNet
print("🏗️ 加载 UNet...")
unet = load_unet_from_safetensors(
args.unet_path,
args.unet_config_path,
args.device,
dtype
)
# Load VAE
print("🎨 加载 VAE...")
vae = load_vae_from_safetensors(
args.vae_path,
args.vae_config_path,
args.device,
dtype
)
# Create scheduler
print("⏰ 创建调度器...")
noise_scheduler = create_scheduler("DDPM")
return text_encoder, adapter, unet, vae, noise_scheduler, dtype
def setup_data(args):
"""Setup data loaders"""
print("📚 设置数据加载器...")
if args.use_multi_aspect:
dataset = MultiAspectDataset(
data_root=args.data_root,
annotations_file=args.annotations_file,
caption_column=args.caption_column,
image_column=args.image_column
)
else:
dataset = ImageCaptionDataset(
data_root=args.data_root,
annotations_file=args.annotations_file,
caption_column=args.caption_column,
image_column=args.image_column
)
dataloader = create_dataloader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True
)
return dataloader
def setup_training(args, adapter, noise_scheduler):
"""Setup training components"""
print("🎯 设置训练组件...")
# Loss function
loss_fn = DiffusionLoss(
noise_scheduler=noise_scheduler,
loss_type=args.loss_type,
snr_gamma=args.snr_gamma,
use_v_parameterization=args.use_v_parameterization
)
# Optimizer
if args.optimizer == "adamw":
optimizer = optim.AdamW(
adapter.parameters(),
lr=args.learning_rate,
weight_decay=args.weight_decay,
betas=(0.9, 0.999)
)
else:
optimizer = optim.Adam(
adapter.parameters(),
lr=args.learning_rate,
weight_decay=args.weight_decay
)
# EMA
ema = None
if args.use_ema:
ema = EMAModel(adapter, decay=args.ema_decay)
return loss_fn, optimizer, ema
def train_step(training_step_fn, batch, optimizer, args, ema=None):
"""Execute one training step"""
# Handle different batch formats
if isinstance(batch["images"], list):
# Multi-size batch, train one by one
total_loss = 0
num_samples = 0
for i in range(len(batch["images"])):
images = batch["images"][i].unsqueeze(0)
captions = [batch["captions"][i]]
step_output = training_step_fn.training_step(images, captions)
loss = step_output["loss"] / args.gradient_accumulation_steps
loss.backward()
total_loss += loss.item()
num_samples += 1
avg_loss = total_loss / num_samples if num_samples > 0 else 0
else:
# Regular batch
images = batch["images"]
captions = batch["captions"]
step_output = training_step_fn.training_step(images, captions)
loss = step_output["loss"] / args.gradient_accumulation_steps
loss.backward()
avg_loss = loss.item()
# Gradient clipping and optimization step
torch.nn.utils.clip_grad_norm_(training_step_fn.adapter.parameters(), args.max_grad_norm)
optimizer.step()
optimizer.zero_grad()
# Update EMA
if ema is not None:
ema.update()
return avg_loss
def save_checkpoint(adapter, optimizer, ema, epoch, step, args):
"""Save training checkpoint"""
os.makedirs(args.output_dir, exist_ok=True)
# Save adapter
adapter_path = os.path.join(args.output_dir, f"adapter_epoch_{epoch}_step_{step}.safetensors")
if hasattr(adapter, 'save_adapter'):
adapter.save_adapter(adapter_path)
else:
import safetensors.torch
safetensors.torch.save_file(adapter.state_dict(), adapter_path)
# Save EMA adapter if available
if ema is not None:
ema.apply_shadow()
ema_path = os.path.join(args.output_dir, f"adapter_ema_epoch_{epoch}_step_{step}.safetensors")
import safetensors.torch
safetensors.torch.save_file(adapter.state_dict(), ema_path)
ema.restore()
# Save training state
state_path = os.path.join(args.output_dir, f"training_state_epoch_{epoch}_step_{step}.pt")
torch.save({
"epoch": epoch,
"step": step,
"optimizer_state_dict": optimizer.state_dict(),
"args": args
}, state_path)
print(f"💾 检查点已保存: epoch {epoch}, step {step}")
def main():
args = parse_args()
# Setup wandb
if args.use_wandb:
wandb.init(
project=args.wandb_project,
name=args.wandb_run_name,
config=vars(args)
)
# Setup models
text_encoder, adapter, unet, vae, noise_scheduler, dtype = setup_models(args)
# Setup data
dataloader = setup_data(args)
# Setup training
loss_fn, optimizer, ema = setup_training(args, adapter, noise_scheduler)
# Create training step function
training_step_fn = AdapterTrainingStep(
unet=unet,
vae=vae,
text_encoder=text_encoder,
adapter=adapter,
noise_scheduler=noise_scheduler,
loss_fn=loss_fn,
device=args.device,
dtype=dtype
)
# Setup learning rate scheduler
total_steps = len(dataloader) * args.num_epochs // args.gradient_accumulation_steps
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=total_steps
)
print(f"🎓 开始训练: {args.num_epochs} epochs, {len(dataloader)} steps/epoch")
print(f"📊 总训练步数: {total_steps}")
# Training loop
global_step = 0
for epoch in range(args.num_epochs):
adapter.train()
epoch_loss = 0
progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.num_epochs}")
for step, batch in enumerate(progress_bar):
step_loss = train_step(training_step_fn, batch, optimizer, args, ema)
epoch_loss += step_loss
# Update learning rate
lr_scheduler.step()
global_step += 1
# Logging
if global_step % args.logging_steps == 0:
avg_loss = epoch_loss / (step + 1)
current_lr = lr_scheduler.get_last_lr()[0]
progress_bar.set_postfix({
"loss": f"{step_loss:.4f}",
"avg_loss": f"{avg_loss:.4f}",
"lr": f"{current_lr:.2e}"
})
if args.use_wandb:
wandb.log({
"train/loss": step_loss,
"train/avg_loss": avg_loss,
"train/learning_rate": current_lr,
"train/epoch": epoch,
"train/step": global_step
})
# Save checkpoint
if global_step % args.save_steps == 0:
save_checkpoint(adapter, optimizer, ema, epoch, global_step, args)
# End of epoch
avg_epoch_loss = epoch_loss / len(dataloader)
print(f"📈 Epoch {epoch+1} 完成,平均损失: {avg_epoch_loss:.4f}")
# Save epoch checkpoint
save_checkpoint(adapter, optimizer, ema, epoch+1, global_step, args)
print("🎉 训练完成!")
if args.use_wandb:
wandb.finish()
if __name__ == "__main__":
main()