Lyra / src /models /utils /token_pruning.py
Muhammad Taqi Raza
adding lyra files
af758d1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from typing import List, Tuple, Optional
# Sample Gumbel noise for Gumbel-Softmax trick
def sample_gumbel(shape: torch.Size, eps: float = 1e-6, device=None, dtype=None) -> torch.Tensor:
U = torch.rand(shape, device=device, dtype=dtype)
return -torch.log(-torch.log(U.clamp(min=eps, max=1 - eps)))
# Select the top-k or softmax probabilities (with Gumbel-Softmax for differentiability)
def select_topk(
logits: torch.Tensor,
k: int,
method: str,
temperature: float,
hard: bool,
eps: float
) -> torch.Tensor:
B, N = logits.shape
if method == 'topk':
topk_vals, topk_idx = torch.topk(logits, k, dim=-1)
mask = torch.zeros_like(logits).scatter(-1, topk_idx, 1.0)
elif method == 'softmax':
gumbel_noise = sample_gumbel(logits.shape, eps=eps, device=logits.device, dtype=logits.dtype)
y = (logits + gumbel_noise) / temperature
y_soft = F.softmax(y, dim=-1)
if hard:
topk_idx = y_soft.topk(k, dim=-1).indices
hard_mask = torch.zeros_like(y_soft).scatter(-1, topk_idx, 1.0)
mask = hard_mask - y_soft.detach() + y_soft
else:
mask = y_soft
else:
raise ValueError(f"Unknown method: {method}")
return mask
# Perform global selection of k tokens across the entire tensor
def global_selection(
mask_logits: torch.Tensor, # (B, T, H, W)
total_k: int,
method: str,
temperature: float,
hard: bool,
eps: float
) -> torch.Tensor:
B, T, H, W = mask_logits.shape
N = T * H * W
logits_flat = mask_logits.reshape(B, N)
mask_flat = select_topk(logits_flat, total_k, method, temperature, hard, eps)
mask = mask_flat.reshape(B, T, H, W)
return mask
# Perform structured selection: k_t and k_hw tokens independently for time and space
def structured_selection(
mask_logits: torch.Tensor, # (B, T, H, W)
k_t: int,
k_hw: int,
method: str,
temperature: float,
hard: bool,
eps: float
) -> torch.Tensor:
B, T, H, W = mask_logits.shape
# Temporal selection
logits_t = mask_logits.mean(dim=[2, 3]) # (B, T)
mask_t = select_topk(logits_t, k_t, method, temperature, hard, eps) # (B, T)
# Spatial selection per frame
mask_spatial = []
for b in range(B):
mask_b = []
for t in range(T):
logits_hw = mask_logits[b, t].reshape(-1) # (H*W,)
mask_hw = select_topk(logits_hw.unsqueeze(0), k_hw, method, temperature, hard, eps)
mask_b.append(mask_hw.reshape(H, W))
mask_b = torch.stack(mask_b, dim=0) # (T, H, W)
mask_spatial.append(mask_b)
mask_spatial = torch.stack(mask_spatial, dim=0) # (B, T, H, W)
# Combine temporal and spatial
mask = mask_spatial * mask_t.unsqueeze(-1).unsqueeze(-1) # (B, T, H, W)
return mask
# Apply the mask and select tokens and other tensors based on the selected indices
def apply_mask_and_select(
tokens: torch.Tensor, # (B, C, T, H, W)
other_tensors: List[torch.Tensor],
mask: torch.Tensor # (B, T, H, W)
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
B, C, T, H, W = tokens.shape
N = T * H * W
tokens_flat = tokens.reshape(B, C, N)
mask_flat = mask.reshape(B, N)
selected_tokens = []
selected_others = [[] for _ in other_tensors]
for b in range(B):
idx = mask_flat[b].nonzero(as_tuple=False).squeeze(-1)
selected_tokens.append(tokens_flat[b, :, idx])
for i, t in enumerate(other_tensors):
t_flat = t.reshape(B, -1, N)
selected = t_flat[b, :, idx]
selected_others[i].append(selected)
tokens_out = torch.stack(selected_tokens, dim=0) # (B, C, k)
others_out = [torch.stack(x, dim=0) for x in selected_others] # list of (B, C_other, k)
return tokens_out, others_out
# Main process function to prune tokens and other tensors based on mask logits
def process_tensors(
tokens: torch.Tensor, # (B, C, T, H, W)
mask_logits: torch.Tensor, # (B, 1, T, H, W)
other_tensors: List[torch.Tensor],
total_k: Optional[int] = None,
k_t: Optional[int] = None,
k_hw: Optional[int] = None,
temperature: float = 1.0,
eps: float = 1e-6,
training: bool = True,
soft_inference: bool = True,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
If training=True -> uses softmax + hard sampling (Gumbel-Softmax trick)
If training=False -> uses topk (non-differentiable)
"""
B, C, T, H, W = tokens.shape
mask_logits = mask_logits.squeeze(1) # (B, T, H, W)
if training or soft_inference:
method = 'softmax'
hard = True
else:
method = 'topk'
hard = False # ignored in topk mode
if total_k is not None:
mask = global_selection(mask_logits, total_k, method, temperature, hard, eps)
elif k_t is not None and k_hw is not None:
mask = structured_selection(mask_logits, k_t, k_hw, method, temperature, hard, eps)
else:
raise ValueError("Provide either total_k or both k_t and k_hw.")
tokens_out, others_out = apply_mask_and_select(tokens, other_tensors, mask)
return tokens_out, others_out, mask
if __name__ == '__main__':
# Case 1: Structured pruning (select 60 frames and 1/4 spatial tokens)
temperature = 1.0
training = False
k_t = 9
k_h = 90
k_w = 160
B, T, C, H, W = 2, 17, 3, 180, 320
tokens = torch.randn(B, C, T, H, W)
mask_logits = torch.randn(B, 1, T, H, W)
# Other tensors with different channels
other1 = torch.randn(B, 6, T, H, W)
other2 = torch.randn(B, 9, T, H, W)
tokens_out, others_out = process_tensors(
tokens=tokens,
mask_logits=mask_logits,
other_tensors=[other1, other2],
k_t=k_t, # select 60 frames out of 121
k_hw=k_h * k_w, # select 1/4 spatial tokens (since 720x1280 is 2*2)
temperature=temperature,
training=training, # differentiable Gumbel-Softmax
)
# Case 2: Global total pruning (select k tokens jointly across T and HW)
tokens_out, others_out = process_tensors(
tokens=tokens,
mask_logits=mask_logits,
other_tensors=[other1, other2],
total_k=k_t * k_h * k_w, # select k tokens globally (joint T and HW selection)
temperature=temperature,
training=training, # inference: real top-k selection
)