|
|
|
import torch |
|
import clip |
|
import torch.nn as nn |
|
from model.MDM import PositionalEncoding, TimestepEmbedder |
|
|
|
|
|
class TextConditionalModel(nn.Module): |
|
def __init__(self, latent_dim=256, cond_mode="no_cond", cond_mask_prob=0., dropout=0.0, clip_dim=512, clip_version=None, **kargs): |
|
super().__init__() |
|
self.cond_mode = cond_mode |
|
assert self.cond_mode in ["no_cond", "text"] |
|
self.cond_mask_prob = cond_mask_prob |
|
|
|
self.sequence_pos_encoder = PositionalEncoding(latent_dim, dropout) |
|
self.embed_timestep = TimestepEmbedder(latent_dim, self.sequence_pos_encoder) |
|
|
|
if cond_mode != 'no_cond': |
|
if 'text' in cond_mode: |
|
self.embed_text = nn.Linear(clip_dim, latent_dim) |
|
print('Loading CLIP...') |
|
self.clip_version = clip_version |
|
self.clip_model = self.load_and_freeze_clip(clip_version) |
|
else: |
|
raise NotImplementedError("only conditioning with text is implemented for now") |
|
|
|
def parameters_wo_clip(self): |
|
return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')] |
|
|
|
def load_and_freeze_clip(self, clip_version): |
|
clip_model, clip_preprocess = clip.load(clip_version, device='cpu', |
|
jit=False) |
|
clip.model.convert_weights( |
|
clip_model) |
|
|
|
|
|
clip_model.eval() |
|
for p in clip_model.parameters(): |
|
p.requires_grad = False |
|
|
|
return clip_model |
|
|
|
def mask_cond(self, cond, force_mask=False): |
|
bs, d = cond.shape |
|
if force_mask: |
|
return torch.zeros_like(cond) |
|
elif self.training and self.cond_mask_prob > 0.: |
|
mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) |
|
return cond * (1. - mask) |
|
else: |
|
return cond |
|
|
|
def encode_text(self, raw_text): |
|
|
|
device = next(self.parameters()).device |
|
max_text_len = 20 if self.dataset in ['humanml'] else None |
|
if max_text_len is not None: |
|
default_context_length = 77 |
|
context_length = max_text_len + 2 |
|
assert context_length < default_context_length |
|
texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) |
|
zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device) |
|
texts = torch.cat([texts, zero_pad], dim=1) |
|
else: |
|
texts = clip.tokenize(raw_text, truncate=True).to(device) |
|
return self.clip_model.encode_text(texts).float() |
|
|
|
def compute_embedding(self, x, timesteps, y): |
|
""" |
|
Explanation on what the buffers do: |
|
- emb: stores the embedding for the current condition. It is used to avoid recomputing the embedding if the condition is the same (big inference speedup) |
|
- emb_hash: stores the hash of the condition. It is used to check if the condition is the same as the one stored in emb |
|
- emb_forcemask: stores the embedding for the current condition, but with the mask forced to True. It is used to avoid recomputing the embedding for the unconditional case |
|
- emb_forcemask_hash: stores the hash of the condition. It is used to check if the condition is the same as the one stored in emb_forcemask |
|
""" |
|
bs, njoints, nfeats, nframes = x.shape |
|
|
|
multitext_mode = "all_texts" in y or not isinstance(y['text'][0], str) |
|
key = "all_texts" if "all_texts" in y else "text" |
|
|
|
time_emb = self.embed_timestep(timesteps) |
|
force_mask = y.get('uncond', False) |
|
if not force_mask: |
|
if 'text' == self.cond_mode: |
|
primitive = frozenset(y[key]) if not multitext_mode else frozenset((frozenset(txts) for txts in y[key])) |
|
else: |
|
raise ValueError |
|
|
|
hash_value = hash(primitive) |
|
recompute = not hasattr(self, 'emb_hash') or self.emb_hash != hash_value |
|
if not recompute: |
|
return time_emb + self.emb |
|
else: |
|
hash_value = hash(frozenset(x.shape)) |
|
recompute = not hasattr(self, 'emb_forcemask_hash') or self.emb_forcemask_hash != hash_value |
|
if not recompute: |
|
return time_emb + self.emb_forcemask |
|
|
|
|
|
if not multitext_mode: |
|
enc_text = self.encode_text(y['text']) if "text_embeddings" not in y else y["text_embeddings"] |
|
cond_emb = self.embed_text(self.mask_cond(enc_text, force_mask=force_mask)) |
|
cond_emb = cond_emb.unsqueeze(0).expand(nframes, -1, -1) |
|
else: |
|
if "text_embeddings" in y: |
|
enc_text = y["text_embeddings"] |
|
else: |
|
|
|
|
|
|
|
texts_list = y[key] |
|
|
|
max_len = max([len(texts) for texts in texts_list]) |
|
for i, texts in enumerate(texts_list): |
|
if len(texts) < max_len: |
|
texts_list[i] = texts + [''] * (max_len - len(texts)) |
|
enc_text = [self.encode_text(text) for text in texts_list] |
|
enc_text = torch.stack(enc_text, dim=1) |
|
|
|
I, N, d = enc_text.shape |
|
enc_text = enc_text.reshape(-1, enc_text.shape[-1]) |
|
embedded_text = self.embed_text(self.mask_cond(enc_text, force_mask=force_mask)).reshape(I, N, d) |
|
|
|
conditions_mask = y['conditions_mask'] |
|
conditions_mask = conditions_mask.unsqueeze(-1).expand(-1, -1, -1, self.latent_dim) |
|
cond_emb = torch.zeros(conditions_mask.shape[1:], device=embedded_text.device) |
|
for i in range(I): |
|
m = conditions_mask[i] |
|
cond_emb = cond_emb + m * embedded_text[i].unsqueeze(0) |
|
|
|
|
|
if force_mask: |
|
self.register_buffer('emb_forcemask', cond_emb, persistent=False) |
|
self.register_buffer('emb_forcemask_hash', torch.tensor(hash(frozenset(x.shape))), persistent=False) |
|
else: |
|
self.register_buffer('emb', cond_emb, persistent=False) |
|
self.register_buffer('emb_hash', torch.tensor(hash(primitive)), persistent=False) |
|
|
|
return time_emb + cond_emb |
|
|