File size: 10,787 Bytes
b6af722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from einops import rearrange
from torch.utils.checkpoint import checkpoint
from transformer_engine.pytorch.attention import apply_rotary_pos_emb

from cosmos_predict1.diffusion.module.attention import Attention
from cosmos_predict1.diffusion.training.utils.peft.lora_net import LoRALinearLayer, TELoRALinearLayer
from cosmos_predict1.diffusion.utils.customization.customization_manager import CustomizationType

try:
    from megatron.core import parallel_state

    USE_MEGATRON = True
except ImportError:
    USE_MEGATRON = False


def enable_attn_lora(attn: Attention, peft_control: dict) -> None:
    """
    Enable LoRA for the attention block based on the peft_control dictionary.

    Args:
        attn (Attention): The attention block to configure.
        peft_control (dict): Dictionary containing PEFT configuration.
    """
    attn.peft_lora_enabled = False
    if peft_control:
        try:
            if peft_control["customization_type"] == CustomizationType.LORA:
                attn.peft_lora_enabled = True
            else:
                raise Exception(f"Unsupported Customization type {peft_control['customization_type']}")
        except KeyError as e:
            raise KeyError(f"peft_control dictionary expected to have attribute {e.args[0]}.")


def configure_attn_lora(attn: Attention, peft_control: dict) -> None:
    """
    Configure LoRA for the attention block based on the peft_control dictionary.

    Args:
        attn (Attention): The attention block to configure.
        peft_control (dict): Dictionary containing PEFT configuration.
    """
    try:
        attn.q_lora_enabled = peft_control.get("to_q", {}).get("activate", False)
        attn.k_lora_enabled = peft_control.get("to_k", {}).get("activate", False)
        attn.v_lora_enabled = peft_control.get("to_v", {}).get("activate", False)
        attn.out_lora_enabled = peft_control.get("to_out", {}).get("activate", False)
        if attn.q_lora_enabled:
            attn.q_lora_rank = peft_control["to_q"]["lora_rank"]
            attn.q_lora_scale = float(peft_control["to_q"]["lora_scale"])
        if attn.k_lora_enabled:
            attn.k_lora_rank = peft_control["to_k"]["lora_rank"]
            attn.k_lora_scale = float(peft_control["to_k"]["lora_scale"])
        if attn.v_lora_enabled:
            attn.v_lora_rank = peft_control["to_v"]["lora_rank"]
            attn.v_lora_scale = float(peft_control["to_v"]["lora_scale"])
        if attn.out_lora_enabled:
            attn.out_lora_rank = peft_control["to_out"]["lora_rank"]
            attn.out_lora_scale = float(peft_control["to_out"]["lora_scale"])
    except KeyError as e:
        raise KeyError(f"All layers (to_q, etc) specified must have attribute {e.args[0]}.")
    except ValueError as e:
        raise ValueError(f"Could not convert string to float: {e}")


def cal_qkv_lora(
    self,
    x: torch.Tensor,
    context: torch.Tensor = None,
    mask: torch.Tensor = None,
    rope_emb: torch.Tensor = None,
    **kwargs,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    del kwargs
    """
    Calculate the Q, K, V matrices with LoRA adjustments. Derived from cosmos_predict1/diffusion/module/attention.py cal_qkv.

    Args:
        x (torch.Tensor): Input tensor.
        context (torch.Tensor, optional): Context tensor
        mask (torch.Tensor, optional): Mask tensor
        rope_emb (torch.Tensor, optional): Rotary positional embedding

    Returns:
        tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The Q, K, V matrices.
    """

    q = self.to_q[0](x)
    context = x if context is None else context
    k = self.to_k[0](context)
    v = self.to_v[0](context)

    if self.peft_lora_enabled:
        try:
            if self.q_lora_enabled:
                q_lora = self.to_q_lora(x)
                q = q + self.q_lora_scale * q_lora
            if self.k_lora_enabled:
                k_lora = self.to_k_lora(context)
                k = k + self.k_lora_scale * k_lora
            if self.v_lora_enabled:
                v_lora = self.to_v_lora(context)
                v = v + self.v_lora_scale * v_lora
        except AttributeError as e:
            raise AttributeError(f"lora enabled, but missing class attribute {e.args[0]} of Attention block")

    q, k, v = map(
        lambda t: rearrange(t, "b ... (n c) -> b ... n c", n=self.heads // self.tp_size, c=self.dim_head),
        (q, k, v),
    )

    def apply_norm_and_rotary_pos_emb(q, k, v, rope_emb):
        q = self.to_q[1](q)
        k = self.to_k[1](k)
        v = self.to_v[1](v)
        if self.is_selfattn and rope_emb is not None:  # only apply to self-attention!
            q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=True)
            k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=True)
        return q, k, v

    q, k, v = checkpoint(apply_norm_and_rotary_pos_emb, q, k, v, rope_emb, use_reentrant=False)

    return q, k, v


