Artist / utils /exp_utils.py
fffiloni's picture
Upload 20 files
e02c605 verified
raw
history blame
2.76 kB
import uuid
import os
import PIL.Image as Image
import torch
import numpy as np
from torchvision import transforms
import torch.nn.functional as F
import torchvision
def make_unique_experiment_path(base_dir: str) -> str:
"""
Create a unique directory in the base directory, named as the least unused number.
return: path to the unique directory
"""
if not os.path.exists(base_dir):
os.makedirs(base_dir)
# List all existing directories
existing_dirs = [
d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))
]
# Convert directory names to integers, filter out non-numeric names
existing_numbers = sorted([int(d) for d in existing_dirs if d.isdigit()])
# Find the least unused number
experiment_id = 1
for number in existing_numbers:
if number != experiment_id:
break
experiment_id += 1
# Create the new directory
experiment_output_path = os.path.join(base_dir, str(experiment_id))
os.makedirs(experiment_output_path)
return experiment_output_path
def get_processed_image(image_dir: str, device, resolution) -> torch.Tensor:
src_img = Image.open(image_dir)
src_img = transforms.ToTensor()(src_img).unsqueeze(0).to(device)
h, w = src_img.shape[-2:]
src_img_512 = torchvision.transforms.functional.pad(
src_img, ((resolution - w) // 2,), fill=0, padding_mode="constant"
)
input_image = F.interpolate(
src_img, (resolution, resolution), mode="bilinear", align_corners=False
)
# drop alpha channel if it exists
if input_image.shape[1] == 4:
input_image = input_image[:, :3]
return input_image
def process_image(image, device, resolution) -> torch.Tensor:
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
src_img = image
src_img = transforms.ToTensor()(src_img).unsqueeze(0).to(device)
h, w = src_img.shape[-2:]
src_img_512 = torchvision.transforms.functional.pad(
src_img, ((resolution - w) // 2,), fill=0, padding_mode="constant"
)
input_image = F.interpolate(
src_img, (resolution, resolution), mode="bilinear", align_corners=False
)
# drop alpha channel if it exists
if input_image.shape[1] == 4:
input_image = input_image[:, :3]
return input_image
def seed_all(seed: int):
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
g_cpu = torch.Generator(device="cpu")
g_cpu.manual_seed(42)
def dump_tensor(tensor, filename):
with open(filename) as f:
torch.save(tensor, f)