|
|
|
|
|
"""
|
|
Motion extractor(M), which directly predicts the canonical keypoints, head pose and expression deformation of the input image
|
|
"""
|
|
|
|
from torch import nn
|
|
import torch
|
|
|
|
from .convnextv2 import convnextv2_tiny
|
|
from .util import filter_state_dict
|
|
|
|
model_dict = {
|
|
'convnextv2_tiny': convnextv2_tiny,
|
|
}
|
|
|
|
|
|
class MotionExtractor(nn.Module):
|
|
def __init__(self, **kwargs):
|
|
super(MotionExtractor, self).__init__()
|
|
|
|
|
|
backbone = kwargs.get('backbone', 'convnextv2_tiny')
|
|
self.detector = model_dict.get(backbone)(**kwargs)
|
|
|
|
def load_pretrained(self, init_path: str):
|
|
if init_path not in (None, ''):
|
|
state_dict = torch.load(init_path, map_location=lambda storage, loc: storage)['model']
|
|
state_dict = filter_state_dict(state_dict, remove_name='head')
|
|
ret = self.detector.load_state_dict(state_dict, strict=False)
|
|
print(f'Load pretrained model from {init_path}, ret: {ret}')
|
|
|
|
def forward(self, x):
|
|
out = self.detector(x)
|
|
return out
|
|
|