| | import logging |
| |
|
| | from typing import Any, Dict |
| |
|
| | import numpy as np |
| | import torch |
| | from diffusers.image_processor import VaeImageProcessor |
| | from PIL import Image |
| | from torch import nn |
| |
|
| | logger: logging.Logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class LeffaTransform(nn.Module): |
| | def __init__( |
| | self, |
| | height: int = 1024, |
| | width: int = 768, |
| | dataset: str = "virtual_tryon", |
| | ): |
| | super().__init__() |
| |
|
| | self.height = height |
| | self.width = width |
| | self.dataset = dataset |
| |
|
| | self.vae_processor = VaeImageProcessor(vae_scale_factor=8) |
| | self.mask_processor = VaeImageProcessor( |
| | vae_scale_factor=8, |
| | do_normalize=False, |
| | do_binarize=True, |
| | do_convert_grayscale=True, |
| | ) |
| |
|
| | def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: |
| | batch_size = len(batch["src_image"]) |
| |
|
| | src_image_list = [] |
| | ref_image_list = [] |
| | mask_list = [] |
| | densepose_list = [] |
| | for i in range(batch_size): |
| | |
| | src_image = batch["src_image"][i] |
| | ref_image = batch["ref_image"][i] |
| | mask = batch["mask"][i] |
| | densepose = batch["densepose"][i] |
| |
|
| | |
| | src_image = self.vae_processor.preprocess( |
| | src_image, self.height, self.width)[0] |
| | ref_image = self.vae_processor.preprocess( |
| | ref_image, self.height, self.width)[0] |
| | mask = self.mask_processor.preprocess( |
| | mask, self.height, self.width)[0] |
| | if self.dataset in ["pose_transfer"]: |
| | densepose = densepose.resize( |
| | (self.width, self.height), Image.NEAREST) |
| | else: |
| | densepose = self.vae_processor.preprocess( |
| | densepose, self.height, self.width |
| | )[0] |
| |
|
| | src_image = self.prepare_image(src_image) |
| | ref_image = self.prepare_image(ref_image) |
| | mask = self.prepare_mask(mask) |
| | if self.dataset in ["pose_transfer"]: |
| | densepose = self.prepare_densepose(densepose) |
| | else: |
| | densepose = self.prepare_image(densepose) |
| |
|
| | src_image_list.append(src_image) |
| | ref_image_list.append(ref_image) |
| | mask_list.append(mask) |
| | densepose_list.append(densepose) |
| |
|
| | src_image = torch.cat(src_image_list, dim=0) |
| | ref_image = torch.cat(ref_image_list, dim=0) |
| | mask = torch.cat(mask_list, dim=0) |
| | densepose = torch.cat(densepose_list, dim=0) |
| |
|
| | batch["src_image"] = src_image |
| | batch["ref_image"] = ref_image |
| | batch["mask"] = mask |
| | batch["densepose"] = densepose |
| |
|
| | return batch |
| |
|
| | @staticmethod |
| | def prepare_image(image): |
| | if isinstance(image, torch.Tensor): |
| | |
| | if image.ndim == 3: |
| | image = image.unsqueeze(0) |
| | image = image.to(dtype=torch.float32) |
| | else: |
| | |
| | if isinstance(image, (Image.Image, np.ndarray)): |
| | image = [image] |
| | if isinstance(image, list) and isinstance(image[0], Image.Image): |
| | image = [np.array(i.convert("RGB"))[None, :] for i in image] |
| | image = np.concatenate(image, axis=0) |
| | elif isinstance(image, list) and isinstance(image[0], np.ndarray): |
| | image = np.concatenate([i[None, :] for i in image], axis=0) |
| | image = image.transpose(0, 3, 1, 2) |
| | image = torch.from_numpy(image).to( |
| | dtype=torch.float32) / 127.5 - 1.0 |
| | return image |
| |
|
| | @staticmethod |
| | def prepare_mask(mask): |
| | if isinstance(mask, torch.Tensor): |
| | if mask.ndim == 2: |
| | |
| | mask = mask.unsqueeze(0).unsqueeze(0) |
| | elif mask.ndim == 3 and mask.shape[0] == 1: |
| | |
| | |
| | mask = mask.unsqueeze(0) |
| | elif mask.ndim == 3 and mask.shape[0] != 1: |
| | |
| | |
| | mask = mask.unsqueeze(1) |
| |
|
| | |
| | mask[mask < 0.5] = 0 |
| | mask[mask >= 0.5] = 1 |
| | else: |
| | |
| | if isinstance(mask, (Image.Image, np.ndarray)): |
| | mask = [mask] |
| |
|
| | if isinstance(mask, list) and isinstance(mask[0], Image.Image): |
| | mask = np.concatenate( |
| | [np.array(m.convert("L"))[None, None, :] for m in mask], |
| | axis=0, |
| | ) |
| | mask = mask.astype(np.float32) / 255.0 |
| | elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): |
| | mask = np.concatenate([m[None, None, :] for m in mask], axis=0) |
| |
|
| | mask[mask < 0.5] = 0 |
| | mask[mask >= 0.5] = 1 |
| | mask = torch.from_numpy(mask) |
| |
|
| | return mask |
| |
|
| | @staticmethod |
| | def prepare_densepose(densepose): |
| | """ |
| | For internal (meta) densepose, the first and second channel should be normalized to 0~1 by 255.0, |
| | and the third channel should be normalized to 0~1 by 24.0 |
| | """ |
| | if isinstance(densepose, torch.Tensor): |
| | |
| | if densepose.ndim == 3: |
| | densepose = densepose.unsqueeze(0) |
| | densepose = densepose.to(dtype=torch.float32) |
| | else: |
| | |
| | if isinstance(densepose, (Image.Image, np.ndarray)): |
| | densepose = [densepose] |
| | if isinstance(densepose, list) and isinstance( |
| | densepose[0], Image.Image |
| | ): |
| | densepose = [np.array(i.convert("RGB"))[None, :] |
| | for i in densepose] |
| | densepose = np.concatenate(densepose, axis=0) |
| | elif isinstance(densepose, list) and isinstance(densepose[0], np.ndarray): |
| | densepose = np.concatenate( |
| | [i[None, :] for i in densepose], axis=0) |
| | densepose = densepose.transpose(0, 3, 1, 2) |
| | densepose = densepose.astype(np.float32) |
| | densepose[:, 0:2, :, :] /= 255.0 |
| | densepose[:, 2:3, :, :] /= 24.0 |
| | densepose = torch.from_numpy(densepose).to( |
| | dtype=torch.float32) * 2.0 - 1.0 |
| | return densepose |
| |
|