|
"""
|
|
Copyright (c) 2023, salesforce.com, inc.
|
|
All rights reserved.
|
|
SPDX-License-Identifier: BSD-3-Clause
|
|
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
"""
|
|
import contextlib
|
|
import logging
|
|
import os
|
|
import time
|
|
import datetime
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.distributed as dist
|
|
import torch.nn.functional as F
|
|
|
|
import lavis.common.dist_utils as dist_utils
|
|
from lavis.common.dist_utils import download_cached_file
|
|
from lavis.common.utils import is_url
|
|
from lavis.common.logger import MetricLogger
|
|
from lavis.models.base_model import BaseModel
|
|
from lavis.models.blip2_models.Qformer import BertConfig, BertLMHeadModel
|
|
from lavis.models.eva_vit import create_eva_vit_g
|
|
from lavis.models.clip_vit import create_clip_vit_L
|
|
from transformers import BertTokenizer
|
|
|
|
|
|
class Blip2ProteinBase(BaseModel):
|
|
@classmethod
|
|
def init_tokenizer(cls, truncation_side="right"):
|
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side=truncation_side)
|
|
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
|
return tokenizer
|
|
|
|
def maybe_autocast(self, dtype=torch.float16):
|
|
|
|
|
|
enable_autocast = self.device != torch.device("cpu")
|
|
|
|
if enable_autocast:
|
|
return torch.cuda.amp.autocast(dtype=dtype)
|
|
else:
|
|
return contextlib.nullcontext()
|
|
|
|
@classmethod
|
|
def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
|
|
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
|
|
encoder_config.encoder_width = vision_width
|
|
|
|
encoder_config.add_cross_attention = True
|
|
encoder_config.cross_attention_freq = cross_attention_freq
|
|
encoder_config.query_length = num_query_token
|
|
Qformer = BertLMHeadModel.from_pretrained(
|
|
"bert-base-uncased", config=encoder_config
|
|
)
|
|
query_tokens = nn.Parameter(
|
|
torch.zeros(1, num_query_token, encoder_config.hidden_size)
|
|
)
|
|
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
|
|
return Qformer, query_tokens
|
|
|
|
def load_from_pretrained(self, url_or_filename):
|
|
if is_url(url_or_filename):
|
|
cached_file = download_cached_file(
|
|
url_or_filename, check_hash=False, progress=True
|
|
)
|
|
checkpoint = torch.load(cached_file, map_location="cpu")
|
|
elif os.path.isfile(url_or_filename):
|
|
checkpoint = torch.load(url_or_filename, map_location="cpu")
|
|
else:
|
|
raise RuntimeError("checkpoint url or path is invalid")
|
|
|
|
state_dict = checkpoint["model"]
|
|
|
|
msg = self.load_state_dict(state_dict, strict=False)
|
|
|
|
|
|
logging.info("load checkpoint from %s" % url_or_filename)
|
|
|
|
return msg
|
|
|
|
def get_optimizer_params(self, weight_decay, lr_scale=1):
|
|
|
|
vit_num_layers = self.ln_vision.num_layers
|
|
lr_scales = list(lr_scale ** (vit_num_layers + 1 - i) for i in range(vit_num_layers + 2))
|
|
|
|
parameter_group_names = {}
|
|
parameter_group_vars = {}
|
|
|
|
for name, param in self.named_parameters():
|
|
if not param.requires_grad:
|
|
continue
|
|
if len(param.shape) == 1 or name.endswith(".bias"):
|
|
group_name = "no_decay"
|
|
this_weight_decay = 0.
|
|
else:
|
|
group_name = "decay"
|
|
this_weight_decay = weight_decay
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if group_name not in parameter_group_names:
|
|
|
|
|
|
|
|
|
|
scale = 1
|
|
|
|
parameter_group_names[group_name] = {
|
|
"weight_decay": this_weight_decay,
|
|
"params": [],
|
|
"lr_scale": scale
|
|
}
|
|
parameter_group_vars[group_name] = {
|
|
"weight_decay": this_weight_decay,
|
|
"params": [],
|
|
"lr_scale": scale
|
|
}
|
|
parameter_group_vars[group_name]["params"].append(param)
|
|
parameter_group_names[group_name]["params"].append(name)
|
|
|
|
|
|
optim_params = list(parameter_group_vars.values())
|
|
return optim_params
|
|
|
|
def _lemmatize(self, answers):
|
|
def apply(answer):
|
|
doc = self.lemmatizer(answer)
|
|
|
|
words = []
|
|
for token in doc:
|
|
if token.pos_ in ["NOUN", "VERB"]:
|
|
words.append(token.lemma_)
|
|
else:
|
|
words.append(token.text)
|
|
answer = " ".join(words)
|
|
|
|
return answer
|
|
|
|
return [apply(answer) for answer in answers]
|
|
|
|
@property
|
|
def lemmatizer(self):
|
|
if self._lemmatizer is None:
|
|
try:
|
|
import spacy
|
|
|
|
self._lemmatizer = spacy.load("en_core_web_sm")
|
|
except ImportError:
|
|
logging.error(
|
|
"""
|
|
Please install spacy and en_core_web_sm model to apply lemmatization.
|
|
python -m spacy download en_core_web_sm
|
|
OR
|
|
import spacy.cli
|
|
spacy.cli.download("en_core_web_sm")
|
|
"""
|
|
)
|
|
exit(1)
|
|
|
|
return self._lemmatizer |