# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- import math from typing import Callable, Tuple import torch def do_nothing(x, mode=None): return x def bipartite_soft_matching( metric: torch.Tensor, r: int, class_token: bool = False, distill_token: bool = False, ) -> Tuple[Callable, Callable]: """ Applies ToMe with a balanced matching set (50%, 50%). Input size is [batch, tokens, channels]. r indicates the number of tokens to remove (max 50% of tokens). Extra args: - class_token: Whether or not there's a class token. - distill_token: Whether or not there's also a distillation token. When enabled, the class token and distillation tokens won't get merged. """ protected = 0 if class_token: protected += 1 if distill_token: protected += 1 # We can only reduce by a maximum of 50% tokens t = metric.shape[1] r = min(r, (t - protected) // 2) if r <= 0: return do_nothing, do_nothing with torch.no_grad(): metric = metric / metric.norm(dim=-1, keepdim=True) a, b = metric[..., ::2, :], metric[..., 1::2, :] scores = a @ b.transpose(-1, -2) if class_token: scores[..., 0, :] = -math.inf if distill_token: scores[..., :, 0] = -math.inf node_max, node_idx = scores.max(dim=-1) edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] unm_idx = edge_idx[..., r:, :] # Unmerged Tokens src_idx = edge_idx[..., :r, :] # Merged Tokens dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) if class_token: # Sort to ensure the class token is at the start unm_idx = unm_idx.sort(dim=1)[0] def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: src, dst = x[..., ::2, :], x[..., 1::2, :] n, t1, c = src.shape unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c)) src = src.gather(dim=-2, index=src_idx.expand(n, r, c)) dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) if distill_token: return torch.cat([unm[:, :1], dst[:, :1], unm[:, 1:], dst[:, 1:]], dim=1) else: return torch.cat([unm, dst], dim=1) def unmerge(x: torch.Tensor) -> torch.Tensor: unm_len = unm_idx.shape[1] unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] n, _, c = unm.shape src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c)) out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype) out[..., 1::2, :] = dst out.scatter_(dim=-2, index=(2 * unm_idx).expand(n, unm_len, c), src=unm) out.scatter_(dim=-2, index=(2 * src_idx).expand(n, r, c), src=src) return out return merge, unmerge def kth_bipartite_soft_matching( metric: torch.Tensor, k: int ) -> Tuple[Callable, Callable]: """ Applies ToMe with the two sets as (every kth element, the rest). If n is the number of tokens, resulting number of tokens will be n // z. Input size is [batch, tokens, channels]. z indicates the stride for the first set. z = 2 is equivalent to regular bipartite_soft_matching with r = 0.5 * N """ if k <= 1: return do_nothing, do_nothing def split(x): t_rnd = (x.shape[1] // k) * k x = x[:, :t_rnd, :].view(x.shape[0], -1, k, x.shape[2]) a, b = ( x[:, :, : (k - 1), :].contiguous().view(x.shape[0], -1, x.shape[-1]), x[:, :, (k - 1), :], ) return a, b with torch.no_grad(): metric = metric / metric.norm(dim=-1, keepdim=True) a, b = split(metric) r = a.shape[1] scores = a @ b.transpose(-1, -2) _, dst_idx = scores.max(dim=-1) dst_idx = dst_idx[..., None] def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: src, dst = split(x) n, _, c = src.shape dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) return dst def unmerge(x: torch.Tensor) -> torch.Tensor: n, _, c = x.shape dst = x src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c)).to(x.dtype) src = src.view(n, -1, (k - 1), c) dst = dst.view(n, -1, 1, c) out = torch.cat([src, dst], dim=-2) out = out.contiguous().view(n, -1, c) return out return merge, unmerge def random_bipartite_soft_matching( metric: torch.Tensor, r: int ) -> Tuple[Callable, Callable]: """ Applies ToMe with the two sets as (r chosen randomly, the rest). Input size is [batch, tokens, channels]. This will reduce the number of tokens by r. """ if r <= 0: return do_nothing, do_nothing with torch.no_grad(): B, N, _ = metric.shape rand_idx = torch.rand(B, N, 1, device=metric.device).argsort(dim=1) a_idx = rand_idx[:, :r, :] b_idx = rand_idx[:, r:, :] def split(x): C = x.shape[-1] a = x.gather(dim=1, index=a_idx.expand(B, r, C)) b = x.gather(dim=1, index=b_idx.expand(B, N - r, C)) return a, b metric = metric / metric.norm(dim=-1, keepdim=True) a, b = split(metric) scores = a @ b.transpose(-1, -2) _, dst_idx = scores.max(dim=-1) dst_idx = dst_idx[..., None] def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: src, dst = split(x) C = src.shape[-1] dst = dst.scatter_reduce(-2, dst_idx.expand(B, r, C), src, reduce=mode) return dst def unmerge(x: torch.Tensor) -> torch.Tensor: C = x.shape[-1] dst = x src = dst.gather(dim=-2, index=dst_idx.expand(B, r, C)) out = torch.zeros(B, N, C, device=x.device, dtype=x.dtype) out.scatter_(dim=-2, index=a_idx.expand(B, r, C), src=src) out.scatter_(dim=-2, index=b_idx.expand(B, N - r, C), src=dst) return out return merge, unmerge def merge_wavg( merge: Callable, x: torch.Tensor, size: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Applies the merge function by taking a weighted average based on token size. Returns the merged tensor and the new token sizes. """ if size is None: size = torch.ones_like(x[..., 0, None]) x = merge(x * size, mode="sum") size = merge(size, mode="sum") x = x / size return x, size def merge_source( merge: Callable, x: torch.Tensor, source: torch.Tensor = None ) -> torch.Tensor: """ For source tracking. Source is an adjacency matrix between the initial tokens and final merged groups. x is used to find out how many tokens there are in case the source is None. """ if source is None: n, t, _ = x.shape source = torch.eye(t, device=x.device)[None, ...].expand(n, t, t) source = merge(source, mode="amax") return source