Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
import random | |
from argparse import Namespace | |
from typing import Dict, List, Tuple, Union | |
import torch | |
from polos.models.estimators.estimator_base import Estimator | |
from polos.modules.feedforward import FeedForward | |
from polos.modules.scalar_mix import ScalarMixWithDropout | |
from torchnlp.utils import collate_tensors | |
import polos.clip as clip | |
import torch | |
from typing import List, Union | |
try: | |
import warnings | |
from shapely.errors import ShapelyDeprecationWarning | |
warnings.filterwarnings("ignore", category=ShapelyDeprecationWarning) | |
except: | |
pass | |
class PolosEstimator(Estimator): | |
""" | |
Estimator class that uses a pretrained encoder to extract features from | |
the sequences and then passes those features to a feed forward estimator. | |
:param hparams: Namespace containing the hyperparameters. | |
""" | |
class ModelConfig(Estimator.ModelConfig): | |
switch_prob: float = 0.0 | |
def __init__( | |
self, | |
hparams: Namespace, | |
) -> None: | |
super().__init__(hparams) | |
def _build_model(self) -> Estimator: | |
""" | |
Initializes the estimator architecture. | |
""" | |
super()._build_model() | |
if self.hparams.encoder_model != "LASER": | |
self.layer = ( | |
int(self.hparams.layer) | |
if self.hparams.layer != "mix" | |
else self.hparams.layer | |
) | |
self.scalar_mix = ( | |
ScalarMixWithDropout( | |
mixture_size=self.encoder.num_layers, | |
dropout=self.hparams.scalar_mix_dropout, | |
do_layer_norm=True, | |
) | |
if self.layer == "mix" and self.hparams.pool != "default" | |
else None | |
) | |
parallel_feature_extraction = True | |
if parallel_feature_extraction: | |
input_emb_sz = ( | |
self.encoder.output_units * 4 + 512*6 | |
if self.hparams.pool != "cls+avg" | |
else self.encoder.output_units * 2 * 8 | |
) | |
else: | |
input_emb_sz = ( | |
self.encoder.output_units * 2 + 512*3 | |
if self.hparams.pool != "cls+avg" | |
else self.encoder.output_units * 2 * 8 | |
) | |
self.ff = torch.nn.Sequential(*[ | |
FeedForward( | |
in_dim=input_emb_sz, | |
# out_dim=input_emb_sz, | |
hidden_sizes=self.hparams.hidden_sizes, | |
activations=self.hparams.activations, | |
dropout=self.hparams.dropout, | |
final_activation=( | |
self.hparams.final_activation | |
if hasattr( | |
self.hparams, "final_activation" | |
) # compatability with older checkpoints! | |
else "Sigmoid" | |
), | |
), | |
torch.nn.Sigmoid() | |
]) | |
self.clip, self.clip_preprocess = clip.load("ViT-B/32", device="cpu") | |
self.parallel_feature_extraction = parallel_feature_extraction | |
def configure_optimizers( | |
self, | |
) -> Tuple[List[torch.optim.Optimizer], List[torch.optim.lr_scheduler.LambdaLR]]: | |
""" Sets different Learning rates for different parameter groups. """ | |
layer_parameters = self.encoder.layerwise_lr( | |
self.hparams.encoder_learning_rate, self.hparams.layerwise_decay | |
) | |
ff_parameters = [ | |
{"params": self.ff.parameters(), "lr": self.hparams.learning_rate} | |
] | |
if self.hparams.encoder_model != "LASER" and self.scalar_mix: | |
scalar_mix_parameters = [ | |
{ | |
"params": self.scalar_mix.parameters(), | |
"lr": self.hparams.learning_rate, | |
} | |
] | |
optimizer = self._build_optimizer( | |
layer_parameters + ff_parameters + scalar_mix_parameters | |
) | |
else: | |
optimizer = self._build_optimizer(layer_parameters + ff_parameters) | |
scheduler = self._build_scheduler(optimizer) | |
return [optimizer], [scheduler] | |
def prepare_sample( | |
self, sample: List[Dict[str, Union[str, float]]], inference: bool = False | |
) -> Union[ | |
Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]], Dict[str, torch.Tensor] | |
]: | |
""" | |
Function that prepares a sample to input the model. | |
:param sample: list of dictionaries. | |
:param inference: If set to true prepares only the model inputs. | |
:returns: Tuple with 2 dictionaries (model inputs and targets). | |
If `inference=True` returns only the model inputs. | |
""" | |
sample = collate_tensors(sample) | |
mt_inputs = self.encoder.prepare_sample(sample["mt"]) | |
ref_inputs = [self.encoder.prepare_sample(ref) for ref in sample["refs"]] | |
inputs = { | |
"mt_inputs": mt_inputs, | |
"ref_inputs": ref_inputs, | |
"refs": sample["refs"], | |
"mt": sample["mt"], | |
"imgs": sample["img"] | |
} | |
if inference: | |
return inputs | |
targets = {"score": torch.tensor(sample["score"], dtype=torch.float)} | |
return inputs, targets | |
def masked_global_average_pooling(self, input_tensor, mask): | |
mask = mask.logical_not() # mask[x] = input[x] is not pad | |
mask_expanded = mask.unsqueeze(-1).expand_as(input_tensor).float() | |
input_tensor_masked = input_tensor * mask_expanded | |
num_elements = mask.sum(dim=1,keepdim=True).float() # TODO: γγ§γγ― | |
output_tensor = input_tensor_masked.sum(dim=1) / num_elements | |
return output_tensor | |
def forward( | |
self, | |
refs, | |
mt, | |
ref_inputs, | |
mt_inputs, | |
imgs: torch.tensor, | |
alt_tokens: torch.tensor = None, | |
alt_lengths: torch.tensor = None, | |
**kwargs | |
) -> Dict[str, torch.Tensor]: | |
mt_tokens, mt_lengths = mt_inputs["tokens"], mt_inputs["lengths"] | |
mt_sentemb, mt_sentembs, mt_mask, padding_index = self.get_sentence_embedding(mt_tokens, mt_lengths,pooling=False) | |
mt_mask = mt_mask.logical_not() | |
ref_sentemb_list = [] | |
ref_sentembs_list = [] | |
ref_mask_list = [] | |
for ref in ref_inputs: | |
ref_tokens, ref_lengths = ref["tokens"], ref["lengths"] | |
ref_sentemb, ref_sentembs, ref_mask, _ = self.get_sentence_embedding(ref_tokens, ref_lengths,pooling=False) | |
ref_mask = ref_mask.logical_not() | |
ref_sentemb_list.append(ref_sentemb) | |
ref_sentembs_list.append(ref_sentembs) | |
ref_mask_list.append(ref_mask) | |
refs_clip = [] | |
for ref_list in refs: # (ref_cnt, B, L) | |
subset = [clip.tokenize("A photo depicts " + ref,truncate=True).to(self.device) for ref in ref_list] | |
subset = torch.cat(subset,dim=0) | |
refs_tensor = self.clip.encode_text(subset) | |
refs_clip.append(refs_tensor) | |
mts_clip = clip.tokenize(["A photo depicts " + x for x in mt],truncate=True).to(self.device) | |
imgs_clip = torch.cat([self.clip_preprocess(img).unsqueeze(0) for img in imgs],dim=0).to(self.device) | |
imgs_clip = self.clip.encode_image(imgs_clip) | |
mts_clip = self.clip.encode_text(mts_clip) | |
del imgs | |
scores = [] | |
for ref_sentemb, ref_clip in zip(ref_sentemb_list,refs_clip): | |
diff = torch.abs(mt_sentemb - ref_sentemb) | |
mul = mt_sentemb * ref_sentemb | |
diff_clip = torch.abs(imgs_clip - mts_clip) | |
mul_clip = imgs_clip * mts_clip | |
diff_clip_txt = torch.abs(ref_clip - mts_clip) | |
mul_clip_txt = ref_clip * mts_clip | |
if self.parallel_feature_extraction: | |
x = torch.cat( | |
(ref_sentemb,mt_sentemb,diff,mul,imgs_clip,mts_clip,diff_clip,mul_clip,diff_clip_txt,mul_clip_txt),dim=1 | |
) | |
else: | |
x = torch.cat( | |
(ref_sentemb,mt_sentemb,ref_clip,imgs_clip,mts_clip),dim=1 | |
) | |
score = self.ff(x) | |
scores.append(score) | |
score = torch.max(torch.stack(scores),dim=0).values | |
return {"score" : score} |