motionfix-demo / model_utils.py
atnikos's picture
fix retrieval placeholders
6837c8b
import torch
import torch.nn as nn
import numpy as np
import torch
from torch import nn
import torch
from typing import List, Dict, Optional
from torch import Tensor
class TimestepEmbedderMDM(nn.Module):
def __init__(self, latent_dim):
super().__init__()
self.latent_dim = latent_dim
time_embed_dim = self.latent_dim
self.sequence_pos_encoder = PositionalEncoding(d_model=self.latent_dim)
# TODO add time embedding learnable
self.time_embed = nn.Sequential(
nn.Linear(self.latent_dim, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
).to('cuda')
def forward(self, timesteps):
return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1,
max_len=5000, batch_first=False, negative=False):
super().__init__()
self.batch_first = batch_first
self.dropout = nn.Dropout(p=dropout)
self.max_len = max_len
self.negative = negative
if negative:
pe = torch.zeros(2*max_len, d_model,device='cuda')
position = torch.arange(-max_len, max_len, dtype=torch.float).unsqueeze(1)
else:
pe = torch.zeros(max_len, d_model,device='cuda')
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe, persistent=False)
def forward(self, x, hist_frames=0):
if not self.negative:
center = 0
assert hist_frames == 0
first = 0
else:
center = self.max_len
first = center-hist_frames
if self.batch_first:
last = first + x.shape[1]
x = x + self.pe.permute(1, 0, 2)[:, first:last, :]
else:
last = first + x.shape[0]
x = x + self.pe[first:last, :]
return self.dropout(x)
def collate_tensor_with_padding(batch: List[Tensor]) -> Tensor:
dims = batch[0].dim()
max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
size = (len(batch),) + tuple(max_size)
canvas = batch[0].new_zeros(size=size)
for i, b in enumerate(batch):
sub_tensor = canvas[i]
for d in range(dims):
sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
sub_tensor.add_(b)
return canvas
def collate_x_dict(lst_x_dict: List, *, device: Optional[str] = 'cuda') -> Dict:
x = collate_tensor_with_padding([x_dict["x"] for x_dict in lst_x_dict])
if device is not None:
x = x.to(device)
length = [x_dict["length"] for x_dict in lst_x_dict]
if isinstance(length, list):
length = torch.tensor(length, device=device)
max_len = max(length)
mask = torch.arange(max_len, device=device).expand(
len(length), max_len
) < length.unsqueeze(1)
batch = {"x": x, "length": length, "mask": mask}
return batch