TomatoCocotree
上传
6a62ffb
from abc import ABC, abstractmethod
from enum import Enum
from typing import Tuple, List, Optional
import torch
from torch import Tensor
class PoseParameterCategory(Enum):
EYEBROW = 1
EYE = 2
IRIS_MORPH = 3
IRIS_ROTATION = 4
MOUTH = 5
FACE_ROTATION = 6
BODY_ROTATION = 7
BREATHING = 8
class PoseParameterGroup:
def __init__(self,
group_name: str,
parameter_index: int,
category: PoseParameterCategory,
arity: int = 1,
discrete: bool = False,
default_value: float = 0.0,
range: Optional[Tuple[float, float]] = None):
assert arity == 1 or arity == 2
if range is None:
range = (0.0, 1.0)
if arity == 1:
parameter_names = [group_name]
else:
parameter_names = [group_name + "_left", group_name + "_right"]
assert len(parameter_names) == arity
self.parameter_names = parameter_names
self.range = range
self.default_value = default_value
self.discrete = discrete
self.arity = arity
self.category = category
self.parameter_index = parameter_index
self.group_name = group_name
def get_arity(self) -> int:
return self.arity
def get_group_name(self) -> str:
return self.group_name
def get_parameter_names(self) -> List[str]:
return self.parameter_names
def is_discrete(self) -> bool:
return self.discrete
def get_range(self) -> Tuple[float, float]:
return self.range
def get_default_value(self):
return self.default_value
def get_parameter_index(self):
return self.parameter_index
def get_category(self) -> PoseParameterCategory:
return self.category
class PoseParameters:
def __init__(self, pose_parameter_groups: List[PoseParameterGroup]):
self.pose_parameter_groups = pose_parameter_groups
def get_parameter_index(self, name: str) -> int:
index = 0
for parameter_group in self.pose_parameter_groups:
for param_name in parameter_group.parameter_names:
if name == param_name:
return index
index += 1
raise RuntimeError("Cannot find parameter with name %s" % name)
def get_parameter_name(self, index: int) -> str:
assert index >= 0 and index < self.get_parameter_count()
for group in self.pose_parameter_groups:
if index < group.get_arity():
return group.get_parameter_names()[index]
index -= group.arity
raise RuntimeError("Something is wrong here!!!")
def get_pose_parameter_groups(self):
return self.pose_parameter_groups
def get_parameter_count(self):
count = 0
for group in self.pose_parameter_groups:
count += group.arity
return count
class Builder:
def __init__(self):
self.index = 0
self.pose_parameter_groups = []
def add_parameter_group(self,
group_name: str,
category: PoseParameterCategory,
arity: int = 1,
discrete: bool = False,
default_value: float = 0.0,
range: Optional[Tuple[float, float]] = None):
self.pose_parameter_groups.append(
PoseParameterGroup(
group_name,
self.index,
category,
arity,
discrete,
default_value,
range))
self.index += arity
return self
def build(self) -> 'PoseParameters':
return PoseParameters(self.pose_parameter_groups)
class Poser(ABC):
@abstractmethod
def get_image_size(self) -> int:
pass
@abstractmethod
def get_output_length(self) -> int:
pass
@abstractmethod
def get_pose_parameter_groups(self) -> List[PoseParameterGroup]:
pass
@abstractmethod
def get_num_parameters(self) -> int:
pass
@abstractmethod
def pose(self, image: Tensor, pose: Tensor, output_index: int = 0) -> Tensor:
pass
@abstractmethod
def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]:
pass
def get_dtype(self) -> torch.dtype:
return torch.float