File size: 7,511 Bytes
899324d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

import torch
import clip
import torch.nn as nn
from model.MDM import PositionalEncoding, TimestepEmbedder


class TextConditionalModel(nn.Module):
    def __init__(self, latent_dim=256, cond_mode="no_cond", cond_mask_prob=0., dropout=0.0, clip_dim=512, clip_version=None, **kargs):
        super().__init__()
        self.cond_mode = cond_mode
        assert self.cond_mode in ["no_cond", "text"]
        self.cond_mask_prob = cond_mask_prob

        self.sequence_pos_encoder = PositionalEncoding(latent_dim, dropout)
        self.embed_timestep = TimestepEmbedder(latent_dim, self.sequence_pos_encoder)
        
        if cond_mode != 'no_cond':
            if 'text' in cond_mode:
                self.embed_text = nn.Linear(clip_dim, latent_dim)
                print('Loading CLIP...')
                self.clip_version = clip_version
                self.clip_model = self.load_and_freeze_clip(clip_version)
            else:
                raise NotImplementedError("only conditioning with text is implemented for now")

    def parameters_wo_clip(self):
        return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')]

    def load_and_freeze_clip(self, clip_version):
        clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
                                                jit=False)  # Must set jit=False for training
        clip.model.convert_weights(
            clip_model)  # Actually this line is unnecessary since clip by default already on float16

        # Freeze CLIP weights
        clip_model.eval()
        for p in clip_model.parameters():
            p.requires_grad = False

        return clip_model

    def mask_cond(self, cond, force_mask=False):
        bs, d = cond.shape
        if force_mask:
            return torch.zeros_like(cond)
        elif self.training and self.cond_mask_prob > 0.:
            mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1)  # 1-> use null_cond, 0-> use real cond
            return cond * (1. - mask)
        else:
            return cond

    def encode_text(self, raw_text):
        # raw_text - list (batch_size length) of strings with input text prompts
        device = next(self.parameters()).device
        max_text_len = 20 if self.dataset in ['humanml'] else None  # Specific hardcoding for humanml dataset
        if max_text_len is not None:
            default_context_length = 77
            context_length = max_text_len + 2 # start_token + 20 + end_token
            assert context_length < default_context_length
            texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) # [bs, context_length] # if n_tokens > context_length -> will truncate
            zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device)
            texts = torch.cat([texts, zero_pad], dim=1)
        else:
            texts = clip.tokenize(raw_text, truncate=True).to(device) # [bs, context_length] # if n_tokens > 77 -> will truncate
        return self.clip_model.encode_text(texts).float()
    
    def compute_embedding(self, x, timesteps, y):
        """
            Explanation on what the buffers do:
            - emb: stores the embedding for the current condition. It is used to avoid recomputing the embedding if the condition is the same (big inference speedup)
            - emb_hash: stores the hash of the condition. It is used to check if the condition is the same as the one stored in emb
            - emb_forcemask: stores the embedding for the current condition, but with the mask forced to True. It is used to avoid recomputing the embedding for the unconditional case
            - emb_forcemask_hash: stores the hash of the condition. It is used to check if the condition is the same as the one stored in emb_forcemask
        """
        bs, njoints, nfeats, nframes = x.shape

        multitext_mode = "all_texts" in y or not isinstance(y['text'][0], str)
        key = "all_texts" if "all_texts" in y else "text"

        time_emb = self.embed_timestep(timesteps)  # [1, bs, d]
        force_mask = y.get('uncond', False)
        if not force_mask:
            if 'text' == self.cond_mode:
                primitive = frozenset(y[key]) if not multitext_mode else frozenset((frozenset(txts) for txts in y[key]))
            else:
                raise ValueError
            
            hash_value = hash(primitive)
            recompute = not hasattr(self, 'emb_hash') or self.emb_hash != hash_value
            if not recompute:
                return time_emb + self.emb
        else:
            hash_value = hash(frozenset(x.shape))
            recompute = not hasattr(self, 'emb_forcemask_hash') or self.emb_forcemask_hash != hash_value
            if not recompute:
                return time_emb + self.emb_forcemask

        # compute embedding
        if not multitext_mode: # --> single text training (e.g. HumanML3D dataset) / inference
            enc_text = self.encode_text(y['text']) if "text_embeddings" not in y else y["text_embeddings"] # if precomputed --> faster
            cond_emb = self.embed_text(self.mask_cond(enc_text, force_mask=force_mask))
            cond_emb = cond_emb.unsqueeze(0).expand(nframes, -1, -1) # [T, N, d]
        else: # --> multi-text training / inference (e.g. Babel dataset)
            if "text_embeddings" in y: # preloaded for fast training / eval
                enc_text = y["text_embeddings"]
            else:
                # 'conditions_mask' has shape [I, T, N] where I is the number of different conditions, N is batch size, T is sequence length.
                # y[key] is a list of size I with each element being a list of strings of size N
                # We need to encode the text and build the embedding matrix
                texts_list = y[key]
                # homogeneize all lists to same length to stack them later
                max_len = max([len(texts) for texts in texts_list])
                for i, texts in enumerate(texts_list):
                    if len(texts) < max_len:
                        texts_list[i] = texts + [''] * (max_len - len(texts))
                enc_text = [self.encode_text(text) for text in texts_list]
                enc_text = torch.stack(enc_text, dim=1)

            I, N, d = enc_text.shape
            enc_text = enc_text.reshape(-1, enc_text.shape[-1]) # [I*N, d]
            embedded_text = self.embed_text(self.mask_cond(enc_text, force_mask=force_mask)).reshape(I, N, d) # [I, N, d]

            conditions_mask = y['conditions_mask'] # [I, T, N]
            conditions_mask = conditions_mask.unsqueeze(-1).expand(-1, -1, -1, self.latent_dim) # [I, T, N, d]
            cond_emb = torch.zeros(conditions_mask.shape[1:], device=embedded_text.device) # [T, N, d]
            for i in range(I):
                m = conditions_mask[i] # [T, N, d]
                cond_emb = cond_emb + m * embedded_text[i].unsqueeze(0) # [T, N, d] --> [T, N, d]
        
        # send to buffer
        if force_mask:
            self.register_buffer('emb_forcemask', cond_emb, persistent=False)
            self.register_buffer('emb_forcemask_hash', torch.tensor(hash(frozenset(x.shape))), persistent=False)
        else:
            self.register_buffer('emb', cond_emb, persistent=False)
            self.register_buffer('emb_hash', torch.tensor(hash(primitive)), persistent=False)

        return time_emb + cond_emb