|
import os |
|
from argparse import ArgumentParser |
|
|
|
from omegaconf import OmegaConf |
|
import torch |
|
from torch.utils.data import DataLoader |
|
from torchvision.utils import make_grid |
|
from accelerate import Accelerator |
|
from accelerate.utils import set_seed |
|
from einops import rearrange |
|
from tqdm import tqdm |
|
from torch.utils.tensorboard import SummaryWriter |
|
from PIL import Image, ImageDraw, ImageFont |
|
import numpy as np |
|
|
|
from model import ControlLDM, SwinIR, Diffusion |
|
from utils.common import instantiate_from_config |
|
from utils.sampler import SpacedSampler |
|
|
|
|
|
def log_txt_as_img(wh, xc): |
|
|
|
|
|
b = len(xc) |
|
txts = list() |
|
for bi in range(b): |
|
txt = Image.new("RGB", wh, color="white") |
|
draw = ImageDraw.Draw(txt) |
|
|
|
font = ImageFont.load_default() |
|
nc = int(40 * (wh[0] / 256)) |
|
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) |
|
|
|
try: |
|
draw.text((0, 0), lines, fill="black", font=font) |
|
except UnicodeEncodeError: |
|
print("Cant encode string for logging. Skipping.") |
|
|
|
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 |
|
txts.append(txt) |
|
txts = np.stack(txts) |
|
txts = torch.tensor(txts) |
|
return txts |
|
|
|
|
|
def main(args) -> None: |
|
|
|
accelerator = Accelerator(split_batches=True) |
|
set_seed(231) |
|
device = accelerator.device |
|
cfg = OmegaConf.load(args.config) |
|
|
|
|
|
if accelerator.is_local_main_process: |
|
exp_dir = cfg.train.exp_dir |
|
os.makedirs(exp_dir, exist_ok=True) |
|
ckpt_dir = os.path.join(exp_dir, "checkpoints") |
|
os.makedirs(ckpt_dir, exist_ok=True) |
|
print(f"Experiment directory created at {exp_dir}") |
|
|
|
|
|
cldm: ControlLDM = instantiate_from_config(cfg.model.cldm) |
|
sd = torch.load(cfg.train.sd_path, map_location="cpu")["state_dict"] |
|
unused = cldm.load_pretrained_sd(sd) |
|
if accelerator.is_local_main_process: |
|
print(f"strictly load pretrained SD weight from {cfg.train.sd_path}\n" |
|
f"unused weights: {unused}") |
|
|
|
if cfg.train.resume: |
|
cldm.load_controlnet_from_ckpt(torch.load(cfg.train.resume, map_location="cpu")) |
|
if accelerator.is_local_main_process: |
|
print(f"strictly load controlnet weight from checkpoint: {cfg.train.resume}") |
|
else: |
|
init_with_new_zero, init_with_scratch = cldm.load_controlnet_from_unet() |
|
if accelerator.is_local_main_process: |
|
print(f"strictly load controlnet weight from pretrained SD\n" |
|
f"weights initialized with newly added zeros: {init_with_new_zero}\n" |
|
f"weights initialized from scratch: {init_with_scratch}") |
|
|
|
swinir: SwinIR = instantiate_from_config(cfg.model.swinir) |
|
sd = { |
|
(k[len("module."):] if k.startswith("module.") else k): v |
|
for k, v in torch.load(cfg.train.swinir_path, map_location="cpu").items() |
|
} |
|
swinir.load_state_dict(sd, strict=True) |
|
for p in swinir.parameters(): |
|
p.requires_grad = False |
|
if accelerator.is_local_main_process: |
|
print(f"load SwinIR from {cfg.train.swinir_path}") |
|
|
|
diffusion: Diffusion = instantiate_from_config(cfg.model.diffusion) |
|
|
|
|
|
opt = torch.optim.AdamW(cldm.controlnet.parameters(), lr=cfg.train.learning_rate) |
|
|
|
|
|
dataset = instantiate_from_config(cfg.dataset.train) |
|
loader = DataLoader( |
|
dataset=dataset, batch_size=cfg.train.batch_size, |
|
num_workers=cfg.train.num_workers, |
|
shuffle=True, drop_last=True |
|
) |
|
if accelerator.is_local_main_process: |
|
print(f"Dataset contains {len(dataset):,} images from {dataset.file_list}") |
|
|
|
|
|
cldm.train().to(device) |
|
swinir.eval().to(device) |
|
diffusion.to(device) |
|
cldm, opt, loader = accelerator.prepare(cldm, opt, loader) |
|
pure_cldm: ControlLDM = accelerator.unwrap_model(cldm) |
|
|
|
|
|
global_step = 0 |
|
max_steps = cfg.train.train_steps |
|
step_loss = [] |
|
epoch = 0 |
|
epoch_loss = [] |
|
sampler = SpacedSampler(diffusion.betas) |
|
if accelerator.is_local_main_process: |
|
writer = SummaryWriter(exp_dir) |
|
print(f"Training for {max_steps} steps...") |
|
|
|
while global_step < max_steps: |
|
pbar = tqdm(iterable=None, disable=not accelerator.is_local_main_process, unit="batch", total=len(loader)) |
|
for gt, lq, prompt in loader: |
|
gt = rearrange(gt, "b h w c -> b c h w").contiguous().float().to(device) |
|
lq = rearrange(lq, "b h w c -> b c h w").contiguous().float().to(device) |
|
with torch.no_grad(): |
|
z_0 = pure_cldm.vae_encode(gt) |
|
clean = swinir(lq) |
|
cond = pure_cldm.prepare_condition(clean, prompt) |
|
t = torch.randint(0, diffusion.num_timesteps, (z_0.shape[0],), device=device) |
|
|
|
loss = diffusion.p_losses(cldm, z_0, t, cond) |
|
opt.zero_grad() |
|
accelerator.backward(loss) |
|
opt.step() |
|
|
|
accelerator.wait_for_everyone() |
|
|
|
global_step += 1 |
|
step_loss.append(loss.item()) |
|
epoch_loss.append(loss.item()) |
|
pbar.update(1) |
|
pbar.set_description(f"Epoch: {epoch:04d}, Global Step: {global_step:07d}, Loss: {loss.item():.6f}") |
|
|
|
|
|
if global_step % cfg.train.log_every == 0 and global_step > 0: |
|
|
|
avg_loss = accelerator.gather(torch.tensor(step_loss, device=device).unsqueeze(0)).mean().item() |
|
step_loss.clear() |
|
if accelerator.is_local_main_process: |
|
writer.add_scalar("loss/loss_simple_step", avg_loss, global_step) |
|
|
|
|
|
if global_step % cfg.train.ckpt_every == 0 and global_step > 0: |
|
if accelerator.is_local_main_process: |
|
checkpoint = pure_cldm.controlnet.state_dict() |
|
ckpt_path = f"{ckpt_dir}/{global_step:07d}.pt" |
|
torch.save(checkpoint, ckpt_path) |
|
|
|
if global_step % cfg.train.image_every == 0 or global_step == 1: |
|
N = 12 |
|
log_clean = clean[:N] |
|
log_cond = {k:v[:N] for k, v in cond.items()} |
|
log_gt, log_lq = gt[:N], lq[:N] |
|
log_prompt = prompt[:N] |
|
cldm.eval() |
|
with torch.no_grad(): |
|
z = sampler.sample( |
|
model=cldm, device=device, steps=50, batch_size=len(log_gt), x_size=z_0.shape[1:], |
|
cond=log_cond, uncond=None, cfg_scale=1.0, x_T=None, |
|
progress=accelerator.is_local_main_process, progress_leave=False |
|
) |
|
if accelerator.is_local_main_process: |
|
for tag, image in [ |
|
("image/samples", (pure_cldm.vae_decode(z) + 1) / 2), |
|
("image/gt", (log_gt + 1) / 2), |
|
("image/lq", log_lq), |
|
("image/condition", log_clean), |
|
("image/condition_decoded", (pure_cldm.vae_decode(log_cond["c_img"]) + 1) / 2), |
|
("image/prompt", (log_txt_as_img((512, 512), log_prompt) + 1) / 2) |
|
]: |
|
writer.add_image(tag, make_grid(image, nrow=4), global_step) |
|
cldm.train() |
|
accelerator.wait_for_everyone() |
|
if global_step == max_steps: |
|
break |
|
|
|
pbar.close() |
|
epoch += 1 |
|
avg_epoch_loss = accelerator.gather(torch.tensor(epoch_loss, device=device).unsqueeze(0)).mean().item() |
|
epoch_loss.clear() |
|
if accelerator.is_local_main_process: |
|
writer.add_scalar("loss/loss_simple_epoch", avg_epoch_loss, global_step) |
|
|
|
if accelerator.is_local_main_process: |
|
print("done!") |
|
writer.close() |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = ArgumentParser() |
|
parser.add_argument("--config", type=str, required=True) |
|
args = parser.parse_args() |
|
main(args) |
|
|