|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from cosmos_predict1.diffusion.training.utils.peft.lora_attn import build_attn_lora |
|
|
from cosmos_predict1.diffusion.utils.customization.customization_manager import CustomizationType |
|
|
from cosmos_predict1.utils import log |
|
|
from cosmos_predict1.utils.misc import count_params |
|
|
|
|
|
|
|
|
def get_all_lora_params(model): |
|
|
""" |
|
|
Get all LoRA weight parameters in the model |
|
|
""" |
|
|
lora_modules = [mod for name, mod in model.named_modules() if "lora.net.0" in name or "lora.net.1" in name] |
|
|
lora_params = [(name, param) for mod in lora_modules for name, param in mod.named_parameters()] |
|
|
log.info(f"Found {len(lora_params)} LoRA weight matrices") |
|
|
return lora_params |
|
|
|
|
|
|
|
|
def setup_lora_requires_grad(model): |
|
|
""" |
|
|
Freeze all model parameters except LoRA parameters. |
|
|
""" |
|
|
num_param = count_params(model, verbose=True) |
|
|
log.critical(f"Model has {num_param * 1e-6:.2f}M parameters before freezing") |
|
|
lora_params = get_all_lora_params(model) |
|
|
num_lora_param = sum([p.numel() for _, p in lora_params]) |
|
|
log.info(f"Total number of LoRA parameters: {num_lora_param * 1e-6:.2f}M") |
|
|
if num_lora_param > 0: |
|
|
log.info("Freezing all parameters") |
|
|
model.requires_grad_(False) |
|
|
log.info("Unfreezing LoRA parameters") |
|
|
for name, param in lora_params: |
|
|
|
|
|
param.requires_grad_(True) |
|
|
num_param = count_params(model, verbose=True) |
|
|
log.critical(f"Model has {num_param * 1e-6:.2f}M parameters after freezing") |
|
|
return num_lora_param |
|
|
|
|
|
|
|
|
def add_lora_layers(model, peft_control_config): |
|
|
for i, block_name in enumerate(model.net.blocks): |
|
|
block = model.net.blocks[block_name] |
|
|
peft_control = peft_control_config.get(i, {}) |
|
|
for j, subblock in enumerate(block.blocks): |
|
|
block_type = subblock.block_type |
|
|
peft_control_subblock = peft_control.get(block_type.upper(), {}) |
|
|
customization_type = peft_control_subblock.get("customization_type", None) |
|
|
if customization_type == CustomizationType.LORA: |
|
|
if block_type.upper() in ["CA", "FA"]: |
|
|
build_attn_lora(subblock.block.attn, peft_control_subblock) |
|
|
|