Spaces:
Running
on
Zero
Running
on
Zero
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
# // | |
# // Licensed under the Apache License, Version 2.0 (the "License"); | |
# // you may not use this file except in compliance with the License. | |
# // You may obtain a copy of the License at | |
# // | |
# // http://www.apache.org/licenses/LICENSE-2.0 | |
# // | |
# // Unless required by applicable law or agreed to in writing, software | |
# // distributed under the License is distributed on an "AS IS" BASIS, | |
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# // See the License for the specific language governing permissions and | |
# // limitations under the License. | |
from functools import lru_cache | |
from typing import Optional, Tuple | |
import torch | |
from einops import rearrange | |
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb | |
from torch import nn | |
from common.cache import Cache | |
class RotaryEmbeddingBase(nn.Module): | |
def __init__(self, dim: int, rope_dim: int): | |
super().__init__() | |
self.rope = RotaryEmbedding( | |
dim=dim // rope_dim, | |
freqs_for="pixel", | |
max_freq=256, | |
) | |
# 1. Set model.requires_grad_(True) after model creation will make | |
# the `requires_grad=False` for rope freqs no longer hold. | |
# 2. Even if we don't set requires_grad_(True) explicitly, | |
# FSDP is not memory efficient when handling fsdp_wrap | |
# with mixed requires_grad=True/False. | |
# With above consideration, it is easier just remove the freqs | |
# out of nn.Parameters when `learned_freq=False` | |
freqs = self.rope.freqs | |
del self.rope.freqs | |
self.rope.register_buffer("freqs", freqs.data) | |
def get_axial_freqs(self, *dims): | |
return self.rope.get_axial_freqs(*dims) | |
class RotaryEmbedding3d(RotaryEmbeddingBase): | |
def __init__(self, dim: int): | |
super().__init__(dim, rope_dim=3) | |
self.mm = False | |
def forward( | |
self, | |
q: torch.FloatTensor, # b h l d | |
k: torch.FloatTensor, # b h l d | |
size: Tuple[int, int, int], | |
) -> Tuple[ | |
torch.FloatTensor, | |
torch.FloatTensor, | |
]: | |
T, H, W = size | |
freqs = self.get_axial_freqs(T, H, W) | |
q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) | |
k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) | |
q = apply_rotary_emb(freqs, q.float()).to(q.dtype) | |
k = apply_rotary_emb(freqs, k.float()).to(k.dtype) | |
q = rearrange(q, "b h T H W d -> b h (T H W) d") | |
k = rearrange(k, "b h T H W d -> b h (T H W) d") | |
return q, k | |
class MMRotaryEmbeddingBase(RotaryEmbeddingBase): | |
def __init__(self, dim: int, rope_dim: int): | |
super().__init__(dim, rope_dim) | |
self.rope = RotaryEmbedding( | |
dim=dim // rope_dim, | |
freqs_for="lang", | |
theta=10000, | |
) | |
freqs = self.rope.freqs | |
del self.rope.freqs | |
self.rope.register_buffer("freqs", freqs.data) | |
self.mm = True | |
class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): | |
def __init__(self, dim: int): | |
super().__init__(dim, rope_dim=3) | |
def forward( | |
self, | |
vid_q: torch.FloatTensor, # L h d | |
vid_k: torch.FloatTensor, # L h d | |
vid_shape: torch.LongTensor, # B 3 | |
txt_q: torch.FloatTensor, # L h d | |
txt_k: torch.FloatTensor, # L h d | |
txt_shape: torch.LongTensor, # B 1 | |
cache: Cache, | |
) -> Tuple[ | |
torch.FloatTensor, | |
torch.FloatTensor, | |
torch.FloatTensor, | |
torch.FloatTensor, | |
]: | |
vid_freqs, txt_freqs = cache( | |
"mmrope_freqs_3d", | |
lambda: self.get_freqs(vid_shape, txt_shape), | |
) | |
vid_q = rearrange(vid_q, "L h d -> h L d") | |
vid_k = rearrange(vid_k, "L h d -> h L d") | |
vid_q = apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype) | |
vid_k = apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype) | |
vid_q = rearrange(vid_q, "h L d -> L h d") | |
vid_k = rearrange(vid_k, "h L d -> L h d") | |
txt_q = rearrange(txt_q, "L h d -> h L d") | |
txt_k = rearrange(txt_k, "L h d -> h L d") | |
txt_q = apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype) | |
txt_k = apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype) | |
txt_q = rearrange(txt_q, "h L d -> L h d") | |
txt_k = rearrange(txt_k, "h L d -> L h d") | |
return vid_q, vid_k, txt_q, txt_k | |
def get_freqs( | |
self, | |
vid_shape: torch.LongTensor, | |
txt_shape: torch.LongTensor, | |
) -> Tuple[ | |
torch.Tensor, | |
torch.Tensor, | |
]: | |
vid_freqs = self.get_axial_freqs(1024, 128, 128) | |
txt_freqs = self.get_axial_freqs(1024) | |
vid_freq_list, txt_freq_list = [], [] | |
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): | |
vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1)) | |
txt_freq = txt_freqs[:l].repeat(1, 3).reshape(-1, vid_freqs.size(-1)) | |
vid_freq_list.append(vid_freq) | |
txt_freq_list.append(txt_freq) | |
return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0) | |
def get_na_rope(rope_type: Optional[str], dim: int): | |
if rope_type is None: | |
return None | |
if rope_type == "mmrope3d": | |
return NaMMRotaryEmbedding3d(dim=dim) | |
raise NotImplementedError(f"{rope_type} is not supported.") | |