| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Model utils for PolyViT.""" |
|
|
| from typing import Any, Optional |
|
|
| from absl import logging |
| import flax |
| import jax.numpy as jnp |
| import ml_collections |
| import numpy as np |
| from scenic.common_lib import debug_utils |
| import scenic.projects.mbt.model_utils as mbt_utils |
| import scenic.projects.vivit.model_utils as vivit_utils |
| import scipy |
|
|
|
|
| def initialize_from_polyvit_train_state( |
| train_state: Any, |
| restored_train_state: Any, |
| tokenizer_to_init_from: Optional[str] = None, |
| tokenizer_to_init: Optional[str] = None, |
| resolution_to_init: Optional[Any] = None, |
| initialize_heads: bool = False) -> Any: |
| """Initializes PolyViT with other PolyViT body.""" |
|
|
| params = flax.core.unfreeze(train_state.params) |
| restored_params = restored_train_state.params |
| restored_params = flax.core.unfreeze(restored_params) |
|
|
| for m_key, m_params in restored_params.items(): |
| if m_key == 'vit_encoder': |
| params[m_key] = m_params |
| elif m_key == 'tokenizer': |
| |
| if tokenizer_to_init_from is None: |
| |
| for tm_key, tm_params in m_params.items(): |
| if tm_key in params[m_key]: |
| params[m_key][tm_key] = tm_params |
| else: |
| |
| |
| params['tokenizer'][tokenizer_to_init]['cls'] = m_params[ |
| tokenizer_to_init_from]['cls'] |
| params['tokenizer'][tokenizer_to_init]['embedding']['bias'] = m_params[ |
| tokenizer_to_init_from]['embedding']['bias'] |
| |
| |
| if tokenizer_to_init_from == 'tokenizer3d': |
| params['tokenizer'][tokenizer_to_init]['embedding'][ |
| 'kernel'] = m_params['tokenizer3d']['embedding']['kernel'].sum( |
| axis=0) |
| elif tokenizer_to_init == 'tokenizer3d': |
| |
| |
| kernel3d = np.zeros([4, 16, 16, 3, 768]) |
| |
| kernel3d[ |
| 2, :, :] = m_params[tokenizer_to_init_from]['embedding']['kernel'] |
| params['tokenizer']['tokenizer3d']['embedding']['kernel'] = kernel3d |
| else: |
| params['tokenizer'][tokenizer_to_init]['embedding'][ |
| 'kernel'] = m_params[tokenizer_to_init_from]['embedding'][ |
| 'kernel'] |
|
|
| |
| pos_embedding = m_params[tokenizer_to_init_from]['posembed_input'][ |
| 'pos_embedding'] |
| |
| new_pos_embedding = pos_embedding[:, 1:] |
| if tokenizer_to_init_from == 'tokenizer3d': |
| |
| |
| new_pos_embedding = new_pos_embedding.reshape(1, 8, -1, |
| pos_embedding.shape[2]) |
| |
| |
| new_pos_embedding = new_pos_embedding.mean(axis=1) |
| |
| if resolution_to_init is not None: |
| |
| |
| |
| if tokenizer_to_init_from == 'tokenizer2d': |
| new_pos_embedding = new_pos_embedding.reshape(24, 24, -1) |
| zoom = (resolution_to_init[0] / 384, resolution_to_init[1] / 384, 1) |
| elif tokenizer_to_init_from == 'tokenizer_spec': |
| new_pos_embedding = new_pos_embedding.reshape(50, 8, -1) |
| zoom = (resolution_to_init[0] / 800, resolution_to_init[1] / 128, 1) |
| else: |
| new_pos_embedding = new_pos_embedding.reshape(14, 14, -1) |
| zoom = (resolution_to_init[0] / 224, resolution_to_init[1] / 224, 1) |
| new_pos_embedding = scipy.ndimage.zoom( |
| new_pos_embedding, zoom, order=1) |
| new_pos_embedding = new_pos_embedding.reshape( |
| 1, new_pos_embedding.shape[0] * new_pos_embedding.shape[1], |
| -1) |
| if tokenizer_to_init == 'tokenizer3d': |
| |
| |
| new_pos_embedding = np.tile(new_pos_embedding, (1, 8, 1)) |
| |
| new_pos_embedding = np.concatenate( |
| [pos_embedding[:, :1], new_pos_embedding], axis=1) |
| params['tokenizer'][tokenizer_to_init][ |
| 'posembed_input']['pos_embedding'] = new_pos_embedding |
| elif initialize_heads and m_key in params: |
| |
| params[m_key] = m_params |
|
|
| return train_state.replace(params=flax.core.freeze(params)) |
|
|
|
|
| def initialize_from_mbt_train_state( |
| train_state: Any, |
| restored_train_state: Any, |
| tokenizer_to_init: str = 'tokenizer_spec', |
| resolution_to_init: Optional[Any] = None, |
| initialize_head: bool = False, |
| ) -> Any: |
| """Initializes PolyViT with AViT body.""" |
|
|
| params = flax.core.unfreeze(train_state.params) |
| restored_params = flax.core.unfreeze(restored_train_state.params) |
|
|
| for m_key, m_params in restored_params.items(): |
| if m_key == 'Transformer': |
| for tm_key, tm_params in m_params.items(): |
| if tm_key == 'posembed_input_spec': |
| |
| if tokenizer_to_init == 'tokenizer_spec': |
| params['tokenizer']['tokenizer_spec']['posembed_input'] = tm_params |
| else: |
| |
| pos_embedding = tm_params['pos_embedding'] |
| |
| new_pos_embedding = pos_embedding[:, 1:] |
| |
| if resolution_to_init is not None: |
| |
| |
| new_pos_embedding = new_pos_embedding.reshape(50, 8, -1) |
| zoom = (resolution_to_init[0] / 800, resolution_to_init[1] / 128, |
| 1) |
| new_pos_embedding = scipy.ndimage.zoom( |
| new_pos_embedding, zoom, order=1) |
| new_pos_embedding = new_pos_embedding.reshape( |
| 1, new_pos_embedding.shape[0] * new_pos_embedding.shape[1], |
| -1) |
| if tokenizer_to_init == 'tokenizer3d': |
| |
| |
| new_pos_embedding = np.tile(new_pos_embedding, (1, 8, 1)) |
| |
| new_pos_embedding = np.concatenate( |
| [pos_embedding[:, :1], new_pos_embedding], axis=1) |
| params['tokenizer'][tokenizer_to_init][ |
| 'posembed_input']['pos_embedding'] = new_pos_embedding |
| elif tm_key.startswith('encoderblock'): |
| |
| params['vit_encoder'][tm_key[:-5]] = tm_params |
| elif tm_key in params['vit_encoder']: |
| params['vit_encoder'][tm_key] = tm_params |
| elif m_key == 'cls': |
| params['tokenizer'][tokenizer_to_init]['cls'] = m_params |
| elif m_key == 'embedding_spec': |
| |
| if tokenizer_to_init in ['tokenizer2d', 'tokenizer_spec']: |
| params['tokenizer'][tokenizer_to_init]['embedding'] = m_params |
| else: |
| |
| |
| kernel3d = np.zeros([4, 16, 16, 3, 768]) |
| |
| kernel3d[2, :, :] = m_params['kernel'] |
| params['tokenizer'][tokenizer_to_init]['embedding'][ |
| 'kernel'] = kernel3d |
| params['tokenizer'][tokenizer_to_init]['embedding']['bias'] = m_params[ |
| 'bias'] |
| elif m_key == 'output_projection' and initialize_head: |
| |
| head_name = [ |
| x for x in params.keys() if x not in ['tokenizer', 'vit_encoder'] |
| ][0] |
| params[head_name]['output_projection'] = m_params |
|
|
| return train_state.replace(params=flax.core.freeze(params)) |
|
|
|
|
| def initialize_from_vivit_train_state( |
| train_state: Any, |
| restored_train_state: Any, |
| tokenizer_to_init: str = 'tokenizer3d', |
| resolution_to_init: Optional[Any] = None, |
| initialize_head: bool = False) -> Any: |
| """Initializes PolyViT with ViViT body.""" |
|
|
| params = flax.core.unfreeze(train_state.params) |
| restored_params = flax.core.unfreeze(restored_train_state.params) |
|
|
| for m_key, m_params in restored_params.items(): |
| |
| if m_key == 'Transformer': |
| for tm_key, tm_params in m_params.items(): |
| if tm_key == 'posembed_input': |
| |
| if tokenizer_to_init == 'tokenizer3d': |
| params['tokenizer']['tokenizer3d']['posembed_input'] = tm_params |
| else: |
| |
| pos_embedding = tm_params['pos_embedding'] |
| |
| |
| |
| new_pos_embedding = pos_embedding[:, 1:].reshape( |
| 1, 8, -1, pos_embedding.shape[2]) |
| |
| |
| new_pos_embedding = new_pos_embedding.mean(axis=1) |
| |
| if resolution_to_init is not None: |
| |
| |
| new_pos_embedding = new_pos_embedding.reshape(14, 14, -1) |
| zoom = (resolution_to_init[0] / 224, resolution_to_init[1] / 224, |
| 1) |
| new_pos_embedding = scipy.ndimage.zoom( |
| new_pos_embedding, zoom, order=1) |
| new_pos_embedding = new_pos_embedding.reshape( |
| 1, new_pos_embedding.shape[0] * new_pos_embedding.shape[1], |
| -1) |
| |
| new_pos_embedding = np.concatenate( |
| [pos_embedding[:, :1], new_pos_embedding], axis=1) |
| params['tokenizer'][tokenizer_to_init][ |
| 'posembed_input']['pos_embedding'] = new_pos_embedding |
| elif tm_key in params['vit_encoder']: |
| params['vit_encoder'][tm_key] = tm_params |
| elif m_key == 'cls': |
| params['tokenizer'][tokenizer_to_init]['cls'] = m_params |
| elif m_key == 'embedding': |
| |
| if tokenizer_to_init == 'tokenizer3d': |
| params['tokenizer']['tokenizer3d']['embedding'] = m_params |
| else: |
| params['tokenizer'][tokenizer_to_init]['embedding']['bias'] = m_params[ |
| 'bias'] |
| |
| |
| params['tokenizer'][tokenizer_to_init]['embedding'][ |
| 'kernel'] = m_params['kernel'].sum(axis=0) |
| elif m_key == 'output_projection' and initialize_head: |
| |
| head_name = [ |
| x for x in params.keys() if x not in ['tokenizer', 'vit_encoder'] |
| ][0] |
| params[head_name]['output_projection'] = m_params |
|
|
| return train_state.replace(params=flax.core.freeze(params)) |
|
|
|
|
| def initialise_from_vit_train_state( |
| config, |
| train_state: Any, |
| restored_train_state: Any, |
| restored_model_cfg: ml_collections.ConfigDict, |
| log_initialised_param_shapes: bool = True) -> Any: |
| """Updates the train_state with data from restored_train_state (ViT model). |
| |
| This function is written to be used for 'fine-tuning' experiments. Here, we |
| do some surgery to support larger resolutions (longer sequence length) in |
| the transformer block, with respect to the learned pos-embeddings. |
| |
| Args: |
| config: Configurations for the model being updated. |
| train_state: A raw TrainState for the model. |
| restored_train_state: A TrainState that is loaded with parameters/state of a |
| pretrained model. |
| restored_model_cfg: Configuration of the model from which the |
| restored_train_state come from. Usually used for some asserts. |
| log_initialised_param_shapes: If true, print tabular summary of all the |
| variables in the model once they have been initialised. |
| |
| Returns: |
| Updated train_state. |
| """ |
| |
| params = flax.core.unfreeze(train_state.params) |
| restored_params = flax.core.unfreeze(restored_train_state.params) |
|
|
| |
| for m_key, m_params in restored_params.items(): |
| if m_key in ['Transformer', 'SpatialTransformer']: |
| for tm_key, tm_params in m_params.items(): |
| if tm_key == 'posembed_input': |
| if 'tokenizer2d' in params['tokenizer']: |
| init_posemb(params['tokenizer']['tokenizer2d'], m_params, config, |
| restored_model_cfg, 'resize') |
| if 'tokenizer3d' in params['tokenizer']: |
| init_posemb(params['tokenizer']['tokenizer3d'], m_params, config, |
| restored_model_cfg, |
| config.init_from.positional_embed_size_change) |
| if 'tokenizer_spec' in params['tokenizer']: |
| init_spec_posemb(params['tokenizer']['tokenizer_spec'], m_params, |
| config, |
| restored_model_cfg) |
| elif 'encoderblock' in tm_key: |
| init_encoderblock(params, m_params, tm_key) |
| else: |
| params['vit_encoder'][tm_key] = tm_params |
| elif m_key == 'cls': |
| for tokenizer_name in ['tokenizer2d', 'tokenizer3d', 'tokenizer_spec']: |
| if tokenizer_name in params['tokenizer']: |
| params['tokenizer'][tokenizer_name]['cls'] = m_params |
| elif m_key == 'embedding': |
| for tokenizer_name in ['tokenizer2d', 'tokenizer_spec']: |
| if tokenizer_name in params['tokenizer']: |
| params['tokenizer'][tokenizer_name]['embedding'] = m_params |
| if 'tokenizer3d' in params['tokenizer']: |
| init_embedding(params['tokenizer']['tokenizer3d'], m_params, config) |
| else: |
| if m_key in train_state.params: |
| params[m_key] = m_params |
| else: |
| logging.info('Skipping %s. In restored model but not in target', m_key) |
|
|
| if log_initialised_param_shapes: |
| logging.info('Parameter summary after initialising from train state') |
| debug_utils.log_param_shapes(params) |
| return train_state.replace(params=flax.core.freeze(params)) |
|
|
|
|
| def init_posemb(to_params, from_params, config, restored_model_cfg, |
| positional_embed_size_change): |
| """Initialize the positional embeddings.""" |
| with_cls_token, num_video_frames = get_cls_token_and_video_frames(config) |
| restored_with_cls_token, _ = get_cls_token_and_video_frames( |
| restored_model_cfg) |
| if config.init_from.restore_positional_embedding: |
| posemb = to_params['posembed_input']['pos_embedding'] |
| restored_posemb = from_params['posembed_input']['pos_embedding'] |
| if restored_posemb.shape != posemb.shape: |
| |
| |
| logging.info('Adapting positional embeddings from %s to %s', |
| restored_posemb.shape, posemb.shape) |
| ntok = posemb.shape[1] |
| if restored_with_cls_token: |
| |
| cls_tok = restored_posemb[:, :1] |
| restored_posemb_grid = restored_posemb[0, 1:] |
| else: |
| cls_tok = restored_posemb[:, :0] |
| restored_posemb_grid = restored_posemb[0] |
| if with_cls_token: |
| ntok -= 1 |
| restored_gs = int(np.sqrt(len(restored_posemb_grid))) |
| gs = int(np.sqrt(ntok)) |
| if with_cls_token != restored_with_cls_token: |
| logging.warning('Only one of target and restored model uses' |
| 'classification token') |
| if restored_gs == gs: |
| |
| restored_posemb = restored_posemb_grid[None, ...] |
|
|
| if restored_gs != gs: |
| if positional_embed_size_change == 'resize': |
| restored_posemb_grid = vivit_utils.interpolate_positional_embeddings( |
| restored_posemb_grid, ntok) |
|
|
| elif positional_embed_size_change == 'tile': |
| restored_posemb_grid = vivit_utils.tile_positional_embeddings( |
| restored_posemb_grid, ntok) |
|
|
| elif positional_embed_size_change == 'resize_tile': |
| n_frames = ( |
| num_video_frames // config.model.modalities.video.patches.size[2]) |
| tokens_per_frame = ntok // n_frames |
| restored_posemb_grid = vivit_utils.interpolate_positional_embeddings( |
| restored_posemb_grid, tokens_per_frame) |
| restored_posemb_grid = restored_posemb_grid[0] |
| restored_posemb_grid = vivit_utils.tile_positional_embeddings( |
| restored_posemb_grid, ntok) |
|
|
| else: |
| raise AssertionError( |
| 'Unknown positional embedding size changing method') |
| |
| if with_cls_token: |
| restored_posemb = jnp.array( |
| np.concatenate([cls_tok, restored_posemb_grid], axis=1)) |
| else: |
| restored_posemb = restored_posemb_grid |
|
|
| to_params['posembed_input']['pos_embedding'] = restored_posemb |
| else: |
| logging.info('Not restoring positional encodings from pretrained model') |
|
|
|
|
| def init_spec_posemb(to_params, from_params, config, restored_model_cfg): |
| """Initialize the spectrogram positional embeddings.""" |
| with_cls_token, _ = get_cls_token_and_video_frames(config) |
| restored_with_cls_token, _ = get_cls_token_and_video_frames( |
| restored_model_cfg) |
| if config.init_from.restore_positional_embedding: |
| posemb = to_params['posembed_input']['pos_embedding'] |
| restored_posemb = from_params['posembed_input']['pos_embedding'] |
| |
| |
| logging.info('Adapting spectrogram positional embeddings from %s to %s', |
| restored_posemb.shape, posemb.shape) |
| ntok = posemb.shape[1] |
| if restored_with_cls_token: |
| |
| cls_tok = restored_posemb[:, :1] |
| restored_posemb_grid = restored_posemb[0, 1:] |
| else: |
| cls_tok = restored_posemb[:, :0] |
| restored_posemb_grid = restored_posemb[0] |
| if with_cls_token: |
| ntok -= 1 |
|
|
| gh = ((config.model.modalities.audio.spec_shape[0] * |
| config.model.modalities.audio.num_spec_frames) // |
| config.model.modalities.audio.patches.size[0]) |
| gw = (config.model.modalities.audio.spec_shape[1] // |
| config.model.modalities.audio.patches.size[1]) |
| tokens_per_frame = (gh, gw) |
|
|
| restored_posemb_grid = mbt_utils.interpolate_positional_embeddings( |
| restored_posemb_grid, tokens_per_frame |
| ) |
| restored_posemb_grid = restored_posemb_grid[0] |
| restored_posemb_grid = mbt_utils.tile_positional_embeddings( |
| restored_posemb_grid, ntok |
| ) |
|
|
| |
| if with_cls_token: |
| restored_posemb = jnp.array( |
| np.concatenate([cls_tok, restored_posemb_grid], axis=1)) |
| else: |
| restored_posemb = restored_posemb_grid |
|
|
| to_params['posembed_input']['pos_embedding'] = restored_posemb |
| else: |
| logging.info('Not restoring positional encodings from pretrained model') |
|
|
|
|
| def init_encoderblock(to_params, from_params, tm_key): |
| """Initialize encoder_block_parameters.""" |
| |
| |
| |
| for enc_key in from_params[tm_key].keys(): |
| if tm_key in to_params['vit_encoder']: |
| to_params['vit_encoder'][tm_key][enc_key] = from_params[tm_key][enc_key] |
| else: |
| for tokenizer_name in ['tokenizer2d', 'tokenizer3d', 'tokenizer_spec']: |
| if tokenizer_name in to_params['tokenizer']: |
| to_params['tokenizer'][tokenizer_name][tm_key][enc_key] = from_params[ |
| tm_key][enc_key] |
|
|
|
|
| def init_embedding(to_params, from_params, config): |
| """Initialize input embedding.""" |
| if config.init_from.get('restore_input_embedding', True): |
| input_kernel = to_params['embedding']['kernel'] |
| restored_kernel = from_params['kernel'] |
| restored_bias = from_params['bias'] |
| if input_kernel.shape != restored_kernel.shape: |
| kernel_init_method = config.model.modalities.video.kernel_init_method |
| if kernel_init_method == 'average_frame_initializer': |
| |
| |
| |
| logging.info('Initializing input kernel with filter inflation.') |
| t = input_kernel.shape[0] |
| restored_kernel = np.expand_dims(restored_kernel, axis=0) |
| restored_kernel = np.tile(restored_kernel, [t, 1, 1, 1, 1]) / t |
| elif kernel_init_method == 'average_arp_frame_initializer': |
| |
| |
| |
| |
| logging.info('Initialzing input kernel with ARP inflation') |
| t = input_kernel.shape[0] |
| restored_kernel = np.expand_dims(restored_kernel, axis=0) |
| restored_kernel = np.tile(restored_kernel, [t, 1, 1, 1, 1]) |
|
|
| def average_arp(length): |
| |
| array = np.arange(1, length + 1) |
|
|
| harmonic = np.zeros((length + 1)) |
| harmonic[1:] = np.cumsum(1.0 / array) |
|
|
| array = 2 * (length - array + 1) - (length + 1) * ( |
| harmonic[-1] - harmonic[:-1]) |
| return array |
|
|
| normalizer = average_arp(t) / t |
| normalizer = np.reshape(normalizer, [t, 1, 1, 1, 1]) |
| restored_kernel = restored_kernel * normalizer |
| elif kernel_init_method == 'central_frame_initializer': |
| logging.info('Initializing input kernel to select centre frame.') |
| central_time_index = input_kernel.shape[0] // 2 |
| temp = np.zeros(input_kernel.shape) |
| temp[central_time_index] = restored_kernel.copy() |
| restored_kernel = temp |
| else: |
| raise AssertionError( |
| 'Unknown input kernel initialization {}'.format(kernel_init_method)) |
|
|
| to_params['embedding']['kernel'] = restored_kernel |
| to_params['embedding']['bias'] = restored_bias |
| else: |
| logging.info('Not restoring input embedding parameters') |
|
|
|
|
| def get_cls_token_and_video_frames(config): |
| """Returns whether there is CLS token and the number of video frames.""" |
|
|
| has_cls_token = False |
| num_video_frames = None |
|
|
| for ds_name, cfg in config.datasets.items(): |
| |
| if ds_name in ['kinetics400', 'moments_in_time', 'epic_kitchens']: |
| num_video_frames = cfg.num_frames |
|
|
| for head_type, head_cfg in config.model.heads.items(): |
| for cfg in head_cfg.values(): |
| if head_type in ['label', 'multilabel', 'bow' |
| ] and cfg.classifier in ['token', '0']: |
| has_cls_token = True |
|
|
| return has_cls_token, num_video_frames |
|
|