#!/usr/bin/env python3 # Portions Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import gzip import html import io import math from functools import lru_cache from typing import Callable, List, Optional import ftfy import numpy as np import regex as re import torch import torch.nn as nn from iopath.common.file_io import g_pathmgr from timm.models.layers import trunc_normal_ from .helpers import cast_if_src_dtype, VerboseNNModule def get_sinusoid_encoding_table(n_position, d_hid): """Sinusoid position encoding table""" # TODO: make it with torch instead of numpy def get_position_angle_vec(position): return [ position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid) ] sinusoid_table = np.array( [get_position_angle_vec(pos_i) for pos_i in range(n_position)] ) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 return torch.FloatTensor(sinusoid_table).unsqueeze(0) def interpolate_pos_encoding_2d(target_spatial_size, pos_embed): N = pos_embed.shape[1] if N == target_spatial_size: return pos_embed dim = pos_embed.shape[-1] # nn.functional.interpolate doesn't work with bfloat16 so we cast to float32 pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32) pos_embed = nn.functional.interpolate( pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( 0, 3, 1, 2 ), scale_factor=math.sqrt(target_spatial_size / N), mode="bicubic", ) if updated: pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16) pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return pos_embed def interpolate_pos_encoding( npatch_per_img, pos_embed, patches_layout, input_shape=None, first_patch_idx=1, ): assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none" N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists if npatch_per_img == N: return pos_embed assert ( patches_layout[-1] == patches_layout[-2] ), "Interpolation of pos embed not supported for non-square layouts" class_emb = pos_embed[:, :first_patch_idx] pos_embed = pos_embed[:, first_patch_idx:] if input_shape is None or patches_layout[0] == 1: # simple 2D pos embedding, no temporal component pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed) elif patches_layout[0] > 1: # pos embed has a temporal component assert len(input_shape) == 4, "temporal interpolation not supported" # we only support 2D interpolation in this case num_frames = patches_layout[0] num_spatial_tokens = patches_layout[1] * patches_layout[2] pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1) # interpolate embedding for zeroth frame pos_embed = interpolate_pos_encoding_2d( npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0) ) else: raise ValueError("This type of interpolation isn't implemented") return torch.cat((class_emb, pos_embed), dim=1) def _get_pos_embedding( npatch_per_img, pos_embed, patches_layout, input_shape, first_patch_idx=1, ): pos_embed = interpolate_pos_encoding( npatch_per_img, pos_embed, patches_layout, input_shape=input_shape, first_patch_idx=first_patch_idx, ) return pos_embed class PatchEmbedGeneric(nn.Module): """ PatchEmbed from Hydra """ def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None): super().__init__() if len(proj_stem) > 1: self.proj = nn.Sequential(*proj_stem) else: # Special case to be able to load pre-trained models that were # trained with a standard stem self.proj = proj_stem[0] self.norm_layer = norm_layer def get_patch_layout(self, img_size): with torch.no_grad(): dummy_img = torch.zeros( [ 1, ] + img_size ) dummy_out = self.proj(dummy_img) embed_dim = dummy_out.shape[1] patches_layout = tuple(dummy_out.shape[2:]) num_patches = np.prod(patches_layout) return patches_layout, num_patches, embed_dim def forward(self, x): x = self.proj(x) # B C (T_I_V_A.txt) H W -> B (T_I_V_A.txt)HW C x = x.flatten(2).transpose(1, 2) if self.norm_layer is not None: x = self.norm_layer(x) return x class SpatioTemporalPosEmbeddingHelper(VerboseNNModule): def __init__( self, patches_layout: List, num_patches: int, num_cls_tokens: int, embed_dim: int, learnable: bool, ) -> None: super().__init__() self.num_cls_tokens = num_cls_tokens self.patches_layout = patches_layout self.num_patches = num_patches self.num_tokens = num_cls_tokens + num_patches self.learnable = learnable if self.learnable: self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) trunc_normal_(self.pos_embed, std=0.02) else: self.register_buffer( "pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim) ) def get_pos_embedding(self, vision_input, all_vision_tokens): input_shape = vision_input.shape pos_embed = _get_pos_embedding( all_vision_tokens.size(1) - self.num_cls_tokens, pos_embed=self.pos_embed, patches_layout=self.patches_layout, input_shape=input_shape, first_patch_idx=self.num_cls_tokens, ) return pos_embed class RGBDTPreprocessor(VerboseNNModule): def __init__( self, rgbt_stem: PatchEmbedGeneric, depth_stem: PatchEmbedGeneric, img_size: List = (3, 224, 224), num_cls_tokens: int = 1, pos_embed_fn: Callable = None, use_type_embed: bool = False, init_param_style: str = "openclip", ) -> None: super().__init__() stem = rgbt_stem if rgbt_stem is not None else depth_stem ( self.patches_layout, self.num_patches, self.embed_dim, ) = stem.get_patch_layout(img_size) self.rgbt_stem = rgbt_stem self.depth_stem = depth_stem self.use_pos_embed = pos_embed_fn is not None self.use_type_embed = use_type_embed self.num_cls_tokens = num_cls_tokens if self.use_pos_embed: self.pos_embedding_helper = pos_embed_fn( patches_layout=self.patches_layout, num_cls_tokens=num_cls_tokens, num_patches=self.num_patches, embed_dim=self.embed_dim, ) if self.num_cls_tokens > 0: self.cls_token = nn.Parameter( torch.zeros(1, self.num_cls_tokens, self.embed_dim) ) if self.use_type_embed: self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) self.init_parameters(init_param_style) @torch.no_grad() def init_parameters(self, init_param_style): if init_param_style == "openclip": # OpenCLIP style initialization scale = self.embed_dim**-0.5 if self.use_pos_embed: nn.init.normal_(self.pos_embedding_helper.pos_embed) self.pos_embedding_helper.pos_embed *= scale if self.num_cls_tokens > 0: nn.init.normal_(self.cls_token) self.cls_token *= scale elif init_param_style == "vit": self.cls_token.data.fill_(0) else: raise ValueError(f"Unknown init {init_param_style}") if self.use_type_embed: nn.init.normal_(self.type_embed) def tokenize_input_and_cls_pos(self, input, stem, mask): # tokens is of shape B x L x D tokens = stem(input) assert tokens.ndim == 3 assert tokens.shape[2] == self.embed_dim B = tokens.shape[0] if self.num_cls_tokens > 0: class_tokens = self.cls_token.expand( B, -1, -1 ) # stole class_tokens impl from Phil Wang, thanks tokens = torch.cat((class_tokens, tokens), dim=1) if self.use_pos_embed: pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens) tokens = tokens + pos_embed if self.use_type_embed: tokens = tokens + self.type_embed.expand(B, -1, -1) return tokens def forward(self, vision=None, depth=None, patch_mask=None): if patch_mask is not None: raise NotImplementedError() if vision is not None: vision_tokens = self.tokenize_input_and_cls_pos( vision, self.rgbt_stem, patch_mask ) if depth is not None: depth_tokens = self.tokenize_input_and_cls_pos( depth, self.depth_stem, patch_mask ) # aggregate tokens if vision is not None and depth is not None: final_tokens = vision_tokens + depth_tokens else: final_tokens = vision_tokens if vision is not None else depth_tokens return_dict = { "trunk": { "tokens": final_tokens, }, "head": {}, } return return_dict class AudioPreprocessor(RGBDTPreprocessor): def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None: super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs) def forward(self, audio=None): return super().forward(vision=audio) class ThermalPreprocessor(RGBDTPreprocessor): def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None: super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs) def forward(self, thermal=None): return super().forward(vision=thermal) def build_causal_attention_mask(context_length): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(context_length, context_length, requires_grad=False) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask class TextPreprocessor(VerboseNNModule): def __init__( self, vocab_size: int, context_length: int, embed_dim: int, causal_masking: bool, supply_seq_len_to_head: bool = True, num_cls_tokens: int = 0, init_param_style: str = "openclip", ) -> None: super().__init__() self.vocab_size = vocab_size self.context_length = context_length self.token_embedding = nn.Embedding(vocab_size, embed_dim) self.pos_embed = nn.Parameter( torch.empty(1, self.context_length + num_cls_tokens, embed_dim) ) self.causal_masking = causal_masking if self.causal_masking: mask = build_causal_attention_mask(self.context_length) # register the mask as a buffer so it can be moved to the right device self.register_buffer("mask", mask) self.supply_seq_len_to_head = supply_seq_len_to_head self.num_cls_tokens = num_cls_tokens self.embed_dim = embed_dim if num_cls_tokens > 0: assert self.causal_masking is False, "Masking + CLS token isn't implemented" self.cls_token = nn.Parameter( torch.zeros(1, self.num_cls_tokens, embed_dim) ) self.init_parameters(init_param_style) @torch.no_grad() def init_parameters(self, init_param_style="openclip"): # OpenCLIP style initialization nn.init.normal_(self.token_embedding.weight, std=0.02) nn.init.normal_(self.pos_embed, std=0.01) if init_param_style == "openclip": # OpenCLIP style initialization scale = self.embed_dim**-0.5 if self.num_cls_tokens > 0: nn.init.normal_(self.cls_token) self.cls_token *= scale elif init_param_style == "vit": self.cls_token.data.fill_(0) else: raise ValueError(f"Unknown init {init_param_style}") def forward(self, text): # text tokens are of shape B x L x D text_tokens = self.token_embedding(text) # concat CLS tokens if any if self.num_cls_tokens > 0: B = text_tokens.shape[0] class_tokens = self.cls_token.expand( B, -1, -1 ) # stole class_tokens impl from Phil Wang, thanks text_tokens = torch.cat((class_tokens, text_tokens), dim=1) text_tokens = text_tokens + self.pos_embed return_dict = { "trunk": { "tokens": text_tokens, }, "head": {}, } # Compute sequence length after adding CLS tokens if self.supply_seq_len_to_head: text_lengths = text.argmax(dim=-1) return_dict["head"] = { "seq_len": text_lengths, } if self.causal_masking: return_dict["trunk"].update({"attn_mask": self.mask}) return return_dict class Im2Video(nn.Module): """Convert an image into a trivial video.""" def __init__(self, time_dim=2): super().__init__() self.time_dim = time_dim def forward(self, x): if x.ndim == 4: # B, C, H, W -> B, C, T_I_V_A.txt, H, W return x.unsqueeze(self.time_dim) elif x.ndim == 5: return x else: raise ValueError(f"Dimension incorrect {x.shape}") class PadIm2Video(Im2Video): def __init__(self, ntimes, pad_type, time_dim=2): super().__init__(time_dim=time_dim) assert ntimes > 0 assert pad_type in ["zero", "repeat"] self.ntimes = ntimes self.pad_type = pad_type def forward(self, x): x = super().forward(x) if x.shape[self.time_dim] == 1: if self.pad_type == "repeat": new_shape = [1] * len(x.shape) new_shape[self.time_dim] = self.ntimes x = x.repeat(new_shape) elif self.pad_type == "zero": padarg = [0, 0] * len(x.shape) padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim] x = nn.functional.pad(x, padarg) return x # Modified from github.com/openai/CLIP @lru_cache() def bytes_to_unicode(): """ Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ bs = ( list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) ) cs = bs[:] n = 0 for b in range(2**8): if b not in bs: bs.append(b) cs.append(2**8 + n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) def get_pairs(word): """Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length strings). """ pairs = set() prev_char = word[0] for char in word[1:]: pairs.add((prev_char, char)) prev_char = char return pairs def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() def whitespace_clean(text): text = re.sub(r"\s+", " ", text) text = text.strip() return text class SimpleTokenizer(object): def __init__(self, bpe_path: str, context_length=77): self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} with g_pathmgr.open(bpe_path, "rb") as fh: bpe_bytes = io.BytesIO(fh.read()) merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n") merges = merges[1 : 49152 - 256 - 2 + 1] merges = [tuple(merge.split()) for merge in merges] vocab = list(bytes_to_unicode().values()) vocab = vocab + [v + "" for v in vocab] for merge in merges: vocab.append("".join(merge)) vocab.extend(["<|startoftext|>", "<|endoftext|>"]) self.encoder = dict(zip(vocab, range(len(vocab)))) self.decoder = {v: k for k, v in self.encoder.items()} self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = { "<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>", } self.pat = re.compile( r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE, ) self.context_length = context_length def bpe(self, token): if token in self.cache: return self.cache[token] word = tuple(token[:-1]) + (token[-1] + "",) pairs = get_pairs(word) if not pairs: return token + "" while True: bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) if bigram not in self.bpe_ranks: break first, second = bigram new_word = [] i = 0 while i < len(word): try: j = word.index(first, i) new_word.extend(word[i:j]) i = j except: new_word.extend(word[i:]) break if word[i] == first and i < len(word) - 1 and word[i + 1] == second: new_word.append(first + second) i += 2 else: new_word.append(word[i]) i += 1 new_word = tuple(new_word) word = new_word if len(word) == 1: break else: pairs = get_pairs(word) word = " ".join(word) self.cache[token] = word return word def encode(self, text): bpe_tokens = [] text = whitespace_clean(basic_clean(text)).lower() for token in re.findall(self.pat, text): token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) bpe_tokens.extend( self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") ) return bpe_tokens def decode(self, tokens): text = "".join([self.decoder[token] for token in tokens]) text = ( bytearray([self.byte_decoder[c] for c in text]) .decode("utf-8", errors="replace") .replace("", " ") ) return text def __call__(self, texts, context_length=None): if not context_length: context_length = self.context_length if isinstance(texts, str): texts = [texts] sot_token = self.encoder["<|startoftext|>"] eot_token = self.encoder["<|endoftext|>"] all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): tokens = tokens[:context_length] result[i, : len(tokens)] = torch.tensor(tokens) if len(result) == 1: return result[0] return result class IMUPreprocessor(VerboseNNModule): def __init__( self, kernel_size: int, imu_stem: PatchEmbedGeneric, embed_dim: int, img_size: List = (6, 2000), num_cls_tokens: int = 1, pos_embed_fn: Callable = None, init_param_style: str = "openclip", ) -> None: super().__init__() stem = imu_stem self.imu_stem = imu_stem self.embed_dim = embed_dim self.use_pos_embed = pos_embed_fn is not None self.num_cls_tokens = num_cls_tokens self.kernel_size = kernel_size self.pos_embed = nn.Parameter( torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim) ) if self.num_cls_tokens > 0: self.cls_token = nn.Parameter( torch.zeros(1, self.num_cls_tokens, self.embed_dim) ) self.init_parameters(init_param_style) @torch.no_grad() def init_parameters(self, init_param_style): nn.init.normal_(self.pos_embed, std=0.01) if init_param_style == "openclip": # OpenCLIP style initialization scale = self.embed_dim**-0.5 if self.num_cls_tokens > 0: nn.init.normal_(self.cls_token) self.cls_token *= scale elif init_param_style == "vit": self.cls_token.data.fill_(0) else: raise ValueError(f"Unknown init {init_param_style}") def tokenize_input_and_cls_pos(self, input, stem): # tokens is of shape B x L x D tokens = stem.norm_layer(stem.proj(input)) assert tokens.ndim == 3 assert tokens.shape[2] == self.embed_dim B = tokens.shape[0] if self.num_cls_tokens > 0: class_tokens = self.cls_token.expand( B, -1, -1 ) # stole class_tokens impl from Phil Wang, thanks tokens = torch.cat((class_tokens, tokens), dim=1) if self.use_pos_embed: tokens = tokens + self.pos_embed return tokens def forward(self, imu): # Patchify imu = imu.unfold( -1, self.kernel_size, self.kernel_size, ).permute(0, 2, 1, 3) imu = imu.reshape(imu.size(0), imu.size(1), -1) imu_tokens = self.tokenize_input_and_cls_pos( imu, self.imu_stem, ) return_dict = { "trunk": { "tokens": imu_tokens, }, "head": {}, } return return_dict