PAID / utils.py
wjh
init
67e6974
raw
history blame contribute delete
No virus
6.09 kB
import os
from typing import Optional
import matplotlib.pyplot as plt
import numpy as np
import torch
from lpips import LPIPS
from PIL import Image
from torchvision.transforms import Normalize
def show_images_horizontally(
list_of_files: np.array, output_file: Optional[str] = None, interact: bool = False
) -> None:
"""
Visualize the list of images horizontally and save the figure as PNG.
Args:
list_of_files: The list of images as numpy array with shape (N, H, W, C).
output_file: The output file path to save the figure as PNG.
interact: Whether to show the figure interactively in Jupyter Notebook or not in Python.
"""
number_of_files = len(list_of_files)
heights = [a[0].shape[0] for a in list_of_files]
widths = [a.shape[1] for a in list_of_files[0]]
fig_width = 8.0 # inches
fig_height = fig_width * sum(heights) / sum(widths)
# Create a figure with subplots
_, axs = plt.subplots(
1, number_of_files, figsize=(fig_width * number_of_files, fig_height)
)
plt.tight_layout()
for i in range(number_of_files):
_image = list_of_files[i]
axs[i].imshow(_image)
axs[i].axis("off")
# Save the figure as PNG
if interact:
plt.show()
else:
plt.savefig(output_file, bbox_inches="tight", pad_inches=0.25)
def save_image(image: np.array, file_name: str) -> None:
"""
Save the image as JPG.
Args:
image: The input image as numpy array with shape (H, W, C).
file_name: The file name to save the image.
"""
image = Image.fromarray(image)
image.save(file_name)
def load_and_process_images(load_dir: str) -> np.array:
"""
Load and process the images into numpy array from the directory.
Args:
load_dir: The directory to load the images.
Returns:
images: The images as numpy array with shape (N, H, W, C).
"""
images = []
print(load_dir)
filenames = sorted(
os.listdir(load_dir), key=lambda x: int(x.split(".")[0])
) # Ensure the files are sorted numerically
for filename in filenames:
if filename.endswith(".jpg"):
img = Image.open(os.path.join(load_dir, filename))
img_array = (
np.asarray(img) / 255.0
) # Convert to numpy array and scale pixel values to [0, 1]
images.append(img_array)
return images
def compute_lpips(images: np.array, lpips_model: LPIPS) -> np.array:
"""
Compute the LPIPS of the input images.
Args:
images: The input images as numpy array with shape (N, H, W, C).
lpips_model: The LPIPS model used to compute perceptual distances.
Returns:
distances: The LPIPS of the input images.
"""
# Get device of lpips_model
device = next(lpips_model.parameters()).device
device = str(device)
# Change the input images into tensor
images = torch.tensor(images).to(device).float()
images = torch.permute(images, (0, 3, 1, 2))
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
images = normalize(images)
# Compute the LPIPS between each adjacent input images
distances = []
for i in range(images.shape[0]):
if i == images.shape[0] - 1:
break
img1 = images[i].unsqueeze(0)
img2 = images[i + 1].unsqueeze(0)
loss = lpips_model(img1, img2)
distances.append(loss.item())
distances = np.array(distances)
return distances
def compute_gini(distances: np.array) -> float:
"""
Compute the Gini index of the input distances.
Args:
distances: The input distances as numpy array.
Returns:
gini: The Gini index of the input distances.
"""
if len(distances) < 2:
return 0.0 # Gini index is 0 for less than two elements
# Sort the list of distances
sorted_distances = sorted(distances)
n = len(sorted_distances)
mean_distance = sum(sorted_distances) / n
# Compute the sum of absolute differences
sum_of_differences = 0
for di in sorted_distances:
for dj in sorted_distances:
sum_of_differences += abs(di - dj)
# Normalize the sum of differences by the mean and the number of elements
gini = sum_of_differences / (2 * n * n * mean_distance)
return gini
def compute_smoothness_and_consistency(images: np.array, lpips_model: LPIPS) -> tuple:
"""
Compute the smoothness and efficiency of the input images.
Args:
images: The input images as numpy array with shape (N, H, W, C).
lpips_model: The LPIPS model used to compute perceptual distances.
Returns:
smoothness: One minus gini index of LPIPS of consecutive images.
consistency: The mean LPIPS of consecutive images.
max_inception_distance: The maximum LPIPS of consecutive images.
"""
distances = compute_lpips(images, lpips_model)
smoothness = 1 - compute_gini(distances)
consistency = np.mean(distances)
max_inception_distance = np.max(distances)
return smoothness, consistency, max_inception_distance
def separate_source_and_interpolated_images(images: np.array) -> tuple:
"""
Separate the input images into source and interpolated images.
The input source is the start and end of the images, while the interpolated images are the rest.
Args:
images: The input images as numpy array with shape (N, H, W, C).
Returns:
source: The source images as numpy array with shape (2, H, W, C).
interpolation: The interpolated images as numpy array with shape (N-2, H, W, C).
"""
# Check if the array has at least two elements
if len(images) < 2:
raise ValueError("The input array should have at least two elements.")
# Separate the array into two parts
# First part takes the first and last element
source = np.array([images[0], images[-1]])
# Second part takes the rest of the elements
interpolation = images[1:-1]
return source, interpolation