|
import torch, os, glob, random, copy |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader |
|
import torch.distributed as dist |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
import numpy as np |
|
from argparse import ArgumentParser |
|
from time import time |
|
from tqdm import tqdm |
|
from omegaconf import OmegaConf |
|
from dataset import RealESRGANDataset, RealESRGANDegrader |
|
from model import Net |
|
from ram.models.ram_lora import ram |
|
from torchvision import transforms |
|
from utils import add_lora_to_unet |
|
|
|
dist.init_process_group(backend="nccl", init_method="env://") |
|
rank = dist.get_rank() |
|
world_size = dist.get_world_size() |
|
|
|
parser = ArgumentParser() |
|
parser.add_argument("--epoch", type=int, default=200) |
|
parser.add_argument("--batch_size", type=int, default=12) |
|
parser.add_argument("--learning_rate", type=float, default=1e-4) |
|
parser.add_argument("--model_dir", type=str, default="weight") |
|
parser.add_argument("--log_dir", type=str, default="log") |
|
parser.add_argument("--save_interval", type=int, default=10) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
seed = rank |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
config = OmegaConf.load("config.yml") |
|
|
|
epoch = args.epoch |
|
learning_rate = args.learning_rate |
|
bsz = args.batch_size |
|
|
|
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") |
|
torch.backends.cudnn.allow_tf32 = True |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
if rank == 0: |
|
print("batch size per gpu =", bsz) |
|
|
|
from diffusers import StableDiffusionPipeline |
|
model_id = "stabilityai/stable-diffusion-2-1-base" |
|
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(device) |
|
|
|
vae = pipe.vae |
|
tokenizer = pipe.tokenizer |
|
unet = pipe.unet |
|
text_encoder = pipe.text_encoder |
|
|
|
unet_D = copy.deepcopy(unet) |
|
new_conv_in = torch.nn.Conv2d(256, 320, 3, padding=1).to(device) |
|
new_conv_in.weight.data = unet_D.conv_in.weight.data.repeat(1, 64, 1, 1) / 64 |
|
new_conv_in.bias.data = unet_D.conv_in.bias.data |
|
unet_D.conv_in = new_conv_in |
|
unet_D = add_lora_to_unet(unet_D) |
|
unet_D.set_adapters(["default_encoder", "default_decoder", "default_others"]) |
|
|
|
vae_teacher = copy.deepcopy(vae) |
|
unet_teacher = copy.deepcopy(unet) |
|
|
|
osediff = torch.load("./weight/pretrained/osediff.pkl", weights_only=False) |
|
vae_teacher.load_state_dict(osediff["vae"]) |
|
unet_teacher.load_state_dict(osediff["unet"]) |
|
|
|
from diffusers.models.autoencoders.vae import Decoder |
|
ckpt_halfdecoder = torch.load("./weight/pretrained/halfDecoder.ckpt", weights_only=False) |
|
decoder = Decoder(in_channels=4, |
|
out_channels=3, |
|
up_block_types=["UpDecoderBlock2D" for _ in range(4)], |
|
block_out_channels=[64, 128, 256, 256], |
|
layers_per_block=2, |
|
norm_num_groups=32, |
|
act_fn="silu", |
|
norm_type="group", |
|
mid_block_add_attention=True).to(device) |
|
decoder_ckpt = {} |
|
for k, v in ckpt_halfdecoder["state_dict"].items(): |
|
if "decoder" in k: |
|
new_k = k.replace("decoder.", "") |
|
decoder_ckpt[new_k] = v |
|
decoder.load_state_dict(decoder_ckpt, strict=True) |
|
|
|
ram_transforms = transforms.Compose([ |
|
transforms.Resize((384, 384)), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
DAPE = ram(pretrained="./weight/pretrained/ram_swin_large_14m.pth", |
|
pretrained_condition="./weight/pretrained/DAPE.pth", |
|
image_size=384, |
|
vit="swin_l").eval().to(device) |
|
|
|
vae.requires_grad_(False) |
|
unet.requires_grad_(False) |
|
text_encoder.requires_grad_(False) |
|
vae_teacher.requires_grad_(False) |
|
unet_teacher.requires_grad_(False) |
|
decoder.requires_grad_(False) |
|
DAPE.requires_grad_(False) |
|
|
|
model = DDP(Net(unet, copy.deepcopy(decoder)).to(device), device_ids=[rank]) |
|
model_D = DDP(unet_D.to(device), device_ids=[rank]) |
|
model.requires_grad_(True) |
|
model_D.requires_grad_(False) |
|
params_to_opt = [] |
|
for n, p in model_D.named_parameters(): |
|
if "lora" in n or "conv_in" in n: |
|
p.requires_grad = True |
|
params_to_opt.append(p) |
|
|
|
if rank == 0: |
|
param_cnt = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
print("#Param.", param_cnt/1e6, "M") |
|
|
|
dataset = RealESRGANDataset(config, bsz) |
|
degrader = RealESRGANDegrader(config, device) |
|
dataloader = DataLoader(dataset, batch_size=bsz, num_workers=8) |
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) |
|
optimizer_D = torch.optim.Adam(params_to_opt, lr=1e-6) |
|
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100,], gamma=0.5) |
|
scaler = torch.cuda.amp.GradScaler() |
|
|
|
model_dir = "./%s" % (args.model_dir,) |
|
log_path = "./%s/log.txt" % (args.log_dir,) |
|
os.makedirs(model_dir, exist_ok=True) |
|
os.makedirs(args.log_dir, exist_ok=True) |
|
|
|
print("start training...") |
|
timesteps = torch.tensor([999], device=device).long().expand(bsz,) |
|
alpha = pipe.scheduler.alphas_cumprod[999] |
|
for epoch_i in range(1, epoch + 1): |
|
start_time = time() |
|
loss_avg = 0.0 |
|
loss_distil_avg = 0.0 |
|
loss_adv_avg = 0.0 |
|
loss_D_avg = 0.0 |
|
iter_num = 0 |
|
dist.barrier() |
|
for batch in tqdm(dataloader): |
|
with torch.cuda.amp.autocast(enabled=True): |
|
with torch.no_grad(): |
|
LR, HR = degrader.degrade(batch) |
|
text_input = tokenizer(DAPE.generate_tag(ram_transforms(LR))[0], |
|
max_length=tokenizer.model_max_length, |
|
padding="max_length", return_tensors="pt").to(device) |
|
encoder_hidden_states = text_encoder(text_input.input_ids, return_dict=False)[0] |
|
LR, HR = LR * 2 - 1, HR * 2 - 1 |
|
LR_ = F.interpolate(LR, scale_factor=4, mode="bicubic") |
|
LR_latents = vae_teacher.encode(LR_).latent_dist.mean * vae_teacher.config.scaling_factor |
|
HR_latents = vae.encode(HR).latent_dist.mean |
|
pred_teacher = unet_teacher( |
|
LR_latents, |
|
timesteps, |
|
encoder_hidden_states=encoder_hidden_states, |
|
return_dict=False, |
|
)[0] |
|
z0_teacher = (LR_latents-((1-alpha)**0.5)*pred_teacher)/(alpha**0.5) |
|
z0_teacher = vae_teacher.post_quant_conv(z0_teacher / vae_teacher.config.scaling_factor) |
|
z0_teacher = decoder.conv_in(z0_teacher) |
|
z0_teacher = decoder.mid_block(z0_teacher) |
|
z0_gt = vae.post_quant_conv(HR_latents) |
|
z0_gt = decoder.conv_in(z0_gt) |
|
z0_gt = decoder.mid_block(z0_gt) |
|
z0_student = model(LR) |
|
loss_distil = (z0_student - z0_teacher).abs().mean() |
|
loss_adv = F.softplus(-model_D( |
|
z0_student, |
|
timesteps, |
|
encoder_hidden_states=encoder_hidden_states, |
|
return_dict=False, |
|
)[0]).mean() |
|
loss = loss_distil + loss_adv |
|
optimizer.zero_grad(set_to_none=True) |
|
scaler.scale(loss).backward() |
|
scaler.step(optimizer) |
|
scaler.update() |
|
with torch.cuda.amp.autocast(enabled=True): |
|
pred_real = model_D( |
|
z0_gt.detach(), |
|
timesteps, |
|
encoder_hidden_states=encoder_hidden_states, |
|
return_dict=False, |
|
)[0] |
|
pred_fake = model_D( |
|
z0_student.detach(), |
|
timesteps, |
|
encoder_hidden_states=encoder_hidden_states, |
|
return_dict=False, |
|
)[0] |
|
loss_D = F.softplus(pred_fake).mean() + F.softplus(-pred_real).mean() |
|
optimizer_D.zero_grad(set_to_none=True) |
|
scaler.scale(loss_D).backward() |
|
scaler.step(optimizer_D) |
|
scaler.update() |
|
loss_avg += loss.item() |
|
loss_distil_avg += loss_distil.item() |
|
loss_adv_avg += loss_adv.item() |
|
loss_D_avg += loss_D.item() |
|
iter_num += 1 |
|
|
|
|
|
|
|
|
|
scheduler.step() |
|
loss_avg /= iter_num |
|
loss_distil_avg /= iter_num |
|
loss_adv_avg /= iter_num |
|
loss_D_avg /= iter_num |
|
log_data = "[%d/%d] Average loss: %f, distil loss: %f, adv loss: %f, D loss: %f, time cost: %.2fs, cur lr is %f." % (epoch_i, epoch, loss_avg, loss_distil_avg, loss_adv_avg, loss_D_avg, time() - start_time, scheduler.get_last_lr()[0]) |
|
if rank == 0: |
|
print(log_data) |
|
with open(log_path, "a") as log_file: |
|
log_file.write(log_data + "\n") |
|
if epoch_i % args.save_interval == 0: |
|
torch.save(model.state_dict(), "./%s/net_params_%d.pkl" % (model_dir, epoch_i)) |
|
|