|
from typing import List, Optional, Tuple, Dict, Callable |
|
|
|
import torch |
|
from torch import Tensor |
|
from torch.nn import Module |
|
|
|
from tha3.poser.poser import PoseParameterGroup, Poser |
|
from tha3.compute.cached_computation_func import TensorListCachedComputationFunc |
|
|
|
|
|
class GeneralPoser02(Poser): |
|
def __init__(self, |
|
module_loaders: Dict[str, Callable[[], Module]], |
|
device: torch.device, |
|
output_length: int, |
|
pose_parameters: List[PoseParameterGroup], |
|
output_list_func: TensorListCachedComputationFunc, |
|
subrect: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None, |
|
default_output_index: int = 0, |
|
image_size: int = 256, |
|
dtype: torch.dtype = torch.float): |
|
self.dtype = dtype |
|
self.image_size = image_size |
|
self.default_output_index = default_output_index |
|
self.output_list_func = output_list_func |
|
self.subrect = subrect |
|
self.pose_parameters = pose_parameters |
|
self.device = device |
|
self.module_loaders = module_loaders |
|
|
|
self.modules = None |
|
|
|
self.num_parameters = 0 |
|
for pose_parameter in self.pose_parameters: |
|
self.num_parameters += pose_parameter.get_arity() |
|
|
|
self.output_length = output_length |
|
|
|
def get_image_size(self) -> int: |
|
return self.image_size |
|
|
|
def get_modules(self): |
|
if self.modules is None: |
|
self.modules = {} |
|
for key in self.module_loaders: |
|
module = self.module_loaders[key]() |
|
self.modules[key] = module |
|
module.to(self.device) |
|
module.train(False) |
|
return self.modules |
|
|
|
def get_pose_parameter_groups(self) -> List[PoseParameterGroup]: |
|
return self.pose_parameters |
|
|
|
def get_num_parameters(self) -> int: |
|
return self.num_parameters |
|
|
|
def pose(self, image: Tensor, pose: Tensor, output_index: Optional[int] = None) -> Tensor: |
|
if output_index is None: |
|
output_index = self.default_output_index |
|
output_list = self.get_posing_outputs(image, pose) |
|
return output_list[output_index] |
|
|
|
def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]: |
|
modules = self.get_modules() |
|
|
|
if len(image.shape) == 3: |
|
image = image.unsqueeze(0) |
|
if len(pose.shape) == 1: |
|
pose = pose.unsqueeze(0) |
|
if self.subrect is not None: |
|
image = image[:, :, self.subrect[0][0]:self.subrect[0][1], self.subrect[1][0]:self.subrect[1][1]] |
|
batch = [image, pose] |
|
|
|
outputs = {} |
|
return self.output_list_func(modules, batch, outputs) |
|
|
|
def get_output_length(self) -> int: |
|
return self.output_length |
|
|
|
def free(self): |
|
self.modules = None |
|
|
|
def get_dtype(self) -> torch.dtype: |
|
return self.dtype |
|
|