haakohu's picture
initial
5d756f1
from pathlib import Path
from typing import Dict, List
import torchvision
import torch
import tops
import torchvision.transforms.functional as F
from .functional import hflip
import numpy as np
from dp2.utils.vis_utils import get_coco_keypoints
from PIL import Image, ImageDraw
from typing import Tuple
class RandomHorizontalFlip(torch.nn.Module):
def __init__(self, p: float, flip_map=None, **kwargs):
super().__init__()
self.flip_ratio = p
self.flip_map = flip_map
if self.flip_ratio is None:
self.flip_ratio = 0.5
assert 0 <= self.flip_ratio <= 1
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
if torch.rand(1) > self.flip_ratio:
return container
return hflip(container, self.flip_map)
class CenterCrop(torch.nn.Module):
"""
Performs the transform on the image.
NOTE: Does not transform the mask to improve runtime.
"""
def __init__(self, size: List[int]):
super().__init__()
self.size = tuple(size)
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
min_size = min(container["img"].shape[1], container["img"].shape[2])
if min_size < self.size[0]:
container["img"] = F.center_crop(container["img"], min_size)
container["img"] = F.resize(container["img"], self.size)
return container
container["img"] = F.center_crop(container["img"], self.size)
return container
class Resize(torch.nn.Module):
"""
Performs the transform on the image.
NOTE: Does not transform the mask to improve runtime.
"""
def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR):
super().__init__()
self.size = tuple(size)
self.interpolation = interpolation
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
if "semantic_mask" in container:
container["semantic_mask"] = F.resize(
container["semantic_mask"], self.size, F.InterpolationMode.NEAREST)
if "embedding" in container:
container["embedding"] = F.resize(
container["embedding"], self.size, self.interpolation)
if "mask" in container:
container["mask"] = F.resize(
container["mask"], self.size, F.InterpolationMode.NEAREST)
if "E_mask" in container:
container["E_mask"] = F.resize(
container["E_mask"], self.size, F.InterpolationMode.NEAREST)
if "maskrcnn_mask" in container:
container["maskrcnn_mask"] = F.resize(
container["maskrcnn_mask"], self.size, F.InterpolationMode.NEAREST)
if "vertices" in container:
container["vertices"] = F.resize(
container["vertices"], self.size, F.InterpolationMode.NEAREST)
return container
def __repr__(self):
repr = super().__repr__()
vars_ = dict(size=self.size, interpolation=self.interpolation)
return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
class Normalize(torch.nn.Module):
"""
Performs the transform on the image.
NOTE: Does not transform the mask to improve runtime.
"""
def __init__(self, mean, std, inplace, keys=["img"]):
super().__init__()
self.mean = mean
self.std = std
self.inplace = inplace
self.keys = keys
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
for key in self.keys:
container[key] = F.normalize(container[key], self.mean, self.std, self.inplace)
return container
def __repr__(self):
repr = super().__repr__()
vars_ = dict(mean=self.mean, std=self.std, inplace=self.inplace)
return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
class ToFloat(torch.nn.Module):
def __init__(self, keys=["img"], norm=True) -> None:
super().__init__()
self.keys = keys
self.gain = 255 if norm else 1
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
for key in self.keys:
container[key] = container[key].float() / self.gain
return container
class RandomCrop(torchvision.transforms.RandomCrop):
"""
Performs the transform on the image.
NOTE: Does not transform the mask to improve runtime.
"""
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
container["img"] = super().forward(container["img"])
return container
class CreateCondition(torch.nn.Module):
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
if container["img"].dtype == torch.uint8:
container["condition"] = container["img"] * container["mask"].byte() + (1-container["mask"].byte()) * 127
return container
container["condition"] = container["img"] * container["mask"]
return container
class CreateEmbedding(torch.nn.Module):
def __init__(self, embed_path: Path, cuda=True) -> None:
super().__init__()
self.embed_map = torch.load(embed_path, map_location=torch.device("cpu"))
if cuda:
self.embed_map = tops.to_cuda(self.embed_map)
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
vertices = container["vertices"]
if vertices.ndim == 3:
embedding = self.embed_map[vertices.long()].squeeze(dim=0)
embedding = embedding.permute(2, 0, 1) * container["E_mask"]
pass
else:
assert vertices.ndim == 4
embedding = self.embed_map[vertices.long()].squeeze(dim=1)
embedding = embedding.permute(0, 3, 1, 2) * container["E_mask"]
container["embedding"] = embedding
container["embed_map"] = self.embed_map.clone()
return container
class InsertJointMap(torch.nn.Module):
def __init__(self, imsize: Tuple) -> None:
super().__init__()
self.imsize = imsize
knames = get_coco_keypoints()[0]
knames = knames + ["neck", "mid_hip"]
connectivity = {
"nose": ["left_eye", "right_eye", "neck"],
"left_eye": ["right_eye", "left_ear"],
"right_eye": ["right_ear"],
"left_shoulder": ["right_shoulder", "left_elbow", "left_hip"],
"right_shoulder": ["right_elbow", "right_hip"],
"left_elbow": ["left_wrist"],
"right_elbow": ["right_wrist"],
"left_hip": ["right_hip", "left_knee"],
"right_hip": ["right_knee"],
"left_knee": ["left_ankle"],
"right_knee": ["right_ankle"],
"neck": ["mid_hip", "nose"],
}
category = {
("nose", "left_eye"): 0, # head
("nose", "right_eye"): 0, # head
("nose", "neck"): 0, # head
("left_eye", "right_eye"): 0, # head
("left_eye", "left_ear"): 0, # head
("right_eye", "right_ear"): 0, # head
("left_shoulder", "left_elbow"): 1, # left arm
("left_elbow", "left_wrist"): 1, # left arm
("right_shoulder", "right_elbow"): 2, # right arm
("right_elbow", "right_wrist"): 2, # right arm
("left_shoulder", "right_shoulder"): 3, # body
("left_shoulder", "left_hip"): 3, # body
("right_shoulder", "right_hip"): 3, # body
("left_hip", "right_hip"): 3, # body
("left_hip", "left_knee"): 4, # left leg
("left_knee", "left_ankle"): 4, # left leg
("right_hip", "right_knee"): 5, # right leg
("right_knee", "right_ankle"): 5, # right leg
("neck", "mid_hip"): 3, # body
("neck", "nose"): 0, # head
}
self.indices2category = {
tuple([knames.index(n) for n in k]): v for k, v in category.items()
}
self.connectivity_indices = {
knames.index(k): [knames.index(v_) for v_ in v]
for k, v in connectivity.items()
}
self.l_shoulder = knames.index("left_shoulder")
self.r_shoulder = knames.index("right_shoulder")
self.l_hip = knames.index("left_hip")
self.r_hip = knames.index("right_hip")
self.l_eye = knames.index("left_eye")
self.r_eye = knames.index("right_eye")
self.nose = knames.index("nose")
self.neck = knames.index("neck")
def create_joint_map(self, N, H, W, keypoints):
joint_maps = np.zeros((N, H, W), dtype=np.uint8)
for bidx, keypoints in enumerate(keypoints):
assert keypoints.shape == (17, 3), keypoints.shape
keypoints = torch.cat((keypoints, torch.zeros(2, 3)))
visible = keypoints[:, -1] > 0
if visible[self.l_shoulder] and visible[self.r_shoulder]:
neck = (keypoints[self.l_shoulder]
+ (keypoints[self.r_shoulder] - keypoints[self.l_shoulder]) / 2)
keypoints[-2] = neck
visible[-2] = 1
if visible[self.l_hip] and visible[self.r_hip]:
mhip = (keypoints[self.l_hip]
+ (keypoints[self.r_hip] - keypoints[self.l_hip]) / 2
)
keypoints[-1] = mhip
visible[-1] = 1
keypoints[:, 0] *= W
keypoints[:, 1] *= H
joint_map = Image.fromarray(np.zeros((H, W), dtype=np.uint8))
draw = ImageDraw.Draw(joint_map)
for fidx in self.connectivity_indices.keys():
for tidx in self.connectivity_indices[fidx]:
if visible[fidx] == 0 or visible[tidx] == 0:
continue
c = self.indices2category[(fidx, tidx)]
s = tuple(keypoints[fidx, :2].round().long().numpy().tolist())
e = tuple(keypoints[tidx, :2].round().long().numpy().tolist())
draw.line((s, e), width=1, fill=c + 1)
if visible[self.nose] == 0 and visible[self.neck] == 1:
m_eye = (
keypoints[self.l_eye]
+ (keypoints[self.r_eye] - keypoints[self.l_eye]) / 2
)
s = tuple(m_eye[:2].round().long().numpy().tolist())
e = tuple(keypoints[self.neck, :2].round().long().numpy().tolist())
c = self.indices2category[(self.nose, self.neck)]
draw.line((s, e), width=1, fill=c + 1)
joint_map = np.array(joint_map)
joint_maps[bidx] = np.array(joint_map)
return joint_maps[:, None]
def forward(self, batch):
batch["joint_map"] = torch.from_numpy(self.create_joint_map(
batch["img"].shape[0], *self.imsize, batch["keypoints"]))
return batch