|
import os |
|
import random |
|
import argparse |
|
import json |
|
import torch |
|
from PIL import Image |
|
from torchvision import transforms |
|
import torchvision.transforms.functional as F |
|
from glob import glob |
|
|
|
|
|
def parse_args_paired_training(input_args=None): |
|
""" |
|
Parses command-line arguments used for configuring an paired session (pix2pix-Turbo). |
|
This function sets up an argument parser to handle various training options. |
|
|
|
Returns: |
|
argparse.Namespace: The parsed command-line arguments. |
|
""" |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--gan_disc_type", default="vagan_clip") |
|
parser.add_argument("--gan_loss_type", default="multilevel_sigmoid_s") |
|
parser.add_argument("--lambda_gan", default=0.5, type=float) |
|
parser.add_argument("--lambda_lpips", default=5, type=float) |
|
parser.add_argument("--lambda_l2", default=1.0, type=float) |
|
parser.add_argument("--lambda_clipsim", default=5.0, type=float) |
|
|
|
|
|
parser.add_argument("--dataset_folder", required=True, type=str) |
|
parser.add_argument("--train_image_prep", default="resized_crop_512", type=str) |
|
parser.add_argument("--test_image_prep", default="resized_crop_512", type=str) |
|
|
|
|
|
parser.add_argument("--eval_freq", default=100, type=int) |
|
parser.add_argument("--track_val_fid", default=False, action="store_true") |
|
parser.add_argument("--num_samples_eval", type=int, default=100, help="Number of samples to use for all evaluation") |
|
|
|
parser.add_argument("--viz_freq", type=int, default=100, help="Frequency of visualizing the outputs.") |
|
parser.add_argument("--tracker_project_name", type=str, default="train_pix2pix_turbo", help="The name of the wandb project to log to.") |
|
|
|
|
|
parser.add_argument("--pretrained_model_name_or_path") |
|
parser.add_argument("--revision", type=str, default=None,) |
|
parser.add_argument("--variant", type=str, default=None,) |
|
parser.add_argument("--tokenizer_name", type=str, default=None) |
|
parser.add_argument("--lora_rank_unet", default=8, type=int) |
|
parser.add_argument("--lora_rank_vae", default=4, type=int) |
|
|
|
|
|
parser.add_argument("--output_dir", required=True) |
|
parser.add_argument("--cache_dir", default=None,) |
|
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") |
|
parser.add_argument("--resolution", type=int, default=512,) |
|
parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.") |
|
parser.add_argument("--num_training_epochs", type=int, default=10) |
|
parser.add_argument("--max_train_steps", type=int, default=10_000,) |
|
parser.add_argument("--checkpointing_steps", type=int, default=500,) |
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.",) |
|
parser.add_argument("--gradient_checkpointing", action="store_true",) |
|
parser.add_argument("--learning_rate", type=float, default=5e-6) |
|
parser.add_argument("--lr_scheduler", type=str, default="constant", |
|
help=( |
|
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
|
' "constant", "constant_with_warmup"]' |
|
), |
|
) |
|
parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.") |
|
parser.add_argument("--lr_num_cycles", type=int, default=1, |
|
help="Number of hard resets of the lr in cosine_with_restarts scheduler.", |
|
) |
|
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") |
|
|
|
parser.add_argument("--dataloader_num_workers", type=int, default=0,) |
|
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") |
|
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") |
|
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") |
|
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") |
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") |
|
parser.add_argument("--allow_tf32", action="store_true", |
|
help=( |
|
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" |
|
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" |
|
), |
|
) |
|
parser.add_argument("--report_to", type=str, default="wandb", |
|
help=( |
|
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' |
|
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' |
|
), |
|
) |
|
parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],) |
|
parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") |
|
parser.add_argument("--set_grads_to_none", action="store_true",) |
|
|
|
if input_args is not None: |
|
args = parser.parse_args(input_args) |
|
else: |
|
args = parser.parse_args() |
|
|
|
return args |
|
|
|
|
|
def parse_args_unpaired_training(): |
|
""" |
|
Parses command-line arguments used for configuring an unpaired session (CycleGAN-Turbo). |
|
This function sets up an argument parser to handle various training options. |
|
|
|
Returns: |
|
argparse.Namespace: The parsed command-line arguments. |
|
""" |
|
|
|
parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") |
|
|
|
|
|
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") |
|
|
|
|
|
parser.add_argument("--gan_disc_type", default="vagan_clip") |
|
parser.add_argument("--gan_loss_type", default="multilevel_sigmoid") |
|
parser.add_argument("--lambda_gan", default=0.5, type=float) |
|
parser.add_argument("--lambda_idt", default=1, type=float) |
|
parser.add_argument("--lambda_cycle", default=1, type=float) |
|
parser.add_argument("--lambda_cycle_lpips", default=10.0, type=float) |
|
parser.add_argument("--lambda_idt_lpips", default=1.0, type=float) |
|
|
|
|
|
parser.add_argument("--dataset_folder", required=True, type=str) |
|
parser.add_argument("--train_img_prep", required=True) |
|
parser.add_argument("--val_img_prep", required=True) |
|
parser.add_argument("--dataloader_num_workers", type=int, default=0) |
|
parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.") |
|
parser.add_argument("--max_train_epochs", type=int, default=100) |
|
parser.add_argument("--max_train_steps", type=int, default=None) |
|
|
|
|
|
parser.add_argument("--pretrained_model_name_or_path", default="stabilityai/sd-turbo") |
|
parser.add_argument("--revision", default=None, type=str) |
|
parser.add_argument("--variant", default=None, type=str) |
|
parser.add_argument("--lora_rank_unet", default=128, type=int) |
|
parser.add_argument("--lora_rank_vae", default=4, type=int) |
|
|
|
|
|
parser.add_argument("--viz_freq", type=int, default=20) |
|
parser.add_argument("--output_dir", type=str, required=True) |
|
parser.add_argument("--report_to", type=str, default="wandb") |
|
parser.add_argument("--tracker_project_name", type=str, required=True) |
|
parser.add_argument("--validation_steps", type=int, default=500,) |
|
parser.add_argument("--validation_num_images", type=int, default=-1, help="Number of images to use for validation. -1 to use all images.") |
|
parser.add_argument("--checkpointing_steps", type=int, default=500) |
|
|
|
|
|
parser.add_argument("--learning_rate", type=float, default=5e-6,) |
|
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") |
|
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") |
|
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") |
|
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") |
|
parser.add_argument("--max_grad_norm", default=10.0, type=float, help="Max gradient norm.") |
|
parser.add_argument("--lr_scheduler", type=str, default="constant", help=( |
|
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
|
' "constant", "constant_with_warmup"]' |
|
), |
|
) |
|
parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.") |
|
parser.add_argument("--lr_num_cycles", type=int, default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.",) |
|
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") |
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1) |
|
|
|
|
|
parser.add_argument("--allow_tf32", action="store_true", |
|
help=( |
|
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" |
|
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" |
|
), |
|
) |
|
parser.add_argument("--gradient_checkpointing", action="store_true", |
|
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.") |
|
parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def build_transform(image_prep): |
|
""" |
|
Constructs a transformation pipeline based on the specified image preparation method. |
|
|
|
Parameters: |
|
- image_prep (str): A string describing the desired image preparation |
|
|
|
Returns: |
|
- torchvision.transforms.Compose: A composable sequence of transformations to be applied to images. |
|
""" |
|
if image_prep == "resized_crop_512": |
|
T = transforms.Compose([ |
|
transforms.Resize(512, interpolation=transforms.InterpolationMode.LANCZOS), |
|
transforms.CenterCrop(512), |
|
]) |
|
elif image_prep == "resize_286_randomcrop_256x256_hflip": |
|
T = transforms.Compose([ |
|
transforms.Resize((286, 286), interpolation=Image.LANCZOS), |
|
transforms.RandomCrop((256, 256)), |
|
transforms.RandomHorizontalFlip(), |
|
]) |
|
elif image_prep in ["resize_256", "resize_256x256"]: |
|
T = transforms.Compose([ |
|
transforms.Resize((256, 256), interpolation=Image.LANCZOS) |
|
]) |
|
elif image_prep in ["resize_512", "resize_512x512"]: |
|
T = transforms.Compose([ |
|
transforms.Resize((512, 512), interpolation=Image.LANCZOS) |
|
]) |
|
elif image_prep == "no_resize": |
|
T = transforms.Lambda(lambda x: x) |
|
return T |
|
|
|
|
|
class PairedDataset(torch.utils.data.Dataset): |
|
def __init__(self, dataset_folder, split, image_prep, tokenizer): |
|
""" |
|
Itialize the paired dataset object for loading and transforming paired data samples |
|
from specified dataset folders. |
|
|
|
This constructor sets up the paths to input and output folders based on the specified 'split', |
|
loads the captions (or prompts) for the input images, and prepares the transformations and |
|
tokenizer to be applied on the data. |
|
|
|
Parameters: |
|
- dataset_folder (str): The root folder containing the dataset, expected to include |
|
sub-folders for different splits (e.g., 'train_A', 'train_B'). |
|
- split (str): The dataset split to use ('train' or 'test'), used to select the appropriate |
|
sub-folders and caption files within the dataset folder. |
|
- image_prep (str): The image preprocessing transformation to apply to each image. |
|
- tokenizer: The tokenizer used for tokenizing the captions (or prompts). |
|
""" |
|
super().__init__() |
|
if split == "train": |
|
self.input_folder = os.path.join(dataset_folder, "train_A") |
|
self.output_folder = os.path.join(dataset_folder, "train_B") |
|
captions = os.path.join(dataset_folder, "train_prompts.json") |
|
elif split == "test": |
|
self.input_folder = os.path.join(dataset_folder, "test_A") |
|
self.output_folder = os.path.join(dataset_folder, "test_B") |
|
captions = os.path.join(dataset_folder, "test_prompts.json") |
|
with open(captions, "r") as f: |
|
self.captions = json.load(f) |
|
self.img_names = list(self.captions.keys()) |
|
self.T = build_transform(image_prep) |
|
self.tokenizer = tokenizer |
|
|
|
def __len__(self): |
|
""" |
|
Returns: |
|
int: The total number of items in the dataset. |
|
""" |
|
return len(self.captions) |
|
|
|
def __getitem__(self, idx): |
|
""" |
|
Retrieves a dataset item given its index. Each item consists of an input image, |
|
its corresponding output image, the captions associated with the input image, |
|
and the tokenized form of this caption. |
|
|
|
This method performs the necessary preprocessing on both the input and output images, |
|
including scaling and normalization, as well as tokenizing the caption using a provided tokenizer. |
|
|
|
Parameters: |
|
- idx (int): The index of the item to retrieve. |
|
|
|
Returns: |
|
dict: A dictionary containing the following key-value pairs: |
|
- "output_pixel_values": a tensor of the preprocessed output image with pixel values |
|
scaled to [-1, 1]. |
|
- "conditioning_pixel_values": a tensor of the preprocessed input image with pixel values |
|
scaled to [0, 1]. |
|
- "caption": the text caption. |
|
- "input_ids": a tensor of the tokenized caption. |
|
|
|
Note: |
|
The actual preprocessing steps (scaling and normalization) for images are defined externally |
|
and passed to this class through the `image_prep` parameter during initialization. The |
|
tokenization process relies on the `tokenizer` also provided at initialization, which |
|
should be compatible with the models intended to be used with this dataset. |
|
""" |
|
img_name = self.img_names[idx] |
|
input_img = Image.open(os.path.join(self.input_folder, img_name)) |
|
output_img = Image.open(os.path.join(self.output_folder, img_name)) |
|
caption = self.captions[img_name] |
|
|
|
|
|
img_t = self.T(input_img) |
|
img_t = F.to_tensor(img_t) |
|
|
|
output_t = self.T(output_img) |
|
output_t = F.to_tensor(output_t) |
|
output_t = F.normalize(output_t, mean=[0.5], std=[0.5]) |
|
|
|
input_ids = self.tokenizer( |
|
caption, max_length=self.tokenizer.model_max_length, |
|
padding="max_length", truncation=True, return_tensors="pt" |
|
).input_ids |
|
|
|
return { |
|
"output_pixel_values": output_t, |
|
"conditioning_pixel_values": img_t, |
|
"caption": caption, |
|
"input_ids": input_ids, |
|
} |
|
|
|
|
|
class UnpairedDataset(torch.utils.data.Dataset): |
|
def __init__(self, dataset_folder, split, image_prep, tokenizer): |
|
""" |
|
A dataset class for loading unpaired data samples from two distinct domains (source and target), |
|
typically used in unsupervised learning tasks like image-to-image translation. |
|
|
|
The class supports loading images from specified dataset folders, applying predefined image |
|
preprocessing transformations, and utilizing fixed textual prompts (captions) for each domain, |
|
tokenized using a provided tokenizer. |
|
|
|
Parameters: |
|
- dataset_folder (str): Base directory of the dataset containing subdirectories (train_A, train_B, test_A, test_B) |
|
- split (str): Indicates the dataset split to use. Expected values are 'train' or 'test'. |
|
- image_prep (str): he image preprocessing transformation to apply to each image. |
|
- tokenizer: The tokenizer used for tokenizing the captions (or prompts). |
|
""" |
|
super().__init__() |
|
if split == "train": |
|
self.source_folder = os.path.join(dataset_folder, "train_A") |
|
self.target_folder = os.path.join(dataset_folder, "train_B") |
|
elif split == "test": |
|
self.source_folder = os.path.join(dataset_folder, "test_A") |
|
self.target_folder = os.path.join(dataset_folder, "test_B") |
|
self.tokenizer = tokenizer |
|
with open(os.path.join(dataset_folder, "fixed_prompt_a.txt"), "r") as f: |
|
self.fixed_caption_src = f.read().strip() |
|
self.input_ids_src = self.tokenizer( |
|
self.fixed_caption_src, max_length=self.tokenizer.model_max_length, |
|
padding="max_length", truncation=True, return_tensors="pt" |
|
).input_ids |
|
|
|
with open(os.path.join(dataset_folder, "fixed_prompt_b.txt"), "r") as f: |
|
self.fixed_caption_tgt = f.read().strip() |
|
self.input_ids_tgt = self.tokenizer( |
|
self.fixed_caption_tgt, max_length=self.tokenizer.model_max_length, |
|
padding="max_length", truncation=True, return_tensors="pt" |
|
).input_ids |
|
|
|
self.l_imgs_src = [] |
|
for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif"]: |
|
self.l_imgs_src.extend(glob(os.path.join(self.source_folder, ext))) |
|
self.l_imgs_tgt = [] |
|
for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif"]: |
|
self.l_imgs_tgt.extend(glob(os.path.join(self.target_folder, ext))) |
|
self.T = build_transform(image_prep) |
|
|
|
def __len__(self): |
|
""" |
|
Returns: |
|
int: The total number of items in the dataset. |
|
""" |
|
return len(self.l_imgs_src) + len(self.l_imgs_tgt) |
|
|
|
def __getitem__(self, index): |
|
""" |
|
Fetches a pair of unaligned images from the source and target domains along with their |
|
corresponding tokenized captions. |
|
|
|
For the source domain, if the requested index is within the range of available images, |
|
the specific image at that index is chosen. If the index exceeds the number of source |
|
images, a random source image is selected. For the target domain, |
|
an image is always randomly selected, irrespective of the index, to maintain the |
|
unpaired nature of the dataset. |
|
|
|
Both images are preprocessed according to the specified image transformation `T`, and normalized. |
|
The fixed captions for both domains |
|
are included along with their tokenized forms. |
|
|
|
Parameters: |
|
- index (int): The index of the source image to retrieve. |
|
|
|
Returns: |
|
dict: A dictionary containing processed data for a single training example, with the following keys: |
|
- "pixel_values_src": The processed source image |
|
- "pixel_values_tgt": The processed target image |
|
- "caption_src": The fixed caption of the source domain. |
|
- "caption_tgt": The fixed caption of the target domain. |
|
- "input_ids_src": The source domain's fixed caption tokenized. |
|
- "input_ids_tgt": The target domain's fixed caption tokenized. |
|
""" |
|
if index < len(self.l_imgs_src): |
|
img_path_src = self.l_imgs_src[index] |
|
else: |
|
img_path_src = random.choice(self.l_imgs_src) |
|
img_path_tgt = random.choice(self.l_imgs_tgt) |
|
img_pil_src = Image.open(img_path_src).convert("RGB") |
|
img_pil_tgt = Image.open(img_path_tgt).convert("RGB") |
|
img_t_src = F.to_tensor(self.T(img_pil_src)) |
|
img_t_tgt = F.to_tensor(self.T(img_pil_tgt)) |
|
img_t_src = F.normalize(img_t_src, mean=[0.5], std=[0.5]) |
|
img_t_tgt = F.normalize(img_t_tgt, mean=[0.5], std=[0.5]) |
|
return { |
|
"pixel_values_src": img_t_src, |
|
"pixel_values_tgt": img_t_tgt, |
|
"caption_src": self.fixed_caption_src, |
|
"caption_tgt": self.fixed_caption_tgt, |
|
"input_ids_src": self.input_ids_src, |
|
"input_ids_tgt": self.input_ids_tgt, |
|
} |
|
|