import logging from typing import Any, Dict import numpy as np import torch from diffusers.image_processor import VaeImageProcessor from PIL import Image from torch import nn logger: logging.Logger = logging.getLogger(__name__) class LeffaTransform(nn.Module): def __init__( self, height: int = 1024, width: int = 768, dataset: str = "virtual_tryon", # virtual_tryon or pose_transfer ): super().__init__() self.height = height self.width = width self.dataset = dataset self.vae_processor = VaeImageProcessor(vae_scale_factor=8) self.mask_processor = VaeImageProcessor( vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True, ) def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: batch_size = len(batch["src_image"]) src_image_list = [] ref_image_list = [] mask_list = [] densepose_list = [] for i in range(batch_size): # 1. get original data src_image = batch["src_image"][i] ref_image = batch["ref_image"][i] mask = batch["mask"][i] densepose = batch["densepose"][i] # 3. process data src_image = self.vae_processor.preprocess( src_image, self.height, self.width)[0] ref_image = self.vae_processor.preprocess( ref_image, self.height, self.width)[0] mask = self.mask_processor.preprocess( mask, self.height, self.width)[0] if self.dataset in ["pose_transfer"]: densepose = densepose.resize( (self.width, self.height), Image.NEAREST) else: densepose = self.vae_processor.preprocess( densepose, self.height, self.width )[0] src_image = self.prepare_image(src_image) ref_image = self.prepare_image(ref_image) mask = self.prepare_mask(mask) if self.dataset in ["pose_transfer"]: densepose = self.prepare_densepose(densepose) else: densepose = self.prepare_image(densepose) src_image_list.append(src_image) ref_image_list.append(ref_image) mask_list.append(mask) densepose_list.append(densepose) src_image = torch.cat(src_image_list, dim=0) ref_image = torch.cat(ref_image_list, dim=0) mask = torch.cat(mask_list, dim=0) densepose = torch.cat(densepose_list, dim=0) batch["src_image"] = src_image batch["ref_image"] = ref_image batch["mask"] = mask batch["densepose"] = densepose return batch @staticmethod def prepare_image(image): if isinstance(image, torch.Tensor): # Batch single image if image.ndim == 3: image = image.unsqueeze(0) image = image.to(dtype=torch.float32) else: # preprocess image if isinstance(image, (Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], Image.Image): image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to( dtype=torch.float32) / 127.5 - 1.0 return image @staticmethod def prepare_mask(mask): if isinstance(mask, torch.Tensor): if mask.ndim == 2: # Batch and add channel dim for single mask mask = mask.unsqueeze(0).unsqueeze(0) elif mask.ndim == 3 and mask.shape[0] == 1: # Single mask, the 0'th dimension is considered to be # the existing batch size of 1 mask = mask.unsqueeze(0) elif mask.ndim == 3 and mask.shape[0] != 1: # Batch of mask, the 0'th dimension is considered to be # the batching dimension mask = mask.unsqueeze(1) # Binarize mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 else: # preprocess mask if isinstance(mask, (Image.Image, np.ndarray)): mask = [mask] if isinstance(mask, list) and isinstance(mask[0], Image.Image): mask = np.concatenate( [np.array(m.convert("L"))[None, None, :] for m in mask], axis=0, ) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): mask = np.concatenate([m[None, None, :] for m in mask], axis=0) mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) return mask @staticmethod def prepare_densepose(densepose): """ For internal (meta) densepose, the first and second channel should be normalized to 0~1 by 255.0, and the third channel should be normalized to 0~1 by 24.0 """ if isinstance(densepose, torch.Tensor): # Batch single densepose if densepose.ndim == 3: densepose = densepose.unsqueeze(0) densepose = densepose.to(dtype=torch.float32) else: # preprocess densepose if isinstance(densepose, (Image.Image, np.ndarray)): densepose = [densepose] if isinstance(densepose, list) and isinstance( densepose[0], Image.Image ): densepose = [np.array(i.convert("RGB"))[None, :] for i in densepose] densepose = np.concatenate(densepose, axis=0) elif isinstance(densepose, list) and isinstance(densepose[0], np.ndarray): densepose = np.concatenate( [i[None, :] for i in densepose], axis=0) densepose = densepose.transpose(0, 3, 1, 2) densepose = densepose.astype(np.float32) densepose[:, 0:2, :, :] /= 255.0 densepose[:, 2:3, :, :] /= 24.0 densepose = torch.from_numpy(densepose).to( dtype=torch.float32) * 2.0 - 1.0 return densepose