|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from einops import rearrange |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
from transformer_engine.pytorch.attention import apply_rotary_pos_emb |
|
|
|
|
|
from cosmos_predict1.diffusion.module.attention import Attention |
|
|
from cosmos_predict1.diffusion.training.utils.peft.lora_net import LoRALinearLayer, TELoRALinearLayer |
|
|
from cosmos_predict1.diffusion.utils.customization.customization_manager import CustomizationType |
|
|
|
|
|
try: |
|
|
from megatron.core import parallel_state |
|
|
|
|
|
USE_MEGATRON = True |
|
|
except ImportError: |
|
|
USE_MEGATRON = False |
|
|
|
|
|
|
|
|
def enable_attn_lora(attn: Attention, peft_control: dict) -> None: |
|
|
""" |
|
|
Enable LoRA for the attention block based on the peft_control dictionary. |
|
|
|
|
|
Args: |
|
|
attn (Attention): The attention block to configure. |
|
|
peft_control (dict): Dictionary containing PEFT configuration. |
|
|
""" |
|
|
attn.peft_lora_enabled = False |
|
|
if peft_control: |
|
|
try: |
|
|
if peft_control["customization_type"] == CustomizationType.LORA: |
|
|
attn.peft_lora_enabled = True |
|
|
else: |
|
|
raise Exception(f"Unsupported Customization type {peft_control['customization_type']}") |
|
|
except KeyError as e: |
|
|
raise KeyError(f"peft_control dictionary expected to have attribute {e.args[0]}.") |
|
|
|
|
|
|
|
|
def configure_attn_lora(attn: Attention, peft_control: dict) -> None: |
|
|
""" |
|
|
Configure LoRA for the attention block based on the peft_control dictionary. |
|
|
|
|
|
Args: |
|
|
attn (Attention): The attention block to configure. |
|
|
peft_control (dict): Dictionary containing PEFT configuration. |
|
|
""" |
|
|
try: |
|
|
attn.q_lora_enabled = peft_control.get("to_q", {}).get("activate", False) |
|
|
attn.k_lora_enabled = peft_control.get("to_k", {}).get("activate", False) |
|
|
attn.v_lora_enabled = peft_control.get("to_v", {}).get("activate", False) |
|
|
attn.out_lora_enabled = peft_control.get("to_out", {}).get("activate", False) |
|
|
if attn.q_lora_enabled: |
|
|
attn.q_lora_rank = peft_control["to_q"]["lora_rank"] |
|
|
attn.q_lora_scale = float(peft_control["to_q"]["lora_scale"]) |
|
|
if attn.k_lora_enabled: |
|
|
attn.k_lora_rank = peft_control["to_k"]["lora_rank"] |
|
|
attn.k_lora_scale = float(peft_control["to_k"]["lora_scale"]) |
|
|
if attn.v_lora_enabled: |
|
|
attn.v_lora_rank = peft_control["to_v"]["lora_rank"] |
|
|
attn.v_lora_scale = float(peft_control["to_v"]["lora_scale"]) |
|
|
if attn.out_lora_enabled: |
|
|
attn.out_lora_rank = peft_control["to_out"]["lora_rank"] |
|
|
attn.out_lora_scale = float(peft_control["to_out"]["lora_scale"]) |
|
|
except KeyError as e: |
|
|
raise KeyError(f"All layers (to_q, etc) specified must have attribute {e.args[0]}.") |
|
|
except ValueError as e: |
|
|
raise ValueError(f"Could not convert string to float: {e}") |
|
|
|
|
|
|
|
|
def cal_qkv_lora( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
context: torch.Tensor = None, |
|
|
mask: torch.Tensor = None, |
|
|
rope_emb: torch.Tensor = None, |
|
|
**kwargs, |
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
del kwargs |
|
|
""" |
|
|
Calculate the Q, K, V matrices with LoRA adjustments. Derived from cosmos_predict1/diffusion/module/attention.py cal_qkv. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): Input tensor. |
|
|
context (torch.Tensor, optional): Context tensor |
|
|
mask (torch.Tensor, optional): Mask tensor |
|
|
rope_emb (torch.Tensor, optional): Rotary positional embedding |
|
|
|
|
|
Returns: |
|
|
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The Q, K, V matrices. |
|
|
""" |
|
|
|
|
|
q = self.to_q[0](x) |
|
|
context = x if context is None else context |
|
|
k = self.to_k[0](context) |
|
|
v = self.to_v[0](context) |
|
|
|
|
|
if self.peft_lora_enabled: |
|
|
try: |
|
|
if self.q_lora_enabled: |
|
|
q_lora = self.to_q_lora(x) |
|
|
q = q + self.q_lora_scale * q_lora |
|
|
if self.k_lora_enabled: |
|
|
k_lora = self.to_k_lora(context) |
|
|
k = k + self.k_lora_scale * k_lora |
|
|
if self.v_lora_enabled: |
|
|
v_lora = self.to_v_lora(context) |
|
|
v = v + self.v_lora_scale * v_lora |
|
|
except AttributeError as e: |
|
|
raise AttributeError(f"lora enabled, but missing class attribute {e.args[0]} of Attention block") |
|
|
|
|
|
q, k, v = map( |
|
|
lambda t: rearrange(t, "b ... (n c) -> b ... n c", n=self.heads // self.tp_size, c=self.dim_head), |
|
|
(q, k, v), |
|
|
) |
|
|
|
|
|
def apply_norm_and_rotary_pos_emb(q, k, v, rope_emb): |
|
|
q = self.to_q[1](q) |
|
|
k = self.to_k[1](k) |
|
|
v = self.to_v[1](v) |
|
|
if self.is_selfattn and rope_emb is not None: |
|
|
q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=True) |
|
|
k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=True) |
|
|
return q, k, v |
|
|
|
|
|
q, k, v = checkpoint(apply_norm_and_rotary_pos_emb, q, k, v, rope_emb, use_reentrant=False) |
|
|
|
|
|
return q, k, v |
|
|
|
|
|
|
|
|
def cal_attn_lora(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: |
|
|
""" |
|
|
Calculate the attention output with LoRA adjustments. Derived from cosmos_predict1/diffusion/module/attention.py cal_attn. |
|
|
|
|
|
Args: |
|
|
q (torch.Tensor): Query tensor. |
|
|
k (torch.Tensor): Key tensor. |
|
|
v (torch.Tensor): Value tensor. |
|
|
mask (torch.Tensor, optional): Mask tensor. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The attention output. |
|
|
""" |
|
|
if self.backend == "transformer_engine": |
|
|
seq_dim = self.qkv_format.index("s") |
|
|
assert ( |
|
|
q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 |
|
|
), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." |
|
|
attn_out = self.attn_op(q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None) |
|
|
out = self.to_out(attn_out) |
|
|
|
|
|
if self.peft_lora_enabled and self.out_lora_enabled: |
|
|
try: |
|
|
out_lora = self.to_out_lora(attn_out) |
|
|
out = out + self.out_lora_scale * out_lora |
|
|
except AttributeError as e: |
|
|
raise AttributeError(f"l1 lora enabled, but missing class attribute {e.args[0]} of FeedForward block") |
|
|
|
|
|
return out |
|
|
elif self.backend == "torch": |
|
|
attn_out = self.attn_op(q, k, v, mask=mask) |
|
|
attn_out = rearrange(attn_out, " b ... n c -> b ... (n c)") |
|
|
out = self.to_out(attn_out) |
|
|
|
|
|
if self.peft_lora_enabled and self.out_lora_enabled: |
|
|
try: |
|
|
out_lora = self.to_out_lora(attn_out) |
|
|
out = out + self.out_lora_scale * out_lora |
|
|
except AttributeError as e: |
|
|
raise AttributeError(f"l1 lora enabled, but missing class attribute {e.args[0]} of FeedForward block") |
|
|
|
|
|
return out |
|
|
else: |
|
|
raise ValueError(f"Backend {self.backend} not found") |
|
|
|
|
|
|
|
|
def build_attn_lora(attn: Attention, peft_control: dict) -> None: |
|
|
""" |
|
|
Configure, build and add LoRA layers to the attention block. |
|
|
|
|
|
Args: |
|
|
attn (Attention): The attention block to add LoRA layers to. |
|
|
peft_control (dict): Dictionary containing PEFT configuration. |
|
|
""" |
|
|
enable_attn_lora(attn, peft_control) |
|
|
configure_attn_lora(attn, peft_control) |
|
|
if attn.peft_lora_enabled: |
|
|
query_dim = attn.query_dim |
|
|
inner_dim = attn.inner_dim |
|
|
context_dim = attn.context_dim |
|
|
tp_group = parallel_state.get_tensor_model_parallel_group(check_initialized=False) if USE_MEGATRON else None |
|
|
|
|
|
if attn.tp_size == 1: |
|
|
if attn.q_lora_enabled: |
|
|
attn.to_q_lora = LoRALinearLayer(query_dim, inner_dim, rank=attn.q_lora_rank, linear=True) |
|
|
if attn.k_lora_enabled: |
|
|
attn.to_k_lora = LoRALinearLayer(context_dim, inner_dim, rank=attn.k_lora_rank, linear=True) |
|
|
if attn.v_lora_enabled: |
|
|
attn.to_v_lora = LoRALinearLayer(context_dim, inner_dim, rank=attn.v_lora_rank, linear=True) |
|
|
if attn.out_lora_enabled: |
|
|
attn.to_out_lora = LoRALinearLayer(inner_dim, query_dim, rank=attn.out_lora_rank, linear=True) |
|
|
else: |
|
|
sequence_parallel = getattr(parallel_state, "sequence_parallel", False) |
|
|
if attn.q_lora_enabled: |
|
|
attn.to_q_lora = TELoRALinearLayer( |
|
|
query_dim, |
|
|
inner_dim, |
|
|
rank=attn.q_lora_rank, |
|
|
linear=True, |
|
|
tp_size=attn.tp_size, |
|
|
tp_group=tp_group, |
|
|
sequence_parallel=sequence_parallel, |
|
|
parallel_mode="column", |
|
|
) |
|
|
if attn.k_lora_enabled: |
|
|
attn.to_k_lora = TELoRALinearLayer( |
|
|
context_dim, |
|
|
inner_dim, |
|
|
rank=attn.k_lora_rank, |
|
|
linear=True, |
|
|
tp_size=attn.tp_size, |
|
|
tp_group=tp_group, |
|
|
sequence_parallel=sequence_parallel, |
|
|
parallel_mode="column", |
|
|
) |
|
|
if attn.v_lora_enabled: |
|
|
attn.to_v_lora = TELoRALinearLayer( |
|
|
context_dim, |
|
|
inner_dim, |
|
|
rank=attn.v_lora_rank, |
|
|
linear=True, |
|
|
tp_size=attn.tp_size, |
|
|
tp_group=tp_group, |
|
|
sequence_parallel=sequence_parallel, |
|
|
parallel_mode="column", |
|
|
) |
|
|
if attn.out_lora_enabled: |
|
|
attn.to_out_lora = TELoRALinearLayer( |
|
|
inner_dim, |
|
|
query_dim, |
|
|
rank=attn.out_lora_rank, |
|
|
linear=True, |
|
|
tp_size=attn.tp_size, |
|
|
tp_group=tp_group, |
|
|
sequence_parallel=sequence_parallel, |
|
|
parallel_mode="row", |
|
|
) |
|
|
attn.cal_qkv = cal_qkv_lora.__get__(attn, attn.__class__) |
|
|
attn.cal_attn = cal_attn_lora.__get__(attn, attn.__class__) |
|
|
|