import torch import torch.nn as nn import numpy as np approx_gelu = lambda: nn.GELU(approximate="tanh") def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool, use_kernel: bool): if use_kernel: try: from apex.normalization import FusedLayerNorm return FusedLayerNorm(hidden_size, elementwise_affine=affine, eps=eps) except ImportError: raise RuntimeError("FusedLayerNorm not available. Please install apex.") else: return nn.LayerNorm(hidden_size, eps, elementwise_affine=affine) def t2i_modulate(x, shift, scale): return x * (1 + scale) + shift # =============================================== # Sine/Cosine Positional Embedding Functions # =============================================== # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scale=1.0, base_size=None): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ if not isinstance(grid_size, tuple): grid_size = (grid_size, grid_size) grid_h = np.arange(grid_size[0], dtype=np.float32) / scale grid_w = np.arange(grid_size[1], dtype=np.float32) / scale if base_size is not None: grid_h *= base_size / grid_size[0] grid_w *= base_size / grid_size[1] grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed(embed_dim, length, scale=1.0): pos = np.arange(0, length)[..., None] / scale return get_1d_sincos_pos_embed_from_grid(embed_dim, pos) def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb