TMR / model.py
Mathis Petrovich
remove tuple in type annotation
007d500
raw history blame
No virus
4.83 kB
from typing import List
import torch.nn as nn
import os
import torch
import numpy as np
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from transformers import logging
from torch.nn.functional import normalize
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe, persistent=False)
def forward(self, x):
return x + self.pe[:x.shape[0], :]
class TMR_textencoder(nn.Module):
def __init__(self, modelpath: str, latent_dim: int, ff_size: int,
num_layers: int, num_heads: int, activation: str, **kwargs) -> None:
super().__init__()
logging.set_verbosity_error()
# Tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
# Text model
self.text_model = AutoModel.from_pretrained(modelpath)
# Then configure the model
self.text_encoded_dim = self.text_model.config.hidden_size
# Projection of the text-outputs into the latent space
self.projection = nn.Sequential(
nn.ReLU(),
nn.Linear(self.text_encoded_dim, latent_dim)
)
self.mu_token = nn.Parameter(torch.randn(latent_dim))
self.logvar_token = nn.Parameter(torch.randn(latent_dim))
self.sequence_pos_encoding = PositionalEncoding(latent_dim)
seq_trans_encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim,
nhead=num_heads,
dim_feedforward=ff_size,
dropout=0.0,
activation=activation)
self.seqTransEncoder = nn.TransformerEncoder(
seq_trans_encoder_layer,
num_layers=num_layers
)
def get_last_hidden_state(self, texts: List[str],
return_mask: bool = False):
encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
output = self.text_model(**encoded_inputs.to(self.text_model.device))
if not return_mask:
return output.last_hidden_state
return output.last_hidden_state, encoded_inputs.attention_mask.to(dtype=bool)
def forward(self, texts: List[str]) -> Tensor:
text_encoded, mask = self.get_last_hidden_state(texts, return_mask=True)
x = self.projection(text_encoded)
bs, nframes, _ = x.shape
# bs, nframes, totjoints, nfeats = x.shape
# Switch sequence and batch_size because the input of
# Pytorch Transformer is [Sequence, Batch size, ...]
x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim]
mu_token = torch.tile(self.mu_token, (bs,)).reshape(bs, -1)
logvar_token = torch.tile(self.logvar_token, (bs,)).reshape(bs, -1)
# adding the distribution tokens for all sequences
xseq = torch.cat((mu_token[None], logvar_token[None], x), 0)
# create a bigger mask, to allow attend to mu and logvar
token_mask = torch.ones((bs, 2), dtype=bool, device=x.device)
aug_mask = torch.cat((token_mask, mask), 1)
# add positional encoding
xseq = self.sequence_pos_encoding(xseq)
final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask)
# only mu for inference
mu = final[0]
return mu
# compute score for retrieval
def compute_scores(self, texts, unit_embs=None, embs=None):
# not both empty
assert not (unit_embs is None and embs is None)
# not both filled
assert not (unit_embs is not None and embs is not None)
output_str = False
# if one input, squeeze the output
if isinstance(texts, str):
texts = [texts]
output_str = True
# compute unit_embs from embs if not given
if embs is not None:
unit_embs = normalize(embs)
with torch.no_grad():
latent_unit_texts = normalize(self(texts))
# compute cosine similarity between 0 and 1
scores = (unit_embs @ latent_unit_texts.T).T/2 + 0.5
scores = scores.cpu().numpy()
if output_str:
scores = scores[0]
return scores