Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import math | |
from dataclasses import dataclass, field | |
from typing import Optional, Callable | |
from functools import partial | |
import numpy as np | |
from omegaconf import II | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.distributed as dist | |
from fairseq.modules import EMAModule, EMAModuleConfig | |
from fairseq.dataclass import FairseqDataclass | |
from fairseq.models import BaseFairseqModel, register_model | |
from examples.data2vec.data.modality import Modality | |
from examples.data2vec.models.modalities.base import ( | |
MaskSeed, | |
D2vModalityConfig, | |
ModalitySpecificEncoder, | |
get_annealed_rate, | |
) | |
from examples.data2vec.models.modalities.modules import ( | |
D2vDecoderConfig, | |
AltBlock, | |
Decoder1d, | |
) | |
from .modalities.audio import ( | |
D2vAudioConfig, | |
AudioEncoder, | |
) | |
from examples.data2vec.models.modalities.images import ( | |
D2vImageConfig, | |
ImageEncoder, | |
) | |
from examples.data2vec.models.modalities.text import ( | |
D2vTextConfig, | |
TextEncoder, | |
) | |
logger = logging.getLogger(__name__) | |
class D2vModalitiesConfig(FairseqDataclass): | |
audio: D2vAudioConfig = D2vAudioConfig() | |
image: D2vImageConfig = D2vImageConfig() | |
text: D2vTextConfig = D2vTextConfig() | |
class Data2VecMultiConfig(FairseqDataclass): | |
loss_beta: float = field( | |
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"} | |
) | |
loss_scale: Optional[float] = field( | |
default=None, | |
metadata={ | |
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)" | |
}, | |
) | |
input_feature_ndim: int = 40 | |
depth: int = 8 | |
start_drop_path_rate: float = 0 | |
end_drop_path_rate: float = 0 | |
num_heads: int = 12 | |
norm_eps: float = 1e-6 | |
norm_affine: bool = True | |
encoder_dropout: float = 0.1 | |
post_mlp_drop: float = 0.1 | |
attention_dropout: float = 0.1 | |
activation_dropout: float = 0.0 | |
dropout_input: float = 0.0 | |
layerdrop: float = 0.0 | |
embed_dim: int = 768 | |
mlp_ratio: float = 4 | |
layer_norm_first: bool = False | |
average_top_k_layers: int = field( | |
default=8, metadata={"help": "how many layers to average"} | |
) | |
end_of_block_targets: bool = False | |
clone_batch: int = 1 | |
layer_norm_target_layer: bool = False | |
batch_norm_target_layer: bool = False | |
instance_norm_target_layer: bool = False | |
instance_norm_targets: bool = False | |
layer_norm_targets: bool = False | |
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"}) | |
ema_same_dtype: bool = True | |
log_norms: bool = True | |
ema_end_decay: float = field( | |
default=0.9999, metadata={"help": "final ema decay rate"} | |
) | |
# when to finish annealing ema decay rate | |
ema_anneal_end_step: int = II("optimization.max_update") | |
ema_encoder_only: bool = field( | |
default=True, | |
metadata={ | |
"help": "whether to momentum update only the shared transformer encoder" | |
}, | |
) | |
max_update: int = II("optimization.max_update") | |
modalities: D2vModalitiesConfig = D2vModalitiesConfig() | |
shared_decoder: Optional[D2vDecoderConfig] = None | |
min_target_var: float = field( | |
default=0.1, metadata={"help": "stop training if target var falls below this"} | |
) | |
min_pred_var: float = field( | |
default=0.01, | |
metadata={"help": "stop training if prediction var falls below this"}, | |
) | |
supported_modality: Optional[Modality] = None | |
mae_init: bool = False | |
seed: int = II("common.seed") | |
skip_ema: bool = False | |
cls_loss: float = 0 | |
recon_loss: float = 0 | |
d2v_loss: float = 1 | |
decoder_group: bool = False | |
class Data2VecMultiModel(BaseFairseqModel): | |
def make_modality_encoder( | |
self, | |
cfg: D2vModalityConfig, | |
embed_dim: int, | |
make_block: Callable[[float], nn.ModuleList], | |
norm_layer: Callable[[int], nn.LayerNorm], | |
layer_norm_first: bool, | |
alibi_biases, | |
task, | |
) -> ModalitySpecificEncoder: | |
if cfg.type == Modality.AUDIO: | |
enc_cls = AudioEncoder | |
elif cfg.type == Modality.IMAGE: | |
enc_cls = ImageEncoder | |
elif cfg.type == Modality.TEXT: | |
enc_cls = TextEncoder | |
if hasattr(task, "text_task"): | |
task = task.text_task | |
else: | |
raise Exception(f"unsupported modality {cfg.type}") | |
return enc_cls( | |
cfg, | |
embed_dim, | |
make_block, | |
norm_layer, | |
layer_norm_first, | |
alibi_biases, | |
task, | |
) | |
def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None): | |
super().__init__() | |
self.cfg = cfg | |
self.modalities = modalities | |
self.task = task | |
make_layer_norm = partial( | |
nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine | |
) | |
def make_block(drop_path, dim=None, heads=None): | |
return AltBlock( | |
cfg.embed_dim if dim is None else dim, | |
cfg.num_heads if heads is None else heads, | |
cfg.mlp_ratio, | |
qkv_bias=True, | |
drop=cfg.encoder_dropout, | |
attn_drop=cfg.attention_dropout, | |
mlp_drop=cfg.activation_dropout, | |
post_mlp_drop=cfg.post_mlp_drop, | |
drop_path=drop_path, | |
norm_layer=make_layer_norm, | |
layer_norm_first=cfg.layer_norm_first, | |
ffn_targets=not cfg.end_of_block_targets, | |
) | |
self.alibi_biases = {} | |
self.modality_encoders = nn.ModuleDict() | |
for mod in self.modalities: | |
mod_cfg = getattr(cfg.modalities, mod.name.lower()) | |
enc = self.make_modality_encoder( | |
mod_cfg, | |
cfg.embed_dim, | |
make_block, | |
make_layer_norm, | |
cfg.layer_norm_first, | |
self.alibi_biases, | |
task, | |
) | |
self.modality_encoders[mod.name] = enc | |
self.ema = None | |
self.average_top_k_layers = cfg.average_top_k_layers | |
self.loss_beta = cfg.loss_beta | |
self.loss_scale = cfg.loss_scale | |
self.dropout_input = nn.Dropout(cfg.dropout_input) | |
dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth) | |
self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)]) | |
self.norm = None | |
if cfg.layer_norm_first: | |
self.norm = make_layer_norm(cfg.embed_dim) | |
if self.cfg.mae_init: | |
self.apply(self._init_weights) | |
else: | |
from fairseq.modules.transformer_sentence_encoder import init_bert_params | |
self.apply(init_bert_params) | |
for mod_enc in self.modality_encoders.values(): | |
mod_enc.reset_parameters() | |
if not skip_ema: | |
self.ema = self.make_ema_teacher(cfg.ema_decay) | |
self.shared_decoder = ( | |
Decoder1d(cfg.shared_decoder, cfg.embed_dim) | |
if self.cfg.shared_decoder is not None | |
else None | |
) | |
if self.shared_decoder is not None: | |
self.shared_decoder.apply(self._init_weights) | |
self.recon_proj = None | |
if cfg.recon_loss > 0: | |
self.recon_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim) | |
for pn, p in self.named_parameters(): | |
if len(p.shape) == 1 or pn.endswith(".bias") or "alibi_scale" in pn: | |
p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}} | |
if cfg.decoder_group and "decoder" in pn: | |
p.param_group = "decoder" | |
self.num_updates = 0 | |
def _init_weights(self, m): | |
try: | |
from apex.normalization import FusedLayerNorm | |
fn = FusedLayerNorm | |
except: | |
fn = nn.LayerNorm | |
if isinstance(m, nn.Linear): | |
torch.nn.init.xavier_uniform_(m.weight) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm) or isinstance(m, fn): | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
if m.weight is not None: | |
nn.init.constant_(m.weight, 1.0) | |
def make_ema_teacher(self, ema_decay): | |
ema_config = EMAModuleConfig( | |
ema_decay=ema_decay, | |
ema_fp32=True, | |
log_norms=self.cfg.log_norms, | |
add_missing_params=False, | |
) | |
model_copy = self.make_target_model() | |
return EMAModule( | |
model_copy, | |
ema_config, | |
copy_model=False, | |
) | |
def make_target_model(self): | |
logger.info("making target model") | |
model_copy = Data2VecMultiModel( | |
self.cfg, self.modalities, skip_ema=True, task=self.task | |
) | |
if self.cfg.ema_encoder_only: | |
model_copy = model_copy.blocks | |
for p_s, p_t in zip(self.blocks.parameters(), model_copy.parameters()): | |
p_t.data.copy_(p_s.data) | |
else: | |
for p_s, p_t in zip(self.parameters(), model_copy.parameters()): | |
p_t.data.copy_(p_s.data) | |
for mod_enc in model_copy.modality_encoders.values(): | |
mod_enc.decoder = None | |
if not mod_enc.modality_cfg.ema_local_encoder: | |
mod_enc.local_encoder = None | |
mod_enc.project_features = None | |
model_copy.requires_grad_(False) | |
return model_copy | |
def set_num_updates(self, num_updates): | |
super().set_num_updates(num_updates) | |
if self.ema is not None and ( | |
(self.num_updates == 0 and num_updates > 1) | |
or self.num_updates >= num_updates | |
): | |
pass | |
elif self.training and self.ema is not None: | |
ema_weight_decay = None | |
if self.cfg.ema_decay != self.cfg.ema_end_decay: | |
if num_updates >= self.cfg.ema_anneal_end_step: | |
decay = self.cfg.ema_end_decay | |
else: | |
decay = get_annealed_rate( | |
self.cfg.ema_decay, | |
self.cfg.ema_end_decay, | |
num_updates, | |
self.cfg.ema_anneal_end_step, | |
) | |
self.ema.set_decay(decay, weight_decay=ema_weight_decay) | |
if self.ema.get_decay() < 1: | |
self.ema.step(self.blocks if self.cfg.ema_encoder_only else self) | |
self.num_updates = num_updates | |
def state_dict(self, destination=None, prefix="", keep_vars=False): | |
state = super().state_dict(destination, prefix, keep_vars) | |
if self.ema is not None: | |
state[prefix + "_ema"] = self.ema.fp32_params | |
return state | |
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): | |
k = prefix + "_ema" | |
if self.ema is not None: | |
assert k in state_dict | |
self.ema.restore(state_dict[k], True) | |
del state_dict[k] | |
elif k in state_dict: | |
del state_dict[k] | |
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) | |
def build_model(cls, cfg: Data2VecMultiConfig, task=None): | |
"""Build a new model instance.""" | |
if task is None or not hasattr(task, "supported_modalities"): | |
modalities = ( | |
[cfg.supported_modality] | |
if cfg.supported_modality is not None | |
else [ | |
Modality.AUDIO, | |
Modality.IMAGE, | |
Modality.TEXT, | |
] | |
) | |
else: | |
modalities = task.supported_modalities | |
return cls(cfg, modalities, task=task, skip_ema=cfg.skip_ema) | |
def forward( | |
self, | |
source, | |
target=None, | |
id=None, | |
mode=None, | |
padding_mask=None, | |
mask=True, | |
features_only=False, | |
force_remove_masked=False, | |
remove_extra_tokens=True, | |
precomputed_mask=None, | |
corpus_key=None, # for config compatiblity | |
): | |
if mode is None: | |
assert self.cfg.supported_modality is not None | |
mode = self.cfg.supported_modality | |
if isinstance(mode, Modality): | |
mode = mode.name | |
feature_extractor = self.modality_encoders[mode] | |
mask_seeds = None | |
if id is not None: | |
mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id) | |
extractor_out = feature_extractor( | |
source, | |
padding_mask, | |
mask, | |
remove_masked=not features_only or force_remove_masked, | |
clone_batch=self.cfg.clone_batch if not features_only else 1, | |
mask_seeds=mask_seeds, | |
precomputed_mask=precomputed_mask, | |
) | |
x = extractor_out["x"] | |
encoder_mask = extractor_out["encoder_mask"] | |
masked_padding_mask = extractor_out["padding_mask"] | |
masked_alibi_bias = extractor_out.get("alibi_bias", None) | |
alibi_scale = extractor_out.get("alibi_scale", None) | |
if self.dropout_input is not None: | |
x = self.dropout_input(x) | |
layer_results = [] | |
for i, blk in enumerate(self.blocks): | |
if ( | |
not self.training | |
or self.cfg.layerdrop == 0 | |
or (np.random.random() > self.cfg.layerdrop) | |
): | |
ab = masked_alibi_bias | |
if ab is not None and alibi_scale is not None: | |
scale = ( | |
alibi_scale[i] | |
if alibi_scale.size(0) > 1 | |
else alibi_scale.squeeze(0) | |
) | |
ab = ab * scale.type_as(ab) | |
x, lr = blk( | |
x, | |
padding_mask=masked_padding_mask, | |
alibi_bias=ab, | |
) | |
if features_only: | |
layer_results.append((x, lr)) | |
if self.norm is not None: | |
x = self.norm(x) | |
if features_only: | |
if remove_extra_tokens: | |
x = x[:, feature_extractor.modality_cfg.num_extra_tokens :] | |
if masked_padding_mask is not None: | |
masked_padding_mask = masked_padding_mask[ | |
:, feature_extractor.modality_cfg.num_extra_tokens : | |
] | |
return { | |
"x": x, | |
"padding_mask": masked_padding_mask, | |
"layer_results": layer_results, | |
"mask": encoder_mask, | |
} | |
xs = [] | |
if self.shared_decoder is not None: | |
dx = self.forward_decoder( | |
x, | |
feature_extractor, | |
self.shared_decoder, | |
encoder_mask, | |
) | |
xs.append(dx) | |
if feature_extractor.decoder is not None: | |
dx = self.forward_decoder( | |
x, | |
feature_extractor, | |
feature_extractor.decoder, | |
encoder_mask, | |
) | |
xs.append(dx) | |
orig_x = x | |
assert len(xs) > 0 | |
p = next(self.ema.model.parameters()) | |
device = x.device | |
dtype = x.dtype | |
ema_device = p.device | |
ema_dtype = p.dtype | |
if not self.cfg.ema_same_dtype: | |
dtype = ema_dtype | |
if ema_device != device or ema_dtype != dtype: | |
logger.info(f"adjusting ema dtype to {dtype} and device to {device}") | |
self.ema.model = self.ema.model.to(dtype=dtype, device=device) | |
ema_dtype = dtype | |
def to_device(d): | |
for k, p in d.items(): | |
if isinstance(d[k], dict): | |
to_device(d[k]) | |
else: | |
d[k] = p.to(device=device) | |
to_device(self.ema.fp32_params) | |
tm = self.ema.model | |
with torch.no_grad(): | |
tm.eval() | |
if self.cfg.ema_encoder_only: | |
assert target is None | |
ema_input = extractor_out["local_features"] | |
ema_input = feature_extractor.contextualized_features( | |
ema_input.to(dtype=ema_dtype), | |
padding_mask, | |
mask=False, | |
remove_masked=False, | |
) | |
ema_blocks = tm | |
else: | |
ema_blocks = tm.blocks | |
if feature_extractor.modality_cfg.ema_local_encoder: | |
inp = ( | |
target.to(dtype=ema_dtype) | |
if target is not None | |
else source.to(dtype=ema_dtype) | |
) | |
ema_input = tm.modality_encoders[mode]( | |
inp, | |
padding_mask, | |
mask=False, | |
remove_masked=False, | |
) | |
else: | |
assert target is None | |
ema_input = extractor_out["local_features"] | |
ema_feature_enc = tm.modality_encoders[mode] | |
ema_input = ema_feature_enc.contextualized_features( | |
ema_input.to(dtype=ema_dtype), | |
padding_mask, | |
mask=False, | |
remove_masked=False, | |
) | |
ema_padding_mask = ema_input["padding_mask"] | |
ema_alibi_bias = ema_input.get("alibi_bias", None) | |
ema_alibi_scale = ema_input.get("alibi_scale", None) | |
ema_input = ema_input["x"] | |
y = [] | |
ema_x = [] | |
extra_tokens = feature_extractor.modality_cfg.num_extra_tokens | |
for i, blk in enumerate(ema_blocks): | |
ab = ema_alibi_bias | |
if ab is not None and alibi_scale is not None: | |
scale = ( | |
ema_alibi_scale[i] | |
if ema_alibi_scale.size(0) > 1 | |
else ema_alibi_scale.squeeze(0) | |
) | |
ab = ab * scale.type_as(ab) | |
ema_input, lr = blk( | |
ema_input, | |
padding_mask=ema_padding_mask, | |
alibi_bias=ab, | |
) | |
y.append(lr[:, extra_tokens:]) | |
ema_x.append(ema_input[:, extra_tokens:]) | |
y = self.make_targets(y, self.average_top_k_layers) | |
orig_targets = y | |
if self.cfg.clone_batch > 1: | |
y = y.repeat_interleave(self.cfg.clone_batch, 0) | |
masked = encoder_mask.mask.unsqueeze(-1) | |
masked_b = encoder_mask.mask.bool() | |
y = y[masked_b] | |
if xs[0].size(1) == masked_b.size(1): | |
xs = [x[masked_b] for x in xs] | |
else: | |
xs = [x.reshape(-1, x.size(-1)) for x in xs] | |
sample_size = masked.sum().long() | |
result = { | |
"losses": {}, | |
"sample_size": sample_size, | |
} | |
sample_size = result["sample_size"] | |
if self.cfg.cls_loss > 0: | |
assert extra_tokens > 0 | |
cls_target = orig_targets.mean(dim=1) | |
if self.cfg.clone_batch > 1: | |
cls_target = cls_target.repeat_interleave(self.cfg.clone_batch, 0) | |
cls_pred = x[:, extra_tokens - 1] | |
result["losses"]["cls"] = self.d2v_loss(cls_pred, cls_target) * ( | |
self.cfg.cls_loss * sample_size | |
) | |
if self.cfg.recon_loss > 0: | |
with torch.no_grad(): | |
target = feature_extractor.patchify(source) | |
mean = target.mean(dim=-1, keepdim=True) | |
var = target.var(dim=-1, keepdim=True) | |
target = (target - mean) / (var + 1.0e-6) ** 0.5 | |
if self.cfg.clone_batch > 1: | |
target = target.repeat_interleave(self.cfg.clone_batch, 0) | |
if masked_b is not None: | |
target = target[masked_b] | |
recon = xs[0] | |
if self.recon_proj is not None: | |
recon = self.recon_proj(recon) | |
result["losses"]["recon"] = ( | |
self.d2v_loss(recon, target.float()) * self.cfg.recon_loss | |
) | |
if self.cfg.d2v_loss > 0: | |
for i, x in enumerate(xs): | |
reg_loss = self.d2v_loss(x, y) | |
n = f"{mode}_regression_{i}" if len(xs) > 1 else f"{mode}_regression" | |
result["losses"][n] = reg_loss * self.cfg.d2v_loss | |
suffix = "" if len(self.modalities) == 1 else f"_{mode}" | |
with torch.no_grad(): | |
if encoder_mask is not None: | |
result["masked_pct"] = 1 - ( | |
encoder_mask.ids_keep.size(1) / encoder_mask.ids_restore.size(1) | |
) | |
for i, x in enumerate(xs): | |
n = f"pred_var{suffix}_{i}" if len(xs) > 1 else f"pred_var{suffix}" | |
result[n] = self.compute_var(x.float()) | |
if self.ema is not None: | |
for k, v in self.ema.logs.items(): | |
result[k] = v | |
y = y.float() | |
result[f"target_var{suffix}"] = self.compute_var(y) | |
if self.num_updates > 5000: | |
if result[f"target_var{suffix}"] < self.cfg.min_target_var: | |
logger.error( | |
f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})" | |
) | |
raise Exception( | |
f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})" | |
) | |
for k in result.keys(): | |
if k.startswith("pred_var") and result[k] < self.cfg.min_pred_var: | |
logger.error( | |
f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})" | |
) | |
raise Exception( | |
f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})" | |
) | |
result["ema_decay"] = self.ema.get_decay() * 1000 | |
return result | |
def forward_decoder( | |
self, | |
x, | |
feature_extractor, | |
decoder, | |
mask_info, | |
): | |
x = feature_extractor.decoder_input(x, mask_info) | |
x = decoder(*x) | |
return x | |
def d2v_loss(self, x, y): | |
x = x.view(-1, x.size(-1)).float() | |
y = y.view(-1, x.size(-1)) | |
if self.loss_beta == 0: | |
loss = F.mse_loss(x, y, reduction="none") | |
else: | |
loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta) | |
if self.loss_scale is not None: | |
scale = self.loss_scale | |
else: | |
scale = 1 / math.sqrt(x.size(-1)) | |
reg_loss = loss * scale | |
return reg_loss | |
def make_targets(self, y, num_layers): | |
with torch.no_grad(): | |
target_layer_results = y[-num_layers:] | |
permuted = False | |
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer: | |
target_layer_results = [ | |
tl.transpose(1, 2) for tl in target_layer_results # BTC -> BCT | |
] | |
permuted = True | |
if self.cfg.batch_norm_target_layer: | |
target_layer_results = [ | |
F.batch_norm( | |
tl.float(), running_mean=None, running_var=None, training=True | |
) | |
for tl in target_layer_results | |
] | |
if self.cfg.instance_norm_target_layer: | |
target_layer_results = [ | |
F.instance_norm(tl.float()) for tl in target_layer_results | |
] | |
if permuted: | |
target_layer_results = [ | |
tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC | |
] | |
if self.cfg.layer_norm_target_layer: | |
target_layer_results = [ | |
F.layer_norm(tl.float(), tl.shape[-1:]) | |
for tl in target_layer_results | |
] | |
y = target_layer_results[0].float() | |
for tl in target_layer_results[1:]: | |
y.add_(tl.float()) | |
y = y.div_(len(target_layer_results)) | |
if self.cfg.layer_norm_targets: | |
y = F.layer_norm(y, y.shape[-1:]) | |
if self.cfg.instance_norm_targets: | |
y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2) | |
return y | |
def compute_var(y): | |
y = y.view(-1, y.size(-1)) | |
if dist.is_initialized(): | |
zc = torch.tensor(y.size(0)).cuda() | |
zs = y.sum(dim=0) | |
zss = (y**2).sum(dim=0) | |
dist.all_reduce(zc) | |
dist.all_reduce(zs) | |
dist.all_reduce(zss) | |
var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1)) | |
return torch.sqrt(var + 1e-6).mean() | |
else: | |
return torch.sqrt(y.var(dim=0) + 1e-6).mean() | |
def extract_features( | |
self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True | |
): | |
res = self.forward( | |
source, | |
mode=mode, | |
padding_mask=padding_mask, | |
mask=mask, | |
features_only=True, | |
remove_extra_tokens=remove_extra_tokens, | |
) | |
return res | |
def remove_pretraining_modules(self, modality=None, keep_decoder=False): | |
self.ema = None | |
self.cfg.clone_batch = 1 | |
self.recon_proj = None | |
if not keep_decoder: | |
self.shared_decoder = None | |
modality = modality.lower() if modality is not None else None | |
for k in list(self.modality_encoders.keys()): | |
if modality is not None and k.lower() != modality: | |
del self.modality_encoders[k] | |
else: | |
self.modality_encoders[k].remove_pretraining_modules( | |
keep_decoder=keep_decoder | |
) | |
if not keep_decoder: | |
self.modality_encoders[k].decoder = None | |