File size: 1,507 Bytes
93091a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
# coding: utf-8
"""
Stitching module(S) and two retargeting live_portrait(R) defined in the paper.
- The stitching module pastes the animated portrait back into the original image space without pixel misalignment, such as in
the stitching region.
- The eyes retargeting module is designed to address the issue of incomplete eye closure during cross-id reenactment, especially
when a person with small eyes drives a person with larger eyes.
- The lip retargeting module is designed similarly to the eye retargeting module, and can also normalize the input by ensuring that
the lips are in a closed state, which facilitates better animation driving.
"""
from torch import nn
class StitchingRetargetingNetwork(nn.Module):
def __init__(self, input_size, hidden_sizes, output_size):
super(StitchingRetargetingNetwork, self).__init__()
layers = []
for i in range(len(hidden_sizes)):
if i == 0:
layers.append(nn.Linear(input_size, hidden_sizes[i]))
else:
layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i]))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Linear(hidden_sizes[-1], output_size))
self.mlp = nn.Sequential(*layers)
def initialize_weights_to_zero(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.zeros_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
return self.mlp(x)
|