|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import transformer_engine as te |
|
|
from megatron.core import parallel_state |
|
|
from torch import nn |
|
|
|
|
|
from cosmos_predict1.utils import log |
|
|
|
|
|
|
|
|
class LoRALinearLayer(nn.Module): |
|
|
""" |
|
|
ported from |
|
|
https://github.com/huggingface/diffusers/blob/7a32b6beeb0cfdefed645253dce23d9b0a78597f/src/diffusers/models/attention_processor.py#L470. |
|
|
""" |
|
|
|
|
|
def __init__(self, in_features, out_features, rank=4, linear=False): |
|
|
super().__init__() |
|
|
|
|
|
if rank > min(in_features, out_features): |
|
|
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") |
|
|
|
|
|
if linear: |
|
|
down = nn.Linear(in_features, rank, bias=False) |
|
|
up = nn.Linear(rank, out_features, bias=False) |
|
|
else: |
|
|
down = nn.Conv1d(in_features, rank, 1, bias=False) |
|
|
up = nn.Conv1d(rank, out_features, 1, bias=False) |
|
|
|
|
|
nn.init.normal_(down.weight, std=1 / rank) |
|
|
nn.init.zeros_(up.weight) |
|
|
self.net = nn.Sequential(down, up) |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
orig_dtype = hidden_states.dtype |
|
|
dtype = self.net[0].weight.dtype |
|
|
|
|
|
up_hidden_states = self.net(hidden_states.to(dtype)) |
|
|
|
|
|
return up_hidden_states.to(orig_dtype) |
|
|
|
|
|
|
|
|
class TELoRALinearLayer(nn.Module): |
|
|
""" |
|
|
ported from |
|
|
https://github.com/huggingface/diffusers/blob/7a32b6beeb0cfdefed645253dce23d9b0a78597f/src/diffusers/models/attention_processor.py#L470. |
|
|
""" |
|
|
|
|
|
def __init__(self, in_features, out_features, rank, linear, tp_size, tp_group, sequence_parallel, parallel_mode): |
|
|
super().__init__() |
|
|
|
|
|
if rank > min(in_features, out_features): |
|
|
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") |
|
|
|
|
|
if linear: |
|
|
down = te.pytorch.Linear( |
|
|
in_features, |
|
|
rank, |
|
|
bias=False, |
|
|
tp_size=1, |
|
|
tp_group=tp_group, |
|
|
sequence_parallel=sequence_parallel, |
|
|
parallel_mode=None, |
|
|
) |
|
|
up = te.pytorch.Linear( |
|
|
rank, |
|
|
out_features, |
|
|
bias=False, |
|
|
tp_size=tp_size, |
|
|
tp_group=tp_group, |
|
|
sequence_parallel=sequence_parallel, |
|
|
parallel_mode=parallel_mode, |
|
|
) |
|
|
else: |
|
|
down = te.pytorch.Conv1d( |
|
|
in_features, |
|
|
rank, |
|
|
1, |
|
|
bias=False, |
|
|
tp_size=1, |
|
|
tp_group=tp_group, |
|
|
sequence_parallel=sequence_parallel, |
|
|
parallel_mode=None, |
|
|
) |
|
|
up = te.pytorch.Conv1d( |
|
|
rank, |
|
|
out_features, |
|
|
1, |
|
|
bias=False, |
|
|
tp_size=tp_size, |
|
|
tp_group=tp_group, |
|
|
sequence_parallel=sequence_parallel, |
|
|
parallel_mode=parallel_mode, |
|
|
) |
|
|
tp_rank = parallel_state.get_tensor_model_parallel_rank() |
|
|
|
|
|
gen = torch.Generator(device=down.weight.device) |
|
|
|
|
|
gen_state = gen.get_state() |
|
|
|
|
|
|
|
|
log.info(f"rank {tp_rank}: setting seed to 0") |
|
|
gen.manual_seed(0) |
|
|
nn.init.normal_(down.weight, std=1 / rank, generator=gen) |
|
|
|
|
|
gen.manual_seed(tp_rank) |
|
|
log.info(f"rank {tp_rank}: setting seed to {tp_rank}") |
|
|
nn.init.zeros_(up.weight) |
|
|
|
|
|
gen.set_state(gen_state) |
|
|
|
|
|
self.net = nn.Sequential(down, up) |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
orig_dtype = hidden_states.dtype |
|
|
dtype = self.net[0].weight.dtype |
|
|
up_hidden_states = self.net(hidden_states.to(dtype)) |
|
|
|
|
|
return up_hidden_states.to(orig_dtype) |
|
|
|