DI-PCG / core /utils /train_utils.py
thuzhaowang's picture
init
b6a9b6d
raw
history blame
2.5 kB
import os
import torch
import numpy as np
import logging
from collections import OrderedDict
from PIL import Image
def requires_grad(model, flag=True):
"""
Set requires_grad flag for all parameters in a model.
"""
for p in model.parameters():
p.requires_grad = flag
def create_logger(logging_dir):
"""
Create a logger that writes to a log file and stdout.
"""
logging.basicConfig(
level=logging.INFO,
format='[\033[34m%(asctime)s\033[0m] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
)
logger = logging.getLogger(__name__)
return logger
@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
"""
Step the EMA model towards the current model.
"""
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
name = name.replace("module.", "")
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
def load_model(ckpt_name):
"""
Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path.
"""
# Load a custom DiT checkpoint:
assert os.path.isfile(ckpt_name), f'Could not find DiT checkpoint at {ckpt_name}'
checkpoint = torch.load(ckpt_name, map_location=lambda storage, loc: storage)
if "ema" in checkpoint: # supports checkpoints from train.py
checkpoint = checkpoint["ema"]
return checkpoint