TANGO / emage /skeleton.py
ameerazam08's picture
Upload folder using huggingface_hub
af98a6c verified
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class SkeletonConv(nn.Module):
def __init__(self, neighbour_list, in_channels, out_channels, kernel_size, joint_num, stride=1, padding=0,
bias=True, padding_mode='zeros', add_offset=False, in_offset_channel=0):
self.in_channels_per_joint = in_channels // joint_num
self.out_channels_per_joint = out_channels // joint_num
if in_channels % joint_num != 0 or out_channels % joint_num != 0:
raise Exception('BAD')
super(SkeletonConv, self).__init__()
if padding_mode == 'zeros':
padding_mode = 'constant'
if padding_mode == 'reflection':
padding_mode = 'reflect'
self.expanded_neighbour_list = []
self.expanded_neighbour_list_offset = []
self.neighbour_list = neighbour_list
self.add_offset = add_offset
self.joint_num = joint_num
self.stride = stride
self.dilation = 1
self.groups = 1
self.padding = padding
self.padding_mode = padding_mode
self._padding_repeated_twice = (padding, padding)
for neighbour in neighbour_list:
expanded = []
for k in neighbour:
for i in range(self.in_channels_per_joint):
expanded.append(k * self.in_channels_per_joint + i)
self.expanded_neighbour_list.append(expanded)
if self.add_offset:
self.offset_enc = SkeletonLinear(neighbour_list, in_offset_channel * len(neighbour_list), out_channels)
for neighbour in neighbour_list:
expanded = []
for k in neighbour:
for i in range(add_offset):
expanded.append(k * in_offset_channel + i)
self.expanded_neighbour_list_offset.append(expanded)
self.weight = torch.zeros(out_channels, in_channels, kernel_size)
if bias:
self.bias = torch.zeros(out_channels)
else:
self.register_parameter('bias', None)
self.mask = torch.zeros_like(self.weight)
for i, neighbour in enumerate(self.expanded_neighbour_list):
self.mask[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...] = 1
self.mask = nn.Parameter(self.mask, requires_grad=False)
self.description = 'SkeletonConv(in_channels_per_armature={}, out_channels_per_armature={}, kernel_size={}, ' \
'joint_num={}, stride={}, padding={}, bias={})'.format(
in_channels // joint_num, out_channels // joint_num, kernel_size, joint_num, stride, padding, bias
)
self.reset_parameters()
def reset_parameters(self):
for i, neighbour in enumerate(self.expanded_neighbour_list):
""" Use temporary variable to avoid assign to copy of slice, which might lead to unexpected result """
tmp = torch.zeros_like(self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1),
neighbour, ...])
nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1),
neighbour, ...] = tmp
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...])
bound = 1 / math.sqrt(fan_in)
tmp = torch.zeros_like(
self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)])
nn.init.uniform_(tmp, -bound, bound)
self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)] = tmp
self.weight = nn.Parameter(self.weight)
if self.bias is not None:
self.bias = nn.Parameter(self.bias)
def set_offset(self, offset):
if not self.add_offset:
raise Exception('Wrong Combination of Parameters')
self.offset = offset.reshape(offset.shape[0], -1)
def forward(self, input):
# print('SkeletonConv')
weight_masked = self.weight * self.mask
#print(f'input: {input.size()}')
res = F.conv1d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
weight_masked, self.bias, self.stride,
0, self.dilation, self.groups)
if self.add_offset:
offset_res = self.offset_enc(self.offset)
offset_res = offset_res.reshape(offset_res.shape + (1, ))
res += offset_res / 100
#print(f'res: {res.size()}')
return res
class SkeletonLinear(nn.Module):
def __init__(self, neighbour_list, in_channels, out_channels, extra_dim1=False):
super(SkeletonLinear, self).__init__()
self.neighbour_list = neighbour_list
self.in_channels = in_channels
self.out_channels = out_channels
self.in_channels_per_joint = in_channels // len(neighbour_list)
self.out_channels_per_joint = out_channels // len(neighbour_list)
self.extra_dim1 = extra_dim1
self.expanded_neighbour_list = []
for neighbour in neighbour_list:
expanded = []
for k in neighbour:
for i in range(self.in_channels_per_joint):
expanded.append(k * self.in_channels_per_joint + i)
self.expanded_neighbour_list.append(expanded)
self.weight = torch.zeros(out_channels, in_channels)
self.mask = torch.zeros(out_channels, in_channels)
self.bias = nn.Parameter(torch.Tensor(out_channels))
self.reset_parameters()
def reset_parameters(self):
for i, neighbour in enumerate(self.expanded_neighbour_list):
tmp = torch.zeros_like(
self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour]
)
self.mask[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = 1
nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = tmp
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
self.weight = nn.Parameter(self.weight)
self.mask = nn.Parameter(self.mask, requires_grad=False)
def forward(self, input):
input = input.reshape(input.shape[0], -1)
weight_masked = self.weight * self.mask
res = F.linear(input, weight_masked, self.bias)
if self.extra_dim1:
res = res.reshape(res.shape + (1,))
return res
class SkeletonPool(nn.Module):
def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=False):
super(SkeletonPool, self).__init__()
if pooling_mode != 'mean':
raise Exception('Unimplemented pooling mode in matrix_implementation')
self.channels_per_edge = channels_per_edge
self.pooling_mode = pooling_mode
self.edge_num = len(edges)
# self.edge_num = len(edges) + 1
self.seq_list = []
self.pooling_list = []
self.new_edges = []
degree = [0] * 100 # each element represents the degree of the corresponding joint
for edge in edges:
degree[edge[0]] += 1
degree[edge[1]] += 1
# seq_list contains multiple sub-lists where each sub-list is an edge chain from the joint whose degree > 2 to the end effectors or joints whose degree > 2.
def find_seq(j, seq):
nonlocal self, degree, edges
if degree[j] > 2 and j != 0:
self.seq_list.append(seq)
seq = []
if degree[j] == 1:
self.seq_list.append(seq)
return
for idx, edge in enumerate(edges):
if edge[0] == j:
find_seq(edge[1], seq + [idx])
find_seq(0, [])
# print(f'self.seq_list: {self.seq_list}')
for seq in self.seq_list:
if last_pool:
self.pooling_list.append(seq)
continue
if len(seq) % 2 == 1:
self.pooling_list.append([seq[0]])
self.new_edges.append(edges[seq[0]])
seq = seq[1:]
for i in range(0, len(seq), 2):
self.pooling_list.append([seq[i], seq[i + 1]])
self.new_edges.append([edges[seq[i]][0], edges[seq[i + 1]][1]])
# print(f'self.pooling_list: {self.pooling_list}')
# print(f'self.new_egdes: {self.new_edges}')
# add global position
# self.pooling_list.append([self.edge_num - 1])
self.description = 'SkeletonPool(in_edge_num={}, out_edge_num={})'.format(
len(edges), len(self.pooling_list)
)
self.weight = torch.zeros(len(self.pooling_list) * channels_per_edge, self.edge_num * channels_per_edge)
for i, pair in enumerate(self.pooling_list):
for j in pair:
for c in range(channels_per_edge):
self.weight[i * channels_per_edge + c, j * channels_per_edge + c] = 1.0 / len(pair)
self.weight = nn.Parameter(self.weight, requires_grad=False)
def forward(self, input: torch.Tensor):
# print('SkeletonPool')
# print(f'input: {input.size()}')
# print(f'self.weight: {self.weight.size()}')
return torch.matmul(self.weight, input)
class SkeletonUnpool(nn.Module):
def __init__(self, pooling_list, channels_per_edge):
super(SkeletonUnpool, self).__init__()
self.pooling_list = pooling_list
self.input_edge_num = len(pooling_list)
self.output_edge_num = 0
self.channels_per_edge = channels_per_edge
for t in self.pooling_list:
self.output_edge_num += len(t)
self.description = 'SkeletonUnpool(in_edge_num={}, out_edge_num={})'.format(
self.input_edge_num, self.output_edge_num,
)
self.weight = torch.zeros(self.output_edge_num * channels_per_edge, self.input_edge_num * channels_per_edge)
for i, pair in enumerate(self.pooling_list):
for j in pair:
for c in range(channels_per_edge):
self.weight[j * channels_per_edge + c, i * channels_per_edge + c] = 1
self.weight = nn.Parameter(self.weight)
self.weight.requires_grad_(False)
def forward(self, input: torch.Tensor):
# print('SkeletonUnpool')
# print(f'input: {input.size()}')
# print(f'self.weight: {self.weight.size()}')
return torch.matmul(self.weight, input)
"""
Helper functions for skeleton operation
"""
def dfs(x, fa, vis, dist):
vis[x] = 1
for y in range(len(fa)):
if (fa[y] == x or fa[x] == y) and vis[y] == 0:
dist[y] = dist[x] + 1
dfs(y, fa, vis, dist)
"""
def find_neighbor_joint(fa, threshold):
neighbor_list = [[]]
for x in range(1, len(fa)):
vis = [0 for _ in range(len(fa))]
dist = [0 for _ in range(len(fa))]
dist[0] = 10000
dfs(x, fa, vis, dist)
neighbor = []
for j in range(1, len(fa)):
if dist[j] <= threshold:
neighbor.append(j)
neighbor_list.append(neighbor)
neighbor = [0]
for i, x in enumerate(neighbor_list):
if i == 0: continue
if 1 in x:
neighbor.append(i)
neighbor_list[i] = [0] + neighbor_list[i]
neighbor_list[0] = neighbor
return neighbor_list
def build_edge_topology(topology, offset):
# get all edges (pa, child, offset)
edges = []
joint_num = len(topology)
for i in range(1, joint_num):
edges.append((topology[i], i, offset[i]))
return edges
"""
def build_edge_topology(topology):
# get all edges (pa, child)
edges = []
joint_num = len(topology)
edges.append((0, joint_num)) # add an edge between the root joint and a virtual joint
for i in range(1, joint_num):
edges.append((topology[i], i))
return edges
def build_joint_topology(edges, origin_names):
parent = []
offset = []
names = []
edge2joint = []
joint_from_edge = [] # -1 means virtual joint
joint_cnt = 0
out_degree = [0] * (len(edges) + 10)
for edge in edges:
out_degree[edge[0]] += 1
# add root joint
joint_from_edge.append(-1)
parent.append(0)
offset.append(np.array([0, 0, 0]))
names.append(origin_names[0])
joint_cnt += 1
def make_topology(edge_idx, pa):
nonlocal edges, parent, offset, names, edge2joint, joint_from_edge, joint_cnt
edge = edges[edge_idx]
if out_degree[edge[0]] > 1:
parent.append(pa)
offset.append(np.array([0, 0, 0]))
names.append(origin_names[edge[1]] + '_virtual')
edge2joint.append(-1)
pa = joint_cnt
joint_cnt += 1
parent.append(pa)
offset.append(edge[2])
names.append(origin_names[edge[1]])
edge2joint.append(edge_idx)
pa = joint_cnt
joint_cnt += 1
for idx, e in enumerate(edges):
if e[0] == edge[1]:
make_topology(idx, pa)
for idx, e in enumerate(edges):
if e[0] == 0:
make_topology(idx, 0)
return parent, offset, names, edge2joint
def calc_edge_mat(edges):
edge_num = len(edges)
# edge_mat[i][j] = distance between edge(i) and edge(j)
edge_mat = [[100000] * edge_num for _ in range(edge_num)]
for i in range(edge_num):
edge_mat[i][i] = 0
# initialize edge_mat with direct neighbor
for i, a in enumerate(edges):
for j, b in enumerate(edges):
link = 0
for x in range(2):
for y in range(2):
if a[x] == b[y]:
link = 1
if link:
edge_mat[i][j] = 1
# calculate all the pairs distance
for k in range(edge_num):
for i in range(edge_num):
for j in range(edge_num):
edge_mat[i][j] = min(edge_mat[i][j], edge_mat[i][k] + edge_mat[k][j])
return edge_mat
def find_neighbor(edges, d):
"""
Args:
edges: The list contains N elements, each element represents (parent, child).
d: Distance between edges (the distance of the same edge is 0 and the distance of adjacent edges is 1).
Returns:
The list contains N elements, each element is a list of edge indices whose distance <= d.
"""
edge_mat = calc_edge_mat(edges)
neighbor_list = []
edge_num = len(edge_mat)
for i in range(edge_num):
neighbor = []
for j in range(edge_num):
if edge_mat[i][j] <= d:
neighbor.append(j)
neighbor_list.append(neighbor)
# # add neighbor for global part
# global_part_neighbor = neighbor_list[0].copy()
# """
# Line #373 is buggy. Thanks @crissallan!!
# See issue #30 (https://github.com/DeepMotionEditing/deep-motion-editing/issues/30)
# However, fixing this bug will make it unable to load the pretrained model and
# affect the reproducibility of quantitative error reported in the paper.
# It is not a fatal bug so we didn't touch it and we are looking for possible solutions.
# """
# for i in global_part_neighbor:
# neighbor_list[i].append(edge_num)
# neighbor_list.append(global_part_neighbor)
return neighbor_list
def calc_node_depth(topology):
def dfs(node, topology):
if topology[node] < 0:
return 0
return 1 + dfs(topology[node], topology)
depth = []
for i in range(len(topology)):
depth.append(dfs(i, topology))
return depth
def residual_ratio(k):
return 1 / (k + 1)
class Affine(nn.Module):
def __init__(self, num_parameters, scale=True, bias=True, scale_init=1.0):
super(Affine, self).__init__()
if scale:
self.scale = nn.Parameter(torch.ones(num_parameters) * scale_init)
else:
self.register_parameter('scale', None)
if bias:
self.bias = nn.Parameter(torch.zeros(num_parameters))
else:
self.register_parameter('bias', None)
def forward(self, input):
output = input
if self.scale is not None:
scale = self.scale.unsqueeze(0)
while scale.dim() < input.dim():
scale = scale.unsqueeze(2)
output = output.mul(scale)
if self.bias is not None:
bias = self.bias.unsqueeze(0)
while bias.dim() < input.dim():
bias = bias.unsqueeze(2)
output += bias
return output
class BatchStatistics(nn.Module):
def __init__(self, affine=-1):
super(BatchStatistics, self).__init__()
self.affine = nn.Sequential() if affine == -1 else Affine(affine)
self.loss = 0
def clear_loss(self):
self.loss = 0
def compute_loss(self, input):
input_flat = input.view(input.size(1), input.numel() // input.size(1))
mu = input_flat.mean(1)
logvar = (input_flat.pow(2).mean(1) - mu.pow(2)).sqrt().log()
self.loss = mu.pow(2).mean() + logvar.pow(2).mean()
def forward(self, input):
self.compute_loss(input)
return self.affine(input)
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation, batch_statistics=False, last_layer=False):
super(ResidualBlock, self).__init__()
self.residual_ratio = residual_ratio
self.shortcut_ratio = 1 - residual_ratio
residual = []
residual.append(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding))
if batch_statistics:
residual.append(BatchStatistics(out_channels))
if not last_layer:
residual.append(nn.PReLU() if activation == 'relu' else nn.Tanh())
self.residual = nn.Sequential(*residual)
self.shortcut = nn.Sequential(
nn.AvgPool1d(kernel_size=2) if stride == 2 else nn.Sequential(),
nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
BatchStatistics(out_channels) if (in_channels != out_channels and batch_statistics is True) else nn.Sequential()
)
def forward(self, input):
return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio)
class ResidualBlockTranspose(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation):
super(ResidualBlockTranspose, self).__init__()
self.residual_ratio = residual_ratio
self.shortcut_ratio = 1 - residual_ratio
self.residual = nn.Sequential(
nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding),
nn.PReLU() if activation == 'relu' else nn.Tanh()
)
self.shortcut = nn.Sequential(
nn.Upsample(scale_factor=2, mode='linear', align_corners=False) if stride == 2 else nn.Sequential(),
nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
)
def forward(self, input):
return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio)
class SkeletonResidual(nn.Module):
def __init__(self, topology, neighbour_list, joint_num, in_channels, out_channels, kernel_size, stride, padding, padding_mode, bias, extra_conv, pooling_mode, activation, last_pool):
super(SkeletonResidual, self).__init__()
kernel_even = False if kernel_size % 2 else True
seq = []
for _ in range(extra_conv):
# (T, J, D) => (T, J, D)
seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels,
joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size,
stride=1,
padding=padding, padding_mode=padding_mode, bias=bias))
seq.append(nn.PReLU() if activation == 'relu' else nn.Tanh())
# (T, J, D) => (T/2, J, 2D)
seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels,
joint_num=joint_num, kernel_size=kernel_size, stride=stride,
padding=padding, padding_mode=padding_mode, bias=bias, add_offset=False))
seq.append(nn.GroupNorm(10, out_channels)) # FIXME: REMEMBER TO CHANGE BACK !!!
self.residual = nn.Sequential(*seq)
# (T, J, D) => (T/2, J, 2D)
self.shortcut = SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels,
joint_num=joint_num, kernel_size=1, stride=stride, padding=0,
bias=True, add_offset=False)
seq = []
# (T/2, J, 2D) => (T/2, J', 2D)
pool = SkeletonPool(edges=topology, pooling_mode=pooling_mode,
channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool)
if len(pool.pooling_list) != pool.edge_num:
seq.append(pool)
seq.append(nn.PReLU() if activation == 'relu' else nn.Tanh())
self.common = nn.Sequential(*seq)
def forward(self, input):
output = self.residual(input) + self.shortcut(input)
return self.common(output)
class SkeletonResidualTranspose(nn.Module):
def __init__(self, neighbour_list, joint_num, in_channels, out_channels, kernel_size, padding, padding_mode, bias, extra_conv, pooling_list, upsampling, activation, last_layer):
super(SkeletonResidualTranspose, self).__init__()
kernel_even = False if kernel_size % 2 else True
seq = []
# (T, J, D) => (2T, J, D)
if upsampling is not None:
seq.append(nn.Upsample(scale_factor=2, mode=upsampling, align_corners=False))
# (2T, J, D) => (2T, J', D)
unpool = SkeletonUnpool(pooling_list, in_channels // len(neighbour_list))
if unpool.input_edge_num != unpool.output_edge_num:
seq.append(unpool)
self.common = nn.Sequential(*seq)
seq = []
for _ in range(extra_conv):
# (2T, J', D) => (2T, J', D)
seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels,
joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size,
stride=1,
padding=padding, padding_mode=padding_mode, bias=bias))
seq.append(nn.PReLU() if activation == 'relu' else nn.Tanh())
# (2T, J', D) => (2T, J', D/2)
seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels,
joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size,
stride=1,
padding=padding, padding_mode=padding_mode, bias=bias, add_offset=False))
self.residual = nn.Sequential(*seq)
# (2T, J', D) => (2T, J', D/2)
self.shortcut = SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels,
joint_num=joint_num, kernel_size=1, stride=1, padding=0,
bias=True, add_offset=False)
if activation == 'relu':
self.activation = nn.PReLU() if not last_layer else None
else:
self.activation = nn.Tanh() if not last_layer else None
def forward(self, input):
output = self.common(input)
output = self.residual(output) + self.shortcut(output)
if self.activation is not None:
return self.activation(output)
else:
return output