from math import prod from typing import Tuple import numpy as np import torch from timm.models.layers import to_2tuple def bchw_to_bhwc(x: torch.Tensor) -> torch.Tensor: """Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C).""" return x.permute(0, 2, 3, 1) def bhwc_to_bchw(x: torch.Tensor) -> torch.Tensor: """Permutes a tensor from the shape (B, H, W, C) to (B, C, H, W).""" return x.permute(0, 3, 1, 2) def bchw_to_blc(x: torch.Tensor) -> torch.Tensor: """Rearrange a tensor from the shape (B, C, H, W) to (B, L, C).""" return x.flatten(2).transpose(1, 2) def blc_to_bchw(x: torch.Tensor, x_size: Tuple) -> torch.Tensor: """Rearrange a tensor from the shape (B, L, C) to (B, C, H, W).""" B, L, C = x.shape return x.transpose(1, 2).view(B, C, *x_size) def blc_to_bhwc(x: torch.Tensor, x_size: Tuple) -> torch.Tensor: """Rearrange a tensor from the shape (B, L, C) to (B, H, W, C).""" B, L, C = x.shape return x.view(B, *x_size, C) def window_partition(x, window_size: Tuple[int, int]): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view( B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C ) windows = ( x.permute(0, 1, 3, 2, 4, 5) .contiguous() .view(-1, window_size[0], window_size[1], C) ) return windows def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]): """ Args: windows: (num_windows * B, window_size[0], window_size[1], C) window_size (Tuple[int, int]): Window size img_size (Tuple[int, int]): Image size Returns: x: (B, H, W, C) """ H, W = img_size B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) x = windows.view( B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1 ) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x def _fill_window(input_resolution, window_size, shift_size=None): if shift_size is None: shift_size = [s // 2 for s in window_size] img_mask = torch.zeros((1, *input_resolution, 1)) # 1 H W 1 h_slices = ( slice(0, -window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None), ) w_slices = ( slice(0, -window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None), ) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, prod(window_size)) return mask_windows ##################################### # Different versions of the functions # 1) Swin Transformer, SwinIR, Square window attention in GRL; # 2) Early development of the decomposition-based efficient attention mechanism (efficient_win_attn.py); # 3) GRL. Window-anchor attention mechanism. # 1) & 3) are still useful ##################################### def calculate_mask(input_resolution, window_size, shift_size): """ Use case: 1) """ # calculate attention mask for SW-MSA if isinstance(shift_size, int): shift_size = to_2tuple(shift_size) mask_windows = _fill_window(input_resolution, window_size, shift_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( attn_mask == 0, float(0.0) ) # nW, window_size**2, window_size**2 return attn_mask def calculate_mask_all( input_resolution, window_size, shift_size, anchor_window_down_factor=1, window_to_anchor=True, ): """ Use case: 3) """ # calculate attention mask for SW-MSA anchor_resolution = [s // anchor_window_down_factor for s in input_resolution] aws = [s // anchor_window_down_factor for s in window_size] anchor_shift = [s // anchor_window_down_factor for s in shift_size] # mask of window1: nW, Wh**Ww mask_windows = _fill_window(input_resolution, window_size, shift_size) # mask of window2: nW, AWh*AWw mask_anchor = _fill_window(anchor_resolution, aws, anchor_shift) if window_to_anchor: attn_mask = mask_windows.unsqueeze(2) - mask_anchor.unsqueeze(1) else: attn_mask = mask_anchor.unsqueeze(2) - mask_windows.unsqueeze(1) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( attn_mask == 0, float(0.0) ) # nW, Wh**Ww, AWh*AWw return attn_mask def calculate_win_mask( input_resolution1, input_resolution2, window_size1, window_size2 ): """ Use case: 2) """ # calculate attention mask for SW-MSA # mask of window1: nW, Wh**Ww mask_windows1 = _fill_window(input_resolution1, window_size1) # mask of window2: nW, AWh*AWw mask_windows2 = _fill_window(input_resolution2, window_size2) attn_mask = mask_windows1.unsqueeze(2) - mask_windows2.unsqueeze(1) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( attn_mask == 0, float(0.0) ) # nW, Wh**Ww, AWh*AWw return attn_mask def _get_meshgrid_coords(start_coords, end_coords): coord_h = torch.arange(start_coords[0], end_coords[0]) coord_w = torch.arange(start_coords[1], end_coords[1]) coords = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")) # 2, Wh, Ww coords = torch.flatten(coords, 1) # 2, Wh*Ww return coords def get_relative_coords_table( window_size, pretrained_window_size=[0, 0], anchor_window_down_factor=1 ): """ Use case: 1) """ # get relative_coords_table ws = window_size aws = [w // anchor_window_down_factor for w in window_size] pws = pretrained_window_size paws = [w // anchor_window_down_factor for w in pretrained_window_size] ts = [(w1 + w2) // 2 for w1, w2 in zip(ws, aws)] pts = [(w1 + w2) // 2 for w1, w2 in zip(pws, paws)] # TODO: pretrained window size and pretrained anchor window size is only used here. # TODO: Investigate whether it is really important to use this setting when finetuning large window size # TODO: based on pretrained weights with small window size. coord_h = torch.arange(-(ts[0] - 1), ts[0], dtype=torch.float32) coord_w = torch.arange(-(ts[1] - 1), ts[1], dtype=torch.float32) table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")).permute( 1, 2, 0 ) table = table.contiguous().unsqueeze(0) # 1, Wh+AWh-1, Ww+AWw-1, 2 if pts[0] > 0: table[:, :, :, 0] /= pts[0] - 1 table[:, :, :, 1] /= pts[1] - 1 else: table[:, :, :, 0] /= ts[0] - 1 table[:, :, :, 1] /= ts[1] - 1 table *= 8 # normalize to -8, 8 table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8) return table def get_relative_coords_table_all( window_size, pretrained_window_size=[0, 0], anchor_window_down_factor=1 ): """ Use case: 3) Support all window shapes. Args: window_size: pretrained_window_size: anchor_window_down_factor: Returns: """ # get relative_coords_table ws = window_size aws = [w // anchor_window_down_factor for w in window_size] pws = pretrained_window_size paws = [w // anchor_window_down_factor for w in pretrained_window_size] # positive table size: (Ww - 1) - (Ww - AWw) // 2 ts_p = [w1 - 1 - (w1 - w2) // 2 for w1, w2 in zip(ws, aws)] # negative table size: -(AWw - 1) - (Ww - AWw) // 2 ts_n = [-(w2 - 1) - (w1 - w2) // 2 for w1, w2 in zip(ws, aws)] pts = [w1 - 1 - (w1 - w2) // 2 for w1, w2 in zip(pws, paws)] # TODO: pretrained window size and pretrained anchor window size is only used here. # TODO: Investigate whether it is really important to use this setting when finetuning large window size # TODO: based on pretrained weights with small window size. coord_h = torch.arange(ts_n[0], ts_p[0] + 1, dtype=torch.float32) coord_w = torch.arange(ts_n[1], ts_p[1] + 1, dtype=torch.float32) table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")).permute( 1, 2, 0 ) table = table.contiguous().unsqueeze(0) # 1, Wh+AWh-1, Ww+AWw-1, 2 if pts[0] > 0: table[:, :, :, 0] /= pts[0] table[:, :, :, 1] /= pts[1] else: table[:, :, :, 0] /= ts_p[0] table[:, :, :, 1] /= ts_p[1] table *= 8 # normalize to -8, 8 table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8) # 1, Wh+AWh-1, Ww+AWw-1, 2 return table def coords_diff(coords1, coords2, max_diff): # The coordinates starts from (-start_coord[0], -start_coord[1]) coords = coords1[:, :, None] - coords2[:, None, :] # 2, Wh*Ww, AWh*AWw coords = coords.permute(1, 2, 0).contiguous() # Wh*Ww, AWh*AWw, 2 coords[:, :, 0] += max_diff[0] - 1 # shift to start from 0 coords[:, :, 1] += max_diff[1] - 1 coords[:, :, 0] *= 2 * max_diff[1] - 1 idx = coords.sum(-1) # Wh*Ww, AWh*AWw return idx def get_relative_position_index( window_size, anchor_window_down_factor=1, window_to_anchor=True ): """ Use case: 1) """ # get pair-wise relative position index for each token inside the window ws = window_size aws = [w // anchor_window_down_factor for w in window_size] coords_anchor_end = [(w1 + w2) // 2 for w1, w2 in zip(ws, aws)] coords_anchor_start = [(w1 - w2) // 2 for w1, w2 in zip(ws, aws)] coords = _get_meshgrid_coords((0, 0), window_size) # 2, Wh*Ww coords_anchor = _get_meshgrid_coords(coords_anchor_start, coords_anchor_end) # 2, AWh*AWw if window_to_anchor: idx = coords_diff(coords, coords_anchor, max_diff=coords_anchor_end) else: idx = coords_diff(coords_anchor, coords, max_diff=coords_anchor_end) return idx # Wh*Ww, AWh*AWw or AWh*AWw, Wh*Ww def coords_diff_odd(coords1, coords2, start_coord, max_diff): # The coordinates starts from (-start_coord[0], -start_coord[1]) coords = coords1[:, :, None] - coords2[:, None, :] # 2, Wh*Ww, AWh*AWw coords = coords.permute(1, 2, 0).contiguous() # Wh*Ww, AWh*AWw, 2 coords[:, :, 0] += start_coord[0] # shift to start from 0 coords[:, :, 1] += start_coord[1] coords[:, :, 0] *= max_diff idx = coords.sum(-1) # Wh*Ww, AWh*AWw return idx def get_relative_position_index_all( window_size, anchor_window_down_factor=1, window_to_anchor=True ): """ Use case: 3) Support all window shapes: square window - square window rectangular window - rectangular window window - anchor anchor - window [8, 8] - [8, 8] [4, 86] - [2, 43] """ # get pair-wise relative position index for each token inside the window ws = window_size aws = [w // anchor_window_down_factor for w in window_size] coords_anchor_start = [(w1 - w2) // 2 for w1, w2 in zip(ws, aws)] coords_anchor_end = [s + w2 for s, w2 in zip(coords_anchor_start, aws)] coords = _get_meshgrid_coords((0, 0), window_size) # 2, Wh*Ww coords_anchor = _get_meshgrid_coords(coords_anchor_start, coords_anchor_end) # 2, AWh*AWw max_horizontal_diff = aws[1] + ws[1] - 1 if window_to_anchor: offset = [w2 + s - 1 for s, w2 in zip(coords_anchor_start, aws)] idx = coords_diff_odd(coords, coords_anchor, offset, max_horizontal_diff) else: offset = [w1 - s - 1 for s, w1 in zip(coords_anchor_start, ws)] idx = coords_diff_odd(coords_anchor, coords, offset, max_horizontal_diff) return idx # Wh*Ww, AWh*AWw or AWh*AWw, Wh*Ww def get_relative_position_index_simple( window_size, anchor_window_down_factor=1, window_to_anchor=True ): """ Use case: 3) This is a simplified version of get_relative_position_index_all The start coordinate of anchor window is also (0, 0) get pair-wise relative position index for each token inside the window """ ws = window_size aws = [w // anchor_window_down_factor for w in window_size] coords = _get_meshgrid_coords((0, 0), window_size) # 2, Wh*Ww coords_anchor = _get_meshgrid_coords((0, 0), aws) # 2, AWh*AWw max_horizontal_diff = aws[1] + ws[1] - 1 if window_to_anchor: offset = [w2 - 1 for w2 in aws] idx = coords_diff_odd(coords, coords_anchor, offset, max_horizontal_diff) else: offset = [w1 - 1 for w1 in ws] idx = coords_diff_odd(coords_anchor, coords, offset, max_horizontal_diff) return idx # Wh*Ww, AWh*AWw or AWh*AWw, Wh*Ww # def get_relative_position_index(window_size): # # This is a very early version # # get pair-wise relative position index for each token inside the window # coords = _get_meshgrid_coords(start_coords=(0, 0), end_coords=window_size) # coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww # coords = coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 # coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 # coords[:, :, 1] += window_size[1] - 1 # coords[:, :, 0] *= 2 * window_size[1] - 1 # idx = coords.sum(-1) # Wh*Ww, Wh*Ww # return idx def get_relative_win_position_index(window_size, anchor_window_size): """ Use case: 2) """ # get pair-wise relative position index for each token inside the window ws = window_size aws = anchor_window_size coords_anchor_end = [(w1 + w2) // 2 for w1, w2 in zip(ws, aws)] coords_anchor_start = [(w1 - w2) // 2 for w1, w2 in zip(ws, aws)] coords = _get_meshgrid_coords((0, 0), window_size) # 2, Wh*Ww coords_anchor = _get_meshgrid_coords(coords_anchor_start, coords_anchor_end) # 2, AWh*AWw coords = coords[:, :, None] - coords_anchor[:, None, :] # 2, Wh*Ww, AWh*AWw coords = coords.permute(1, 2, 0).contiguous() # Wh*Ww, AWh*AWw, 2 coords[:, :, 0] += coords_anchor_end[0] - 1 # shift to start from 0 coords[:, :, 1] += coords_anchor_end[1] - 1 coords[:, :, 0] *= 2 * coords_anchor_end[1] - 1 idx = coords.sum(-1) # Wh*Ww, AWh*AWw return idx # def get_relative_coords_table(window_size, pretrained_window_size): # # This is a very early version # # get relative_coords_table # ws = window_size # pws = pretrained_window_size # coord_h = torch.arange(-(ws[0] - 1), ws[0], dtype=torch.float32) # coord_w = torch.arange(-(ws[1] - 1), ws[1], dtype=torch.float32) # table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing='ij')).permute(1, 2, 0) # table = table.contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 # if pws[0] > 0: # table[:, :, :, 0] /= pws[0] - 1 # table[:, :, :, 1] /= pws[1] - 1 # else: # table[:, :, :, 0] /= ws[0] - 1 # table[:, :, :, 1] /= ws[1] - 1 # table *= 8 # normalize to -8, 8 # table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8) # return table def get_relative_win_coords_table( window_size, anchor_window_size, pretrained_window_size=[0, 0], pretrained_anchor_window_size=[0, 0], ): """ Use case: 2) """ # get relative_coords_table ws = window_size aws = anchor_window_size pws = pretrained_window_size paws = pretrained_anchor_window_size # TODO: pretrained window size and pretrained anchor window size is only used here. # TODO: Investigate whether it is really important to use this setting when finetuning large window size # TODO: based on pretrained weights with small window size. table_size = [(wsi + awsi) // 2 for wsi, awsi in zip(ws, aws)] table_size_pretrained = [(pwsi + pawsi) // 2 for pwsi, pawsi in zip(pws, paws)] coord_h = torch.arange(-(table_size[0] - 1), table_size[0], dtype=torch.float32) coord_w = torch.arange(-(table_size[1] - 1), table_size[1], dtype=torch.float32) table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")).permute( 1, 2, 0 ) table = table.contiguous().unsqueeze(0) # 1, Wh+AWh-1, Ww+AWw-1, 2 if table_size_pretrained[0] > 0: table[:, :, :, 0] /= table_size_pretrained[0] - 1 table[:, :, :, 1] /= table_size_pretrained[1] - 1 else: table[:, :, :, 0] /= table_size[0] - 1 table[:, :, :, 1] /= table_size[1] - 1 table *= 8 # normalize to -8, 8 table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8) return table if __name__ == "__main__": table = get_relative_coords_table_all((4, 86), anchor_window_down_factor=2) table = table.view(-1, 2) index1 = get_relative_position_index_all((4, 86), 2, False) index2 = get_relative_position_index_simple((4, 86), 2, False) print(index2) index3 = get_relative_position_index_all((4, 86), 2) index4 = get_relative_position_index_simple((4, 86), 2) print(index4) print( table.shape, index2.shape, index2.max(), index2.min(), index4.shape, index4.max(), index4.min(), torch.allclose(index1, index2), torch.allclose(index3, index4), ) table = get_relative_coords_table_all((4, 86), anchor_window_down_factor=1) table = table.view(-1, 2) index1 = get_relative_position_index_all((4, 86), 1, False) index2 = get_relative_position_index_simple((4, 86), 1, False) # print(index1) index3 = get_relative_position_index_all((4, 86), 1) index4 = get_relative_position_index_simple((4, 86), 1) # print(index2) print( table.shape, index2.shape, index2.max(), index2.min(), index4.shape, index4.max(), index4.min(), torch.allclose(index1, index2), torch.allclose(index3, index4), ) table = get_relative_coords_table_all((8, 8), anchor_window_down_factor=2) table = table.view(-1, 2) index1 = get_relative_position_index_all((8, 8), 2, False) index2 = get_relative_position_index_simple((8, 8), 2, False) # print(index1) index3 = get_relative_position_index_all((8, 8), 2) index4 = get_relative_position_index_simple((8, 8), 2) # print(index2) print( table.shape, index2.shape, index2.max(), index2.min(), index4.shape, index4.max(), index4.min(), torch.allclose(index1, index2), torch.allclose(index3, index4), ) table = get_relative_coords_table_all((8, 8), anchor_window_down_factor=1) table = table.view(-1, 2) index1 = get_relative_position_index_all((8, 8), 1, False) index2 = get_relative_position_index_simple((8, 8), 1, False) # print(index1) index3 = get_relative_position_index_all((8, 8), 1) index4 = get_relative_position_index_simple((8, 8), 1) # print(index2) print( table.shape, index2.shape, index2.max(), index2.min(), index4.shape, index4.max(), index4.min(), torch.allclose(index1, index2), torch.allclose(index3, index4), )