|
import os
|
|
|
|
import math
|
|
import PIL
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
from accelerate.state import AcceleratorState
|
|
from packaging import version
|
|
import accelerate
|
|
from typing import List, Optional, Tuple
|
|
from torch.nn import functional as F
|
|
from diffusers import UNet2DConditionModel, SchedulerMixin
|
|
|
|
|
|
def compute_dream_and_update_latents_for_inpaint(
|
|
unet: UNet2DConditionModel,
|
|
noise_scheduler: SchedulerMixin,
|
|
timesteps: torch.Tensor,
|
|
noise: torch.Tensor,
|
|
noisy_latents: torch.Tensor,
|
|
target: torch.Tensor,
|
|
encoder_hidden_states: torch.Tensor,
|
|
dream_detail_preservation: float = 1.0,
|
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
"""
|
|
Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from http://arxiv.org/abs/2312.00210.
|
|
DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra
|
|
forward step without gradients.
|
|
|
|
Args:
|
|
`unet`: The state unet to use to make a prediction.
|
|
`noise_scheduler`: The noise scheduler used to add noise for the given timestep.
|
|
`timesteps`: The timesteps for the noise_scheduler to user.
|
|
`noise`: A tensor of noise in the shape of noisy_latents.
|
|
`noisy_latents`: Previously noise latents from the training loop.
|
|
`target`: The ground-truth tensor to predict after eps is removed.
|
|
`encoder_hidden_states`: Text embeddings from the text model.
|
|
`dream_detail_preservation`: A float value that indicates detail preservation level.
|
|
See reference.
|
|
|
|
Returns:
|
|
`tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target.
|
|
"""
|
|
alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None]
|
|
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
|
|
|
|
|
dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation
|
|
|
|
pred = None
|
|
with torch.no_grad():
|
|
pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
|
|
|
noisy_latents_no_condition = noisy_latents[:, :4]
|
|
_noisy_latents, _target = (None, None)
|
|
if noise_scheduler.config.prediction_type == "epsilon":
|
|
predicted_noise = pred
|
|
delta_noise = (noise - predicted_noise).detach()
|
|
delta_noise.mul_(dream_lambda)
|
|
_noisy_latents = noisy_latents_no_condition.add(sqrt_one_minus_alphas_cumprod * delta_noise)
|
|
_target = target.add(delta_noise)
|
|
elif noise_scheduler.config.prediction_type == "v_prediction":
|
|
raise NotImplementedError("DREAM has not been implemented for v-prediction")
|
|
else:
|
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
|
|
|
_noisy_latents = torch.cat([_noisy_latents, noisy_latents[:, 4:]], dim=1)
|
|
return _noisy_latents, _target
|
|
|
|
|
|
def prepare_inpainting_input(
|
|
noisy_latents: torch.Tensor,
|
|
mask_latents: torch.Tensor,
|
|
condition_latents: torch.Tensor,
|
|
enable_condition_noise: bool = True,
|
|
condition_concat_dim: int = -1,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Prepare the input for inpainting model.
|
|
|
|
Args:
|
|
noisy_latents (torch.Tensor): Noisy latents.
|
|
mask_latents (torch.Tensor): Mask latents.
|
|
condition_latents (torch.Tensor): Condition latents.
|
|
enable_condition_noise (bool): Enable condition noise.
|
|
|
|
Returns:
|
|
torch.Tensor: Inpainting input.
|
|
"""
|
|
if not enable_condition_noise:
|
|
condition_latents_ = condition_latents.chunk(2, dim=condition_concat_dim)[-1]
|
|
noisy_latents = torch.cat([noisy_latents, condition_latents_], dim=condition_concat_dim)
|
|
noisy_latents = torch.cat([noisy_latents, mask_latents, condition_latents], dim=1)
|
|
return noisy_latents
|
|
|
|
|
|
def compute_vae_encodings(image: torch.Tensor, vae: torch.nn.Module) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
images (torch.Tensor): image to be encoded
|
|
vae (torch.nn.Module): vae model
|
|
|
|
Returns:
|
|
torch.Tensor: latent encoding of the image
|
|
"""
|
|
pixel_values = image.to(memory_format=torch.contiguous_format).float()
|
|
pixel_values = pixel_values.to(vae.device, dtype=vae.dtype)
|
|
with torch.no_grad():
|
|
model_input = vae.encode(pixel_values).latent_dist.sample()
|
|
model_input = model_input * vae.config.scaling_factor
|
|
return model_input
|
|
|
|
|
|
|
|
from accelerate import Accelerator, DistributedDataParallelKwargs
|
|
from accelerate.utils import ProjectConfiguration
|
|
|
|
def init_accelerator(config):
|
|
accelerator_project_config = ProjectConfiguration(
|
|
project_dir=config.project_name,
|
|
logging_dir=os.path.join(config.project_name, "logs"),
|
|
)
|
|
accelerator_ddp_config = DistributedDataParallelKwargs(find_unused_parameters=True)
|
|
accelerator = Accelerator(
|
|
mixed_precision=config.mixed_precision,
|
|
log_with=config.report_to,
|
|
project_config=accelerator_project_config,
|
|
kwargs_handlers=[accelerator_ddp_config],
|
|
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
|
)
|
|
|
|
if torch.backends.mps.is_available():
|
|
accelerator.native_amp = False
|
|
|
|
if accelerator.is_main_process:
|
|
accelerator.init_trackers(
|
|
project_name=config.project_name,
|
|
config={
|
|
"learning_rate": config.learning_rate,
|
|
"train_batch_size": config.train_batch_size,
|
|
"image_size": f"{config.width}x{config.height}",
|
|
},
|
|
)
|
|
|
|
return accelerator
|
|
|
|
|
|
def init_weight_dtype(wight_dtype):
|
|
return {
|
|
"no": torch.float32,
|
|
"fp16": torch.float16,
|
|
"bf16": torch.bfloat16,
|
|
}[wight_dtype]
|
|
|
|
|
|
def init_add_item_id(config):
|
|
return torch.tensor(
|
|
[
|
|
config.height,
|
|
config.width * 2,
|
|
0,
|
|
0,
|
|
config.height,
|
|
config.width * 2,
|
|
]
|
|
).repeat(config.train_batch_size, 1)
|
|
|
|
|
|
def prepare_eval_data(dataset_root, dataset_name, is_pair=True):
|
|
assert dataset_name in ["vitonhd", "dresscode", "farfetch"], "Unknown dataset name {}.".format(dataset_name)
|
|
if dataset_name == "vitonhd":
|
|
data_root = os.path.join(dataset_root, "VITONHD-1024", "test")
|
|
if is_pair:
|
|
keys = os.listdir(os.path.join(data_root, "Images"))
|
|
cloth_image_paths = [
|
|
os.path.join(data_root, "Images", key, key + "-0.jpg") for key in keys
|
|
]
|
|
person_image_paths = [
|
|
os.path.join(data_root, "Images", key, key + "-1.jpg") for key in keys
|
|
]
|
|
else:
|
|
|
|
cloth_image_paths = []
|
|
person_image_paths = []
|
|
with open(
|
|
os.path.join(dataset_root, "VITONHD-1024", "test_pairs.txt"), "r"
|
|
) as f:
|
|
lines = f.readlines()
|
|
for line in lines:
|
|
cloth_image, person_image = (
|
|
line.replace(".jpg", "").strip().split(" ")
|
|
)
|
|
cloth_image_paths.append(
|
|
os.path.join(
|
|
data_root, "Images", cloth_image, cloth_image + "-0.jpg"
|
|
)
|
|
)
|
|
person_image_paths.append(
|
|
os.path.join(
|
|
data_root, "Images", person_image, person_image + "-1.jpg"
|
|
)
|
|
)
|
|
elif dataset_name == "dresscode":
|
|
data_root = os.path.join(dataset_root, "DressCode-1024")
|
|
if is_pair:
|
|
part = ["lower", "lower", "upper", "upper", "dresses", "dresses"]
|
|
ids = ["013581", "051685", "000190", "050072", "020829", "053742"]
|
|
cloth_image_paths = [
|
|
os.path.join(data_root, "Images", part[i], ids[i], ids[i] + "_1.jpg")
|
|
for i in range(len(part))
|
|
]
|
|
person_image_paths = [
|
|
os.path.join(data_root, "Images", part[i], ids[i], ids[i] + "_0.jpg")
|
|
for i in range(len(part))
|
|
]
|
|
else:
|
|
raise ValueError("DressCode dataset does not support non-pair evaluation.")
|
|
elif dataset_name == "farfetch":
|
|
data_root = os.path.join(dataset_root, "FARFETCH-1024")
|
|
cloth_image_paths = [
|
|
|
|
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Tops/Blouses/13732751/13732751-2.jpg",
|
|
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Tops/Hoodies/14661627/14661627-4.jpg",
|
|
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Tops/Vests & Tank Tops/16532697/16532697-4.jpg",
|
|
"Images/men/Pants/Loose Fit Pants/14750720/14750720-6.jpg",
|
|
|
|
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Tops/Shirts/10889688/10889688-3.jpg",
|
|
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Shorts/Leather & Faux Leather Shorts/20143338/20143338-1.jpg",
|
|
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Jackets/Blazers/15541224/15541224-2.jpg",
|
|
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/men/Polo Shirts/Polo Shirts/17652415/17652415-0.jpg"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
]
|
|
person_image_paths = [
|
|
|
|
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Tops/Blouses/13732751/13732751-0.jpg",
|
|
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Tops/Hoodies/14661627/14661627-2.jpg",
|
|
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Tops/Vests & Tank Tops/16532697/16532697-1.jpg",
|
|
"Images/men/Pants/Loose Fit Pants/14750720/14750720-5.jpg",
|
|
|
|
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Tops/Shirts/10889688/10889688-1.jpg",
|
|
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Shorts/Leather & Faux Leather Shorts/20143338/20143338-2.jpg",
|
|
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/women/Jackets/Blazers/15541224/15541224-0.jpg",
|
|
"/home/chongzheng/Projects/hivton/Datasets/FARFETCH-1024/Images/men/Polo Shirts/Polo Shirts/17652415/17652415-4.jpg",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
]
|
|
cloth_image_paths = [
|
|
os.path.join(data_root, path) for path in cloth_image_paths
|
|
]
|
|
person_image_paths = [
|
|
os.path.join(data_root, path) for path in person_image_paths
|
|
]
|
|
else:
|
|
raise ValueError(f"Unknown dataset name: {dataset_name}")
|
|
|
|
samples = [
|
|
{
|
|
"folder": os.path.basename(os.path.dirname(cloth_image)),
|
|
"cloth": cloth_image,
|
|
"person": person_image,
|
|
}
|
|
for cloth_image, person_image in zip(
|
|
cloth_image_paths, person_image_paths
|
|
)
|
|
]
|
|
return samples
|
|
|
|
|
|
def repaint_result(result, person_image, mask_image):
|
|
result, person, mask = np.array(result), np.array(person_image), np.array(mask_image)
|
|
|
|
mask = np.expand_dims(mask, axis=2)
|
|
mask = mask / 255.0
|
|
|
|
result_ = result * mask + person * (1 - mask)
|
|
return Image.fromarray(result_.astype(np.uint8))
|
|
|
|
|
|
|
|
def sobel(batch_image, mask=None, scale=4.0):
|
|
"""
|
|
计算输入批量图像的Sobel梯度.
|
|
|
|
batch_image: 输入的批量图像张量,大小为 [batch, channels, height, width]
|
|
"""
|
|
w, h = batch_image.size(3), batch_image.size(2)
|
|
pool_kernel = (max(w, h) // 16) * 2 + 1
|
|
|
|
kernel_x = (
|
|
torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32)
|
|
.view(1, 1, 3, 3)
|
|
.to(batch_image.device)
|
|
.repeat(1, batch_image.size(1), 1, 1)
|
|
)
|
|
kernel_y = (
|
|
torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32)
|
|
.view(1, 1, 3, 3)
|
|
.to(batch_image.device)
|
|
.repeat(1, batch_image.size(1), 1, 1)
|
|
)
|
|
|
|
grad_x = torch.zeros_like(batch_image)
|
|
grad_y = torch.zeros_like(batch_image)
|
|
|
|
batch_image = F.pad(batch_image, (1, 1, 1, 1), mode="reflect")
|
|
|
|
grad_x = F.conv2d(batch_image, kernel_x, padding=0)
|
|
grad_y = F.conv2d(batch_image, kernel_y, padding=0)
|
|
|
|
grad_magnitude = torch.sqrt(grad_x.pow(2) + grad_y.pow(2))
|
|
|
|
if mask is not None:
|
|
grad_magnitude = grad_magnitude * mask
|
|
|
|
grad_magnitude = torch.clamp(grad_magnitude, 0.2, 2.5)
|
|
|
|
grad_magnitude = F.avg_pool2d(
|
|
grad_magnitude, kernel_size=pool_kernel, stride=1, padding=pool_kernel // 2
|
|
)
|
|
|
|
grad_magnitude = (grad_magnitude / grad_magnitude.max()) * scale
|
|
return grad_magnitude
|
|
|
|
|
|
|
|
def sobel_aug_squared_error(x, y, reference, mask=None, reduction="mean"):
|
|
"""
|
|
计算x,y的逐元素平方误差,其中x和y是图像张量.
|
|
然后利用 x 的 sobel 结果作为权重,计算加权平方误差.
|
|
x: Tensor, shape [batch, channels, height, width]
|
|
y: Tensor, shape [batch, channels, height, width]
|
|
"""
|
|
ref_sobel = sobel(reference, mask=mask)
|
|
if ref_sobel.isnan().any():
|
|
print("Error: NaN Sobel Gradient")
|
|
loss = F.mse_loss(x, y, reduction="mean")
|
|
else:
|
|
squared_error = (x - y).pow(2)
|
|
weighted_squared_error = squared_error * ref_sobel
|
|
if reduction == "mean":
|
|
loss = weighted_squared_error.mean()
|
|
elif reduction == "sum":
|
|
loss = weighted_squared_error.sum()
|
|
elif reduction == "none":
|
|
loss = weighted_squared_error
|
|
|
|
return loss
|
|
|
|
|
|
|
|
def prepare_image(image):
|
|
if isinstance(image, torch.Tensor):
|
|
|
|
if image.ndim == 3:
|
|
image = image.unsqueeze(0)
|
|
image = image.to(dtype=torch.float32)
|
|
else:
|
|
|
|
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
|
image = [image]
|
|
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
|
|
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
|
image = np.concatenate(image, axis=0)
|
|
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
|
image = np.concatenate([i[None, :] for i in image], axis=0)
|
|
image = image.transpose(0, 3, 1, 2)
|
|
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
|
return image
|
|
|
|
|
|
def prepare_mask_image(mask_image):
|
|
if isinstance(mask_image, torch.Tensor):
|
|
if mask_image.ndim == 2:
|
|
|
|
mask_image = mask_image.unsqueeze(0).unsqueeze(0)
|
|
elif mask_image.ndim == 3 and mask_image.shape[0] == 1:
|
|
|
|
|
|
mask_image = mask_image.unsqueeze(0)
|
|
elif mask_image.ndim == 3 and mask_image.shape[0] != 1:
|
|
|
|
|
|
mask_image = mask_image.unsqueeze(1)
|
|
|
|
|
|
mask_image[mask_image < 0.5] = 0
|
|
mask_image[mask_image >= 0.5] = 1
|
|
else:
|
|
|
|
if isinstance(mask_image, (PIL.Image.Image, np.ndarray)):
|
|
mask_image = [mask_image]
|
|
|
|
if isinstance(mask_image, list) and isinstance(mask_image[0], PIL.Image.Image):
|
|
mask_image = np.concatenate(
|
|
[np.array(m.convert("L"))[None, None, :] for m in mask_image], axis=0
|
|
)
|
|
mask_image = mask_image.astype(np.float32) / 255.0
|
|
elif isinstance(mask_image, list) and isinstance(mask_image[0], np.ndarray):
|
|
mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)
|
|
|
|
mask_image[mask_image < 0.5] = 0
|
|
mask_image[mask_image >= 0.5] = 1
|
|
mask_image = torch.from_numpy(mask_image)
|
|
|
|
return mask_image
|
|
|
|
|
|
def numpy_to_pil(images):
|
|
"""
|
|
Convert a numpy image or a batch of images to a PIL image.
|
|
"""
|
|
if images.ndim == 3:
|
|
images = images[None, ...]
|
|
images = (images * 255).round().astype("uint8")
|
|
if images.shape[-1] == 1:
|
|
|
|
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
|
else:
|
|
pil_images = [Image.fromarray(image) for image in images]
|
|
|
|
return pil_images
|
|
|
|
|
|
def load_eval_image_pairs(root, mode="logo"):
|
|
|
|
test_name = "test"
|
|
person_image_paths = [
|
|
os.path.join(root, test_name, "image", _)
|
|
for _ in os.listdir(os.path.join(root, test_name, "image"))
|
|
if _.endswith(".jpg")
|
|
]
|
|
cloth_image_paths = [
|
|
person_image_path.replace("image", "cloth")
|
|
for person_image_path in person_image_paths
|
|
]
|
|
|
|
if mode == "logo":
|
|
filter_pairs = [
|
|
6648,
|
|
6744,
|
|
6967,
|
|
6985,
|
|
14031,
|
|
12358,
|
|
4963,
|
|
4680,
|
|
499,
|
|
396,
|
|
345,
|
|
6648,
|
|
6744,
|
|
6967,
|
|
6985,
|
|
7510,
|
|
8205,
|
|
8254,
|
|
10545,
|
|
11485,
|
|
11632,
|
|
12354,
|
|
13144,
|
|
14112,
|
|
12570,
|
|
11766,
|
|
]
|
|
filter_pairs.sort()
|
|
filter_pairs = [f"{_:05d}_00.jpg" for _ in filter_pairs]
|
|
cloth_image_paths = [
|
|
cloth_image_paths[i]
|
|
for i in range(len(cloth_image_paths))
|
|
if os.path.basename(cloth_image_paths[i]) in filter_pairs
|
|
]
|
|
person_image_paths = [
|
|
person_image_paths[i]
|
|
for i in range(len(person_image_paths))
|
|
if os.path.basename(person_image_paths[i]) in filter_pairs
|
|
]
|
|
return cloth_image_paths, person_image_paths
|
|
|
|
|
|
def tensor_to_image(tensor: torch.Tensor):
|
|
"""
|
|
Converts a torch tensor to PIL Image.
|
|
"""
|
|
assert tensor.dim() == 3, "Input tensor should be 3-dimensional."
|
|
assert tensor.dtype == torch.float32, "Input tensor should be float32."
|
|
assert (
|
|
tensor.min() >= 0 and tensor.max() <= 1
|
|
), "Input tensor should be in range [0, 1]."
|
|
tensor = tensor.cpu()
|
|
tensor = tensor * 255
|
|
tensor = tensor.permute(1, 2, 0)
|
|
tensor = tensor.numpy().astype(np.uint8)
|
|
image = Image.fromarray(tensor)
|
|
return image
|
|
|
|
|
|
def concat_images(images: List[Image.Image], divider: int = 4, cols: int = 4):
|
|
"""
|
|
Concatenates images horizontally and with
|
|
"""
|
|
widths = [image.size[0] for image in images]
|
|
heights = [image.size[1] for image in images]
|
|
total_width = cols * max(widths)
|
|
total_width += divider * (cols - 1)
|
|
|
|
rows = math.ceil(len(images) / cols)
|
|
total_height = max(heights) * rows
|
|
|
|
total_height += divider * (len(heights) // cols - 1)
|
|
|
|
|
|
concat_image = Image.new("RGB", (total_width, total_height), (0, 0, 0))
|
|
|
|
x_offset = 0
|
|
y_offset = 0
|
|
for i, image in enumerate(images):
|
|
concat_image.paste(image, (x_offset, y_offset))
|
|
x_offset += image.size[0] + divider
|
|
if (i + 1) % cols == 0:
|
|
x_offset = 0
|
|
y_offset += image.size[1] + divider
|
|
|
|
return concat_image
|
|
|
|
|
|
def read_prompt_file(prompt_file: str):
|
|
if prompt_file is not None and os.path.isfile(prompt_file):
|
|
with open(prompt_file, "r") as sample_prompt_file:
|
|
sample_prompts = sample_prompt_file.readlines()
|
|
sample_prompts = [sample_prompt.strip() for sample_prompt in sample_prompts]
|
|
else:
|
|
sample_prompts = []
|
|
return sample_prompts
|
|
|
|
|
|
def save_tensors_to_npz(tensors: torch.Tensor, paths: List[str]):
|
|
assert len(tensors) == len(paths), "Length of tensors and paths should be the same!"
|
|
for tensor, path in zip(tensors, paths):
|
|
np.savez_compressed(path, latent=tensor.cpu().numpy())
|
|
|
|
|
|
def deepspeed_zero_init_disabled_context_manager():
|
|
"""
|
|
returns either a context list that includes one that will disable zero.Init or an empty context list
|
|
"""
|
|
deepspeed_plugin = (
|
|
AcceleratorState().deepspeed_plugin
|
|
if accelerate.state.is_initialized()
|
|
else None
|
|
)
|
|
if deepspeed_plugin is None:
|
|
return []
|
|
|
|
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
|
|
|
|
|
|
def is_xformers_available():
|
|
try:
|
|
import xformers
|
|
|
|
xformers_version = version.parse(xformers.__version__)
|
|
if xformers_version == version.parse("0.0.16"):
|
|
print(
|
|
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
|
|
"please update xFormers to at least 0.0.17. "
|
|
"See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
|
)
|
|
return True
|
|
except ImportError:
|
|
raise ValueError(
|
|
"xformers is not available. Make sure it is installed correctly"
|
|
)
|
|
|
|
|
|
def resize_and_crop(image, size):
|
|
|
|
w, h = image.size
|
|
target_w, target_h = size
|
|
if w / h < target_w / target_h:
|
|
new_w = w
|
|
new_h = w * target_h // target_w
|
|
else:
|
|
new_h = h
|
|
new_w = h * target_w // target_h
|
|
image = image.crop(
|
|
((w - new_w) // 2, (h - new_h) // 2, (w + new_w) // 2, (h + new_h) // 2)
|
|
)
|
|
|
|
image = image.resize(size, Image.LANCZOS)
|
|
return image
|
|
|
|
|
|
def resize_and_padding(image, size):
|
|
|
|
w, h = image.size
|
|
target_w, target_h = size
|
|
if w / h < target_w / target_h:
|
|
new_h = target_h
|
|
new_w = w * target_h // h
|
|
else:
|
|
new_w = target_w
|
|
new_h = h * target_w // w
|
|
image = image.resize((new_w, new_h), Image.LANCZOS)
|
|
|
|
padding = Image.new("RGB", size, (255, 255, 255))
|
|
padding.paste(image, ((target_w - new_w) // 2, (target_h - new_h) // 2))
|
|
return padding
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pass
|
|
|