|
"""
|
|
Copyright (c) 2022, salesforce.com, inc.
|
|
All rights reserved.
|
|
SPDX-License-Identifier: BSD-3-Clause
|
|
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from lavis.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
|
|
from lavis.common.utils import get_abs_path, is_url
|
|
from omegaconf import OmegaConf
|
|
from transformers import PreTrainedModel, PretrainedConfig
|
|
|
|
|
|
class FAPMConfig(PretrainedConfig):
|
|
model_type = 'FAPM'
|
|
def __init__(self, important_param=42, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
class BaseModel(PreTrainedModel):
|
|
"""Base class for models."""
|
|
config_class = FAPMConfig
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
@property
|
|
def device(self):
|
|
return list(self.parameters())[0].device
|
|
|
|
def load_checkpoint(self, url_or_filename):
|
|
"""
|
|
Load from a finetuned checkpoint.
|
|
|
|
This should expect no mismatch in the model keys and the checkpoint keys.
|
|
"""
|
|
|
|
if is_url(url_or_filename):
|
|
cached_file = download_cached_file(
|
|
url_or_filename, check_hash=False, progress=True
|
|
)
|
|
checkpoint = torch.load(cached_file, map_location="cpu")
|
|
elif os.path.isfile(url_or_filename):
|
|
checkpoint = torch.load(url_or_filename, map_location="cpu")
|
|
else:
|
|
raise RuntimeError("checkpoint url or path is invalid")
|
|
|
|
if "model" in checkpoint.keys():
|
|
state_dict = checkpoint["model"]
|
|
else:
|
|
state_dict = checkpoint
|
|
|
|
msg = self.load_state_dict(state_dict, strict=False)
|
|
|
|
logging.info("Missing keys {}".format(msg.missing_keys))
|
|
logging.info("load checkpoint from %s" % url_or_filename)
|
|
|
|
return msg
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, model_type):
|
|
"""
|
|
Build a pretrained model from default configuration file, specified by model_type.
|
|
|
|
Args:
|
|
- model_type (str): model type, specifying architecture and checkpoints.
|
|
|
|
Returns:
|
|
- model (nn.Module): pretrained or finetuned model, depending on the configuration.
|
|
"""
|
|
model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
|
|
model = cls.from_config(model_cfg)
|
|
|
|
return model
|
|
|
|
@classmethod
|
|
def default_config_path(cls, model_type):
|
|
assert (
|
|
model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
|
|
), "Unknown model type {}".format(model_type)
|
|
return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
|
|
|
|
def load_checkpoint_from_config(self, cfg, **kwargs):
|
|
"""
|
|
Load checkpoint as specified in the config file.
|
|
|
|
If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
|
|
When loading the pretrained model, each task-specific architecture may define their
|
|
own load_from_pretrained() method.
|
|
"""
|
|
load_finetuned = cfg.get("load_finetuned", True)
|
|
if load_finetuned:
|
|
finetune_path = cfg.get("finetuned", None)
|
|
assert (
|
|
finetune_path is not None
|
|
), "Found load_finetuned is True, but finetune_path is None."
|
|
self.load_checkpoint(url_or_filename=finetune_path)
|
|
else:
|
|
load_pretrained = cfg.get("load_pretrained", True)
|
|
if load_pretrained:
|
|
|
|
pretrain_path = cfg.get("pretrained", None)
|
|
assert "Found load_finetuned is False, but pretrain_path is None."
|
|
self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
|
|
|
|
def before_training(self, **kwargs):
|
|
pass
|
|
|
|
def get_optimizer_params(self, weight_decay, lr_scale=1):
|
|
p_wd, p_non_wd = [], []
|
|
for n, p in self.named_parameters():
|
|
if not p.requires_grad:
|
|
continue
|
|
if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
|
|
p_non_wd.append(p)
|
|
else:
|
|
p_wd.append(p)
|
|
optim_params = [
|
|
{"params": p_wd, "weight_decay": weight_decay, "lr_scale": lr_scale},
|
|
{"params": p_non_wd, "weight_decay": 0, "lr_scale": lr_scale},
|
|
]
|
|
return optim_params
|
|
|
|
def before_evaluation(self, **kwargs):
|
|
pass
|
|
|
|
def show_n_params(self, return_str=True):
|
|
tot = 0
|
|
for p in self.parameters():
|
|
w = 1
|
|
for x in p.shape:
|
|
w *= x
|
|
tot += w
|
|
if return_str:
|
|
if tot >= 1e6:
|
|
return "{:.1f}M".format(tot / 1e6)
|
|
else:
|
|
return "{:.1f}K".format(tot / 1e3)
|
|
else:
|
|
return tot
|
|
|
|
|
|
class BaseEncoder(nn.Module):
|
|
"""
|
|
Base class for primitive encoders, such as ViT, TimeSformer, etc.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward_features(self, samples, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def device(self):
|
|
return list(self.parameters())[0].device
|
|
|
|
|
|
class SharedQueueMixin:
|
|
@torch.no_grad()
|
|
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
|
|
|
|
image_feats = concat_all_gather(image_feat)
|
|
text_feats = concat_all_gather(text_feat)
|
|
|
|
batch_size = image_feats.shape[0]
|
|
|
|
ptr = int(self.queue_ptr)
|
|
assert self.queue_size % batch_size == 0
|
|
|
|
|
|
self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
|
|
self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
|
|
|
|
if idxs is not None:
|
|
idxs = concat_all_gather(idxs)
|
|
self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
|
|
|
|
ptr = (ptr + batch_size) % self.queue_size
|
|
self.queue_ptr[0] = ptr
|
|
|
|
|
|
class MomentumDistilationMixin:
|
|
@torch.no_grad()
|
|
def copy_params(self):
|
|
for model_pair in self.model_pairs:
|
|
for param, param_m in zip(
|
|
model_pair[0].parameters(), model_pair[1].parameters()
|
|
):
|
|
param_m.data.copy_(param.data)
|
|
param_m.requires_grad = False
|
|
|
|
@torch.no_grad()
|
|
def _momentum_update(self):
|
|
for model_pair in self.model_pairs:
|
|
for param, param_m in zip(
|
|
model_pair[0].parameters(), model_pair[1].parameters()
|
|
):
|
|
param_m.data = param_m.data * self.momentum + param.data * (
|
|
1.0 - self.momentum
|
|
)
|
|
|
|
|
|
class GatherLayer(torch.autograd.Function):
|
|
"""
|
|
Gather tensors from all workers with support for backward propagation:
|
|
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
output = [
|
|
torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
|
|
]
|
|
torch.distributed.all_gather(output, x)
|
|
return tuple(output)
|
|
|
|
@staticmethod
|
|
def backward(ctx, *grads):
|
|
all_gradients = torch.stack(grads)
|
|
torch.distributed.all_reduce(all_gradients)
|
|
return all_gradients[torch.distributed.get_rank()]
|
|
|
|
|
|
def all_gather_with_grad(tensors):
|
|
"""
|
|
Performs all_gather operation on the provided tensors.
|
|
Graph remains connected for backward grad computation.
|
|
"""
|
|
|
|
world_size = torch.distributed.get_world_size()
|
|
|
|
if world_size == 1:
|
|
return tensors
|
|
|
|
|
|
tensor_all = GatherLayer.apply(tensors)
|
|
|
|
return torch.cat(tensor_all, dim=0)
|
|
|
|
|
|
@torch.no_grad()
|
|
def concat_all_gather(tensor):
|
|
"""
|
|
Performs all_gather operation on the provided tensors.
|
|
*** Warning ***: torch.distributed.all_gather has no gradient.
|
|
"""
|
|
|
|
if not is_dist_avail_and_initialized():
|
|
return tensor
|
|
|
|
tensors_gather = [
|
|
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
|
]
|
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
|
|
|
output = torch.cat(tensors_gather, dim=0)
|
|
return output
|
|
|
|
|
|
def tile(x, dim, n_tile):
|
|
init_dim = x.size(dim)
|
|
repeat_idx = [1] * x.dim()
|
|
repeat_idx[dim] = n_tile
|
|
x = x.repeat(*(repeat_idx))
|
|
order_index = torch.LongTensor(
|
|
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
|
|
)
|
|
return torch.index_select(x, dim, order_index.to(x.device))
|
|
|