Diffusion-DDIM / train.py
Yash Nagraj
Sampled Image using DDIM in Half the number of timesteps
00315bf
import os
from typing import Dict
import torch
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.utils import save_image
from diffusion import GaussianDiffusionSampler, GaussianDiffusionTrainer, DDIMSampler
from model import UNet
from scheduler import GradualWarmupScheduler
def train(modelConfig: Dict):
device = torch.device(modelConfig['device'])
dataset = CIFAR10(
"./",train=True,download=True,transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
)
dataloader = DataLoader(dataset,batch_size=modelConfig['batch_size'],shuffle=True, num_workers=4,drop_last=True,pin_memory=True)
net_model = UNet(modelConfig['T'],modelConfig['channel'],modelConfig['ch_mult'],modelConfig['attn'],
modelConfig['num_res_blocks'],modelConfig['dropout'])
optimizer = optim.AdamW(net_model.parameters(),modelConfig['lr'],weight_decay=1e-4)
cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer,modelConfig['epochs'],eta_min=0,last_epoch=-1
)
warmupScheduler = GradualWarmupScheduler(
optimizer,modelConfig['multiplier'],modelConfig['epochs'] // 10,
cosineScheduler
)
trainer = GaussianDiffusionTrainer(
modelConfig['beta_1'],
modelConfig['beta_T'],
modelConfig['T'],
net_model).to(device)
for epoch in range(modelConfig['epochs']):
with tqdm(dataloader,dynamic_ncols=True) as tqdmDataLoader:
for images,_ in tqdmDataLoader:
optimizer.zero_grad()
x_0 = images.to(device)
loss = trainer(x_0).sum() / 1000
loss.backward()
torch.nn.utils.clip_grad_norm_(
net_model.parameters(),modelConfig['grad_clip']
)
optimizer.step()
tqdmDataLoader.set_postfix(ordered_dict={
"epoch": epoch,
"loss: ": loss.item(),
"img shape: ": x_0.shape,
"LR": optimizer.state_dict()['param_groups'][0]["lr"]
})
warmupScheduler.step()
torch.save(net_model,os.path.join(modelConfig['checkpoint_dir'] + f"ckpt_{epoch}.pth"))
def eval(modelConfig:Dict):
with torch.no_grad():
device = torch.device(modelConfig['device'])
model = torch.load(os.path.join(modelConfig['checkpoint_dir'],modelConfig['test_load_weight']),device)
print("Model loaded")
model.eval()
sampler = GaussianDiffusionSampler(
modelConfig['beta_1'], modelConfig['beta_T'],
model,modelConfig['T']
)
noisyImage = torch.randn(
size=[modelConfig['batch_size'],3,32,32],
device=device
)
sampledImgs = sampler(noisyImage)
sampledImgs = sampledImgs * 0.5 + 0.5
save_image(sampledImgs,
os.path.join(modelConfig['sample_dir'],modelConfig['sampledImgName']),
nrow = modelConfig['nrow']
)
def eval_ddim(modelConfig:Dict):
with torch.no_grad():
device = torch.device(modelConfig['device'])
model = torch.load(os.path.join(modelConfig['checkpoint_dir'],modelConfig['test_load_weight']),device)
print("Model loaded")
model.eval()
sampler = DDIMSampler(
modelConfig['beta_1'], modelConfig['beta_T'],
model,modelConfig['T']
)
noisyImage = torch.randn(
size=[modelConfig['batch_size'],3,32,32],
device=device
)
sampledImgs = sampler(noisyImage)
sampledImgs = sampledImgs * 0.5 + 0.5
save_image(sampledImgs,
os.path.join(modelConfig['sample_dir'],modelConfig['sampledImgName']),
nrow = modelConfig['nrow']
)