Polos-Demo / polos /models /estimators /polos_estimator.py
yuwd's picture
init
03f6091
# -*- coding: utf-8 -*-
import random
from argparse import Namespace
from typing import Dict, List, Tuple, Union
import torch
from polos.models.estimators.estimator_base import Estimator
from polos.modules.feedforward import FeedForward
from polos.modules.scalar_mix import ScalarMixWithDropout
from torchnlp.utils import collate_tensors
import polos.clip as clip
import torch
from typing import List, Union
try:
import warnings
from shapely.errors import ShapelyDeprecationWarning
warnings.filterwarnings("ignore", category=ShapelyDeprecationWarning)
except:
pass
class PolosEstimator(Estimator):
"""
Estimator class that uses a pretrained encoder to extract features from
the sequences and then passes those features to a feed forward estimator.
:param hparams: Namespace containing the hyperparameters.
"""
class ModelConfig(Estimator.ModelConfig):
switch_prob: float = 0.0
def __init__(
self,
hparams: Namespace,
) -> None:
super().__init__(hparams)
def _build_model(self) -> Estimator:
"""
Initializes the estimator architecture.
"""
super()._build_model()
if self.hparams.encoder_model != "LASER":
self.layer = (
int(self.hparams.layer)
if self.hparams.layer != "mix"
else self.hparams.layer
)
self.scalar_mix = (
ScalarMixWithDropout(
mixture_size=self.encoder.num_layers,
dropout=self.hparams.scalar_mix_dropout,
do_layer_norm=True,
)
if self.layer == "mix" and self.hparams.pool != "default"
else None
)
parallel_feature_extraction = True
if parallel_feature_extraction:
input_emb_sz = (
self.encoder.output_units * 4 + 512*6
if self.hparams.pool != "cls+avg"
else self.encoder.output_units * 2 * 8
)
else:
input_emb_sz = (
self.encoder.output_units * 2 + 512*3
if self.hparams.pool != "cls+avg"
else self.encoder.output_units * 2 * 8
)
self.ff = torch.nn.Sequential(*[
FeedForward(
in_dim=input_emb_sz,
# out_dim=input_emb_sz,
hidden_sizes=self.hparams.hidden_sizes,
activations=self.hparams.activations,
dropout=self.hparams.dropout,
final_activation=(
self.hparams.final_activation
if hasattr(
self.hparams, "final_activation"
) # compatability with older checkpoints!
else "Sigmoid"
),
),
torch.nn.Sigmoid()
])
self.clip, self.clip_preprocess = clip.load("ViT-B/32", device="cpu")
self.parallel_feature_extraction = parallel_feature_extraction
def configure_optimizers(
self,
) -> Tuple[List[torch.optim.Optimizer], List[torch.optim.lr_scheduler.LambdaLR]]:
""" Sets different Learning rates for different parameter groups. """
layer_parameters = self.encoder.layerwise_lr(
self.hparams.encoder_learning_rate, self.hparams.layerwise_decay
)
ff_parameters = [
{"params": self.ff.parameters(), "lr": self.hparams.learning_rate}
]
if self.hparams.encoder_model != "LASER" and self.scalar_mix:
scalar_mix_parameters = [
{
"params": self.scalar_mix.parameters(),
"lr": self.hparams.learning_rate,
}
]
optimizer = self._build_optimizer(
layer_parameters + ff_parameters + scalar_mix_parameters
)
else:
optimizer = self._build_optimizer(layer_parameters + ff_parameters)
scheduler = self._build_scheduler(optimizer)
return [optimizer], [scheduler]
def prepare_sample(
self, sample: List[Dict[str, Union[str, float]]], inference: bool = False
) -> Union[
Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]], Dict[str, torch.Tensor]
]:
"""
Function that prepares a sample to input the model.
:param sample: list of dictionaries.
:param inference: If set to true prepares only the model inputs.
:returns: Tuple with 2 dictionaries (model inputs and targets).
If `inference=True` returns only the model inputs.
"""
sample = collate_tensors(sample)
mt_inputs = self.encoder.prepare_sample(sample["mt"])
ref_inputs = [self.encoder.prepare_sample(ref) for ref in sample["refs"]]
inputs = {
"mt_inputs": mt_inputs,
"ref_inputs": ref_inputs,
"refs": sample["refs"],
"mt": sample["mt"],
"imgs": sample["img"]
}
if inference:
return inputs
targets = {"score": torch.tensor(sample["score"], dtype=torch.float)}
return inputs, targets
def masked_global_average_pooling(self, input_tensor, mask):
mask = mask.logical_not() # mask[x] = input[x] is not pad
mask_expanded = mask.unsqueeze(-1).expand_as(input_tensor).float()
input_tensor_masked = input_tensor * mask_expanded
num_elements = mask.sum(dim=1,keepdim=True).float() # TODO: チェック
output_tensor = input_tensor_masked.sum(dim=1) / num_elements
return output_tensor
def forward(
self,
refs,
mt,
ref_inputs,
mt_inputs,
imgs: torch.tensor,
alt_tokens: torch.tensor = None,
alt_lengths: torch.tensor = None,
**kwargs
) -> Dict[str, torch.Tensor]:
mt_tokens, mt_lengths = mt_inputs["tokens"], mt_inputs["lengths"]
mt_sentemb, mt_sentembs, mt_mask, padding_index = self.get_sentence_embedding(mt_tokens, mt_lengths,pooling=False)
mt_mask = mt_mask.logical_not()
ref_sentemb_list = []
ref_sentembs_list = []
ref_mask_list = []
for ref in ref_inputs:
ref_tokens, ref_lengths = ref["tokens"], ref["lengths"]
ref_sentemb, ref_sentembs, ref_mask, _ = self.get_sentence_embedding(ref_tokens, ref_lengths,pooling=False)
ref_mask = ref_mask.logical_not()
ref_sentemb_list.append(ref_sentemb)
ref_sentembs_list.append(ref_sentembs)
ref_mask_list.append(ref_mask)
refs_clip = []
for ref_list in refs: # (ref_cnt, B, L)
subset = [clip.tokenize("A photo depicts " + ref,truncate=True).to(self.device) for ref in ref_list]
subset = torch.cat(subset,dim=0)
refs_tensor = self.clip.encode_text(subset)
refs_clip.append(refs_tensor)
mts_clip = clip.tokenize(["A photo depicts " + x for x in mt],truncate=True).to(self.device)
imgs_clip = torch.cat([self.clip_preprocess(img).unsqueeze(0) for img in imgs],dim=0).to(self.device)
imgs_clip = self.clip.encode_image(imgs_clip)
mts_clip = self.clip.encode_text(mts_clip)
del imgs
scores = []
for ref_sentemb, ref_clip in zip(ref_sentemb_list,refs_clip):
diff = torch.abs(mt_sentemb - ref_sentemb)
mul = mt_sentemb * ref_sentemb
diff_clip = torch.abs(imgs_clip - mts_clip)
mul_clip = imgs_clip * mts_clip
diff_clip_txt = torch.abs(ref_clip - mts_clip)
mul_clip_txt = ref_clip * mts_clip
if self.parallel_feature_extraction:
x = torch.cat(
(ref_sentemb,mt_sentemb,diff,mul,imgs_clip,mts_clip,diff_clip,mul_clip,diff_clip_txt,mul_clip_txt),dim=1
)
else:
x = torch.cat(
(ref_sentemb,mt_sentemb,ref_clip,imgs_clip,mts_clip),dim=1
)
score = self.ff(x)
scores.append(score)
score = torch.max(torch.stack(scores),dim=0).values
return {"score" : score}