def cal_attn_lora(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
    """
    Calculate the attention output with LoRA adjustments. Derived from cosmos_predict1/diffusion/module/attention.py cal_attn.

    Args:
        q (torch.Tensor): Query tensor.
        k (torch.Tensor): Key tensor.
        v (torch.Tensor): Value tensor.
        mask (torch.Tensor, optional): Mask tensor.

    Returns:
        torch.Tensor: The attention output.
    """
    if self.backend == "transformer_engine":
        seq_dim = self.qkv_format.index("s")
        assert (
            q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1
        ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version."
        attn_out = self.attn_op(q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None)  # [B, Mq, H, V]
        out = self.to_out(attn_out)

        if self.peft_lora_enabled and self.out_lora_enabled:
            try:
                out_lora = self.to_out_lora(attn_out)
                out = out + self.out_lora_scale * out_lora
            except AttributeError as e:
                raise AttributeError(f"l1 lora enabled, but missing class attribute {e.args[0]} of FeedForward block")

        return out
    elif self.backend == "torch":
        attn_out = self.attn_op(q, k, v, mask=mask)  # [B, Mq, H, V]
        attn_out = rearrange(attn_out, " b ... n c -> b ... (n c)")
        out = self.to_out(attn_out)

        if self.peft_lora_enabled and self.out_lora_enabled:
            try:
                out_lora = self.to_out_lora(attn_out)
                out = out + self.out_lora_scale * out_lora
            except AttributeError as e:
                raise AttributeError(f"l1 lora enabled, but missing class attribute {e.args[0]} of FeedForward block")

        return out
    else:
        raise ValueError(f"Backend {self.backend} not found")


def build_attn_lora(attn: Attention, peft_control: dict) -> None:
    """
    Configure, build and add LoRA layers to the attention block.

    Args:
        attn (Attention): The attention block to add LoRA layers to.
        peft_control (dict): Dictionary containing PEFT configuration.
    """
    enable_attn_lora(attn, peft_control)
    configure_attn_lora(attn, peft_control)
    if attn.peft_lora_enabled:
        query_dim = attn.query_dim
        inner_dim = attn.inner_dim
        context_dim = attn.context_dim
        tp_group = parallel_state.get_tensor_model_parallel_group(check_initialized=False) if USE_MEGATRON else None

        if attn.tp_size == 1:
            if attn.q_lora_enabled:
                attn.to_q_lora = LoRALinearLayer(query_dim, inner_dim, rank=attn.q_lora_rank, linear=True)
            if attn.k_lora_enabled:
                attn.to_k_lora = LoRALinearLayer(context_dim, inner_dim, rank=attn.k_lora_rank, linear=True)
            if attn.v_lora_enabled:
                attn.to_v_lora = LoRALinearLayer(context_dim, inner_dim, rank=attn.v_lora_rank, linear=True)
            if attn.out_lora_enabled:
                attn.to_out_lora = LoRALinearLayer(inner_dim, query_dim, rank=attn.out_lora_rank, linear=True)
        else:
            sequence_parallel = getattr(parallel_state, "sequence_parallel", False)
            if attn.q_lora_enabled:
                attn.to_q_lora = TELoRALinearLayer(
                    query_dim,
                    inner_dim,
                    rank=attn.q_lora_rank,
                    linear=True,
                    tp_size=attn.tp_size,
                    tp_group=tp_group,
                    sequence_parallel=sequence_parallel,
                    parallel_mode="column",
                )
            if attn.k_lora_enabled:
                attn.to_k_lora = TELoRALinearLayer(
                    context_dim,
                    inner_dim,
                    rank=attn.k_lora_rank,
                    linear=True,
                    tp_size=attn.tp_size,
                    tp_group=tp_group,
                    sequence_parallel=sequence_parallel,
                    parallel_mode="column",
                )
            if attn.v_lora_enabled:
                attn.to_v_lora = TELoRALinearLayer(
                    context_dim,
                    inner_dim,
                    rank=attn.v_lora_rank,
                    linear=True,
                    tp_size=attn.tp_size,
                    tp_group=tp_group,
                    sequence_parallel=sequence_parallel,
                    parallel_mode="column",
                )
            if attn.out_lora_enabled:
                attn.to_out_lora = TELoRALinearLayer(
                    inner_dim,
                    query_dim,
                    rank=attn.out_lora_rank,
                    linear=True,
                    tp_size=attn.tp_size,
                    tp_group=tp_group,
                    sequence_parallel=sequence_parallel,
                    parallel_mode="row",
                )
    attn.cal_qkv = cal_qkv_lora.__get__(attn, attn.__class__)
    attn.cal_attn = cal_attn_lora.__get__(attn, attn.__class__)