newTryOn / utils /image_utils.py
amanSethSmava
new commit
6d314be
import os
import subprocess
import tempfile
from pathlib import Path
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision.transforms import transforms
from torchvision.utils import save_image
from models.Net import get_segmentation
def equal_replacer(images: list[torch.Tensor]) -> list[torch.Tensor]:
for i in range(len(images)):
if images[i].dtype is torch.uint8:
images[i] = images[i] / 255
for i in range(len(images)):
for j in range(i + 1, len(images)):
if torch.allclose(images[i], images[j]):
images[j] = images[i]
return images
class DilateErosion:
def __init__(self, dilate_erosion=5, device='cuda'):
self.dilate_erosion = dilate_erosion
self.weight = torch.Tensor([
[False, True, False],
[True, True, True],
[False, True, False]
]).float()[None, None, ...].to(device)
def hair_from_mask(self, mask):
mask = torch.where(mask == 13, torch.ones_like(mask), torch.zeros_like(mask))
mask = F.interpolate(mask, size=(256, 256), mode='nearest')
dilate, erosion = self.mask(mask)
return dilate, erosion
def mask(self, mask):
masks = mask.clone().repeat(*([2] + [1] * (len(mask.shape) - 1))).float()
sum_w = self.weight.sum().item()
n = len(mask)
for _ in range(self.dilate_erosion):
masks = F.conv2d(masks, self.weight,
bias=None, stride=1, padding='same', dilation=1, groups=1)
masks[:n] = (masks[:n] > 0).float()
masks[n:] = (masks[n:] == sum_w).float()
hair_mask_dilate, hair_mask_erode = masks[:n], masks[n:]
return hair_mask_dilate, hair_mask_erode
def poisson_image_blending(final_image, face_image, dilate_erosion=30, maxn=115):
dilate_erosion = DilateErosion(dilate_erosion=dilate_erosion)
transform = transforms.ToTensor()
if isinstance(face_image, str):
face_image = transform(Image.open(face_image))
elif not isinstance(face_image, torch.Tensor):
face_image = transform(face_image)
final_mask = get_segmentation(final_image.cuda().unsqueeze(0), resize=False)
face_mask = get_segmentation(face_image.cuda().unsqueeze(0), resize=False)
hair_target = torch.where(final_mask == 13, torch.ones_like(final_mask),
torch.zeros_like(final_mask))
hair_face = torch.where(face_mask == 13, torch.ones_like(face_mask),
torch.zeros_like(face_mask))
final_mask = F.interpolate(((1 - hair_target) * (1 - hair_face)).float(), size=(1024, 1024), mode='bicubic')
dilation, _ = dilate_erosion.mask(1 - final_mask)
mask_save = 1 - dilation[0]
with tempfile.TemporaryDirectory() as temp_dir:
final_image_path = os.path.join(temp_dir, 'final_image.png')
face_image_path = os.path.join(temp_dir, 'face_image.png')
mask_path = os.path.join(temp_dir, 'mask_save.png')
save_image(final_image, final_image_path)
save_image(face_image, face_image_path)
save_image(mask_save, mask_path)
out_image_path = os.path.join(temp_dir, 'out_image_path.png')
result = subprocess.run(
["fpie", "-s", face_image_path, "-m", mask_path, "-t", final_image_path, "-o", out_image_path, "-n",
str(maxn), "-b", "taichi-gpu", "-g", "max"],
check=True
)
return Image.open(out_image_path), Image.open(mask_path)
def list_image_files(directory):
image_extensions = ['.jpg', '.jpeg', '.png']
image_files = []
for entry in sorted(os.listdir(directory)):
file_path = os.path.join(directory, entry)
if os.path.isfile(file_path):
file_extension = Path(file_path).suffix.lower()
if file_extension in image_extensions:
image_files.append(entry)
return image_files