|
import argparse |
|
import contextlib |
|
import time |
|
import gc |
|
import logging |
|
import math |
|
import os |
|
import random |
|
import jsonlines |
|
import functools |
|
import shutil |
|
import pyrallis |
|
import itertools |
|
from pathlib import Path |
|
from collections import namedtuple, OrderedDict |
|
|
|
import accelerate |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
import transformers |
|
from accelerate import Accelerator |
|
from accelerate.logging import get_logger |
|
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed |
|
from datasets import load_dataset |
|
from packaging import version |
|
from PIL import Image |
|
from losses.losses import * |
|
from torchvision import transforms |
|
from torchvision.transforms.functional import crop |
|
from tqdm.auto import tqdm |
|
|
|
|
|
def import_model_class_from_model_name_or_path( |
|
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" |
|
): |
|
from transformers import PretrainedConfig |
|
text_encoder_config = PretrainedConfig.from_pretrained( |
|
pretrained_model_name_or_path, subfolder=subfolder, revision=revision |
|
) |
|
model_class = text_encoder_config.architectures[0] |
|
|
|
if model_class == "CLIPTextModel": |
|
from transformers import CLIPTextModel |
|
|
|
return CLIPTextModel |
|
elif model_class == "CLIPTextModelWithProjection": |
|
from transformers import CLIPTextModelWithProjection |
|
|
|
return CLIPTextModelWithProjection |
|
else: |
|
raise ValueError(f"{model_class} is not supported.") |
|
|
|
def get_train_dataset(dataset_name, dataset_dir, args, accelerator): |
|
|
|
|
|
|
|
|
|
|
|
dataset = load_dataset( |
|
dataset_name, |
|
data_dir=dataset_dir, |
|
cache_dir=os.path.join(dataset_dir, ".cache"), |
|
num_proc=4, |
|
split="train", |
|
) |
|
|
|
|
|
|
|
column_names = dataset.column_names |
|
|
|
|
|
if args.image_column is None: |
|
args.image_column = column_names[0] |
|
logger.info(f"image column defaulting to {column_names[0]}") |
|
else: |
|
image_column = args.image_column |
|
if image_column not in column_names: |
|
logger.warning(f"dataset {dataset_name} has no column {image_column}") |
|
|
|
if args.caption_column is None: |
|
args.caption_column = column_names[1] |
|
logger.info(f"caption column defaulting to {column_names[1]}") |
|
else: |
|
caption_column = args.caption_column |
|
if caption_column not in column_names: |
|
logger.warning(f"dataset {dataset_name} has no column {caption_column}") |
|
|
|
if args.conditioning_image_column is None: |
|
args.conditioning_image_column = column_names[2] |
|
logger.info(f"conditioning image column defaulting to {column_names[2]}") |
|
else: |
|
conditioning_image_column = args.conditioning_image_column |
|
if conditioning_image_column not in column_names: |
|
logger.warning(f"dataset {dataset_name} has no column {conditioning_image_column}") |
|
|
|
with accelerator.main_process_first(): |
|
train_dataset = dataset.shuffle(seed=args.seed) |
|
if args.max_train_samples is not None: |
|
train_dataset = train_dataset.select(range(args.max_train_samples)) |
|
return train_dataset |
|
|
|
def prepare_train_dataset(dataset, accelerator, deg_pipeline, centralize=False): |
|
|
|
|
|
hflip = deg_pipeline.augment_opt['use_hflip'] and random.random() < 0.5 |
|
vflip = deg_pipeline.augment_opt['use_rot'] and random.random() < 0.5 |
|
rot90 = deg_pipeline.augment_opt['use_rot'] and random.random() < 0.5 |
|
augment_transforms = [] |
|
if hflip: |
|
augment_transforms.append(transforms.RandomHorizontalFlip(p=1.0)) |
|
if vflip: |
|
augment_transforms.append(transforms.RandomVerticalFlip(p=1.0)) |
|
if rot90: |
|
|
|
augment_transforms.append(transforms.RandomRotation(degrees=(90,90))) |
|
torch_transforms=[transforms.ToTensor()] |
|
if centralize: |
|
|
|
torch_transforms.append(transforms.Normalize([0.5], [0.5])) |
|
|
|
training_size = deg_pipeline.degrade_opt['gt_size'] |
|
image_transforms = transforms.Compose(augment_transforms) |
|
train_transforms = transforms.Compose(torch_transforms) |
|
train_resize = transforms.Resize(training_size, interpolation=transforms.InterpolationMode.BILINEAR) |
|
train_crop = transforms.RandomCrop(training_size) |
|
|
|
def preprocess_train(examples): |
|
raw_images = [] |
|
for img_data in examples[args.image_column]: |
|
raw_images.append(Image.open(img_data).convert("RGB")) |
|
|
|
|
|
images = [] |
|
original_sizes = [] |
|
crop_top_lefts = [] |
|
|
|
kernel = [] |
|
kernel2 = [] |
|
sinc_kernel = [] |
|
|
|
for raw_image in raw_images: |
|
raw_image = image_transforms(raw_image) |
|
original_sizes.append((raw_image.height, raw_image.width)) |
|
|
|
|
|
raw_image = train_resize(raw_image) |
|
|
|
y1, x1, h, w = train_crop.get_params(raw_image, (training_size, training_size)) |
|
raw_image = crop(raw_image, y1, x1, h, w) |
|
crop_top_left = (y1, x1) |
|
crop_top_lefts.append(crop_top_left) |
|
image = train_transforms(raw_image) |
|
|
|
images.append(image) |
|
k, k2, sk = deg_pipeline.get_kernel() |
|
kernel.append(k) |
|
kernel2.append(k2) |
|
sinc_kernel.append(sk) |
|
|
|
examples["images"] = images |
|
examples["original_sizes"] = original_sizes |
|
examples["crop_top_lefts"] = crop_top_lefts |
|
examples["kernel"] = kernel |
|
examples["kernel2"] = kernel2 |
|
examples["sinc_kernel"] = sinc_kernel |
|
|
|
return examples |
|
|
|
with accelerator.main_process_first(): |
|
dataset = dataset.with_transform(preprocess_train) |
|
|
|
return dataset |
|
|
|
def collate_fn(examples): |
|
images = torch.stack([example["images"] for example in examples]) |
|
images = images.to(memory_format=torch.contiguous_format).float() |
|
kernel = torch.stack([example["kernel"] for example in examples]) |
|
kernel = kernel.to(memory_format=torch.contiguous_format).float() |
|
kernel2 = torch.stack([example["kernel2"] for example in examples]) |
|
kernel2 = kernel2.to(memory_format=torch.contiguous_format).float() |
|
sinc_kernel = torch.stack([example["sinc_kernel"] for example in examples]) |
|
sinc_kernel = sinc_kernel.to(memory_format=torch.contiguous_format).float() |
|
original_sizes = [example["original_sizes"] for example in examples] |
|
crop_top_lefts = [example["crop_top_lefts"] for example in examples] |
|
|
|
prompts = [] |
|
for example in examples: |
|
prompts.append(example[args.caption_column]) if args.caption_column in example else prompts.append("") |
|
|
|
return { |
|
"images": images, |
|
"text": prompts, |
|
"kernel": kernel, |
|
"kernel2": kernel2, |
|
"sinc_kernel": sinc_kernel, |
|
"original_sizes": original_sizes, |
|
"crop_top_lefts": crop_top_lefts, |
|
} |
|
|
|
def encode_prompt(prompt_batch, text_encoders, tokenizers, is_train=True): |
|
prompt_embeds_list = [] |
|
|
|
captions = [] |
|
for caption in prompt_batch: |
|
if isinstance(caption, str): |
|
captions.append(caption) |
|
elif isinstance(caption, (list, np.ndarray)): |
|
|
|
captions.append(random.choice(caption) if is_train else caption[0]) |
|
|
|
with torch.no_grad(): |
|
for tokenizer, text_encoder in zip(tokenizers, text_encoders): |
|
text_inputs = tokenizer( |
|
captions, |
|
padding="max_length", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
text_input_ids = text_inputs.input_ids |
|
prompt_embeds = text_encoder( |
|
text_input_ids.to(text_encoder.device), |
|
output_hidden_states=True, |
|
) |
|
|
|
|
|
pooled_prompt_embeds = prompt_embeds[0] |
|
prompt_embeds = prompt_embeds.hidden_states[-2] |
|
bs_embed, seq_len, _ = prompt_embeds.shape |
|
prompt_embeds_list.append(prompt_embeds) |
|
|
|
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) |
|
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) |
|
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) |
|
return prompt_embeds, pooled_prompt_embeds |
|
|
|
def importance_sampling_fn(t, max_t, alpha): |
|
"""Importance Sampling Function f(t)""" |
|
return 1 / max_t * (1 - alpha * np.cos(np.pi * t / max_t)) |
|
|
|
def extract_into_tensor(a, t, x_shape): |
|
b, *_ = t.shape |
|
out = a.gather(-1, t) |
|
return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
|
|
|
def tensor_to_pil(images): |
|
""" |
|
Convert image tensor or a batch of image tensors to PIL image(s). |
|
""" |
|
images = (images + 1) / 2 |
|
images_np = images.detach().cpu().numpy() |
|
if images_np.ndim == 4: |
|
images_np = np.transpose(images_np, (0, 2, 3, 1)) |
|
elif images_np.ndim == 3: |
|
images_np = np.transpose(images_np, (1, 2, 0)) |
|
images_np = images_np[None, ...] |
|
images_np = (images_np * 255).round().astype("uint8") |
|
if images_np.shape[-1] == 1: |
|
|
|
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images_np] |
|
else: |
|
pil_images = [Image.fromarray(image[:, :, :3]) for image in images_np] |
|
|
|
return pil_images |
|
|
|
def save_np_to_image(img_np, save_dir): |
|
img_np = np.transpose(img_np, (0, 2, 3, 1)) |
|
img_np = (img_np * 255).astype(np.uint8) |
|
img_np = Image.fromarray(img_np[0]) |
|
img_np.save(save_dir) |
|
|
|
|
|
def seperate_SFT_params_from_unet(unet): |
|
params = [] |
|
non_params = [] |
|
for name, param in unet.named_parameters(): |
|
if "SFT" in name: |
|
params.append(param) |
|
else: |
|
non_params.append(param) |
|
return params, non_params |
|
|
|
|
|
def seperate_lora_params_from_unet(unet): |
|
keys = [] |
|
frozen_keys = [] |
|
for name, param in unet.named_parameters(): |
|
if "lora" in name: |
|
keys.append(param) |
|
else: |
|
frozen_keys.append(param) |
|
return keys, frozen_keys |
|
|
|
|
|
def seperate_ip_params_from_unet(unet): |
|
ip_params = [] |
|
non_ip_params = [] |
|
for name, param in unet.named_parameters(): |
|
if "encoder_hid_proj." in name or "_ip." in name: |
|
ip_params.append(param) |
|
elif "attn" in name and "processor" in name: |
|
if "ip" in name or "ln" in name: |
|
ip_params.append(param) |
|
else: |
|
non_ip_params.append(param) |
|
return ip_params, non_ip_params |
|
|
|
|
|
def seperate_ref_params_from_unet(unet): |
|
ip_params = [] |
|
non_ip_params = [] |
|
for name, param in unet.named_parameters(): |
|
if "encoder_hid_proj." in name or "_ip." in name: |
|
ip_params.append(param) |
|
elif "attn" in name and "processor" in name: |
|
if "ip" in name or "ln" in name: |
|
ip_params.append(param) |
|
elif "extract" in name: |
|
ip_params.append(param) |
|
else: |
|
non_ip_params.append(param) |
|
return ip_params, non_ip_params |
|
|
|
|
|
def seperate_ip_modules_from_unet(unet): |
|
ip_modules = [] |
|
non_ip_modules = [] |
|
for name, module in unet.named_modules(): |
|
if "encoder_hid_proj" in name or "attn2.processor" in name: |
|
ip_modules.append(module) |
|
else: |
|
non_ip_modules.append(module) |
|
return ip_modules, non_ip_modules |
|
|
|
|
|
def seperate_SFT_keys_from_unet(unet): |
|
keys = [] |
|
non_keys = [] |
|
for name, param in unet.named_parameters(): |
|
if "SFT" in name: |
|
keys.append(name) |
|
else: |
|
non_keys.append(name) |
|
return keys, non_keys |
|
|
|
|
|
def seperate_ip_keys_from_unet(unet): |
|
keys = [] |
|
non_keys = [] |
|
for name, param in unet.named_parameters(): |
|
if "encoder_hid_proj." in name or "_ip." in name: |
|
keys.append(name) |
|
elif "attn" in name and "processor" in name: |
|
if "ip" in name or "ln" in name: |
|
keys.append(name) |
|
else: |
|
non_keys.append(name) |
|
return keys, non_keys |