Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# InstructDiffusion | |
# Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix) | |
# Removed Pytorch-lightning and supported deepspeed by Zigang Geng (zigang@mail.ustc.edu.cn) | |
# -------------------------------------------------------- | |
import argparse, os, sys, datetime, glob | |
import numpy as np | |
import time | |
import json | |
import pickle | |
import wandb | |
import deepspeed | |
from packaging import version | |
from omegaconf import OmegaConf | |
from functools import partial | |
from PIL import Image | |
from timm.utils import AverageMeter | |
import torch | |
import torchvision | |
import torch.cuda.amp as amp | |
import torch.distributed as dist | |
import torch.backends.cudnn as cudnn | |
from torch.utils.data import DataLoader, Dataset, ConcatDataset | |
sys.path.append("./stable_diffusion") | |
from ldm.data.base import Txt2ImgIterableBaseDataset | |
from ldm.util import instantiate_from_config | |
from ldm.modules.ema import LitEma | |
from utils.logger import create_logger | |
from utils.utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper | |
from utils.deepspeed import create_ds_config | |
def wandb_log(*args, **kwargs): | |
if dist.get_rank() == 0: | |
wandb.log(*args, **kwargs) | |
def get_parser(**parser_kwargs): | |
def str2bool(v): | |
if isinstance(v, bool): | |
return v | |
if v.lower() in ("yes", "true", "t", "y", "1"): | |
return True | |
elif v.lower() in ("no", "false", "f", "n", "0"): | |
return False | |
else: | |
raise argparse.ArgumentTypeError("Boolean value expected.") | |
parser = argparse.ArgumentParser(**parser_kwargs) | |
parser.add_argument( | |
"-n", | |
"--name", | |
type=str, | |
const=True, | |
default="", | |
nargs="?", | |
help="postfix for logdir", | |
) | |
parser.add_argument( | |
"-r", | |
"--resume", | |
type=str, | |
const=True, | |
default="", | |
nargs="?", | |
help="resume from logdir or checkpoint in logdir", | |
) | |
parser.add_argument( | |
"-b", | |
"--base", | |
nargs="*", | |
metavar="base_config.yaml", | |
help="paths to base configs. Loaded from left-to-right. " | |
"Parameters can be overwritten or added with command-line options of the form `--key value`.", | |
default=list(), | |
) | |
parser.add_argument( | |
"-t", | |
"--train", | |
type=str2bool, | |
const=True, | |
default=False, | |
nargs="?", | |
help="train", | |
) | |
parser.add_argument( | |
"--no-test", | |
type=str2bool, | |
const=True, | |
default=False, | |
nargs="?", | |
help="disable test", | |
) | |
parser.add_argument( | |
"-p", | |
"--project", | |
help="name of new or path to existing project" | |
) | |
parser.add_argument( | |
"-d", | |
"--debug", | |
type=str2bool, | |
nargs="?", | |
const=True, | |
default=False, | |
help="enable post-mortem debugging", | |
) | |
parser.add_argument( | |
"-s", | |
"--seed", | |
type=int, | |
default=23, | |
help="seed for seed_everything", | |
) | |
parser.add_argument( | |
"-f", | |
"--postfix", | |
type=str, | |
default="", | |
help="post-postfix for default name", | |
) | |
parser.add_argument( | |
"-l", | |
"--logdir", | |
type=str, | |
default="logs", | |
help="directory for logging dat shit", | |
) | |
parser.add_argument( | |
"--scale_lr", | |
action="store_true", | |
default=False, | |
help="scale base-lr by ngpu * batch_size * n_accumulate", | |
) | |
parser.add_argument( | |
"--amd", | |
action="store_true", | |
default=False, | |
help="amd", | |
) | |
parser.add_argument( | |
"--local_rank", | |
type=int, | |
# required=False, | |
default=int(os.environ.get('LOCAL_RANK', 0)), | |
help="local rank for DistributedDataParallel", | |
) | |
return parser | |
class WrappedDataset(Dataset): | |
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" | |
def __init__(self, dataset): | |
self.data = dataset | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
return self.data[idx] | |
class DataModuleFromConfig(): | |
def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, | |
wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, | |
shuffle_val_dataloader=False): | |
super().__init__() | |
self.batch_size = batch_size | |
self.dataset_configs = dict() | |
self.num_workers = num_workers if num_workers is not None else batch_size * 2 | |
self.use_worker_init_fn = use_worker_init_fn | |
if train is not None: | |
if "target" in train: | |
self.dataset_configs["train"] = train | |
self.train_dataloader = self._train_dataloader | |
else: | |
for ds in train: | |
ds_name = str([key for key in ds.keys()][0]) | |
self.dataset_configs[ds_name] = ds | |
self.train_dataloader = self._train_concat_dataloader | |
if validation is not None: | |
self.dataset_configs["validation"] = validation | |
self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) | |
if test is not None: | |
self.dataset_configs["test"] = test | |
self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) | |
if predict is not None: | |
self.dataset_configs["predict"] = predict | |
self.predict_dataloader = self._predict_dataloader | |
self.wrap = wrap | |
def prepare_data(self): | |
for data_cfg in self.dataset_configs.values(): | |
instantiate_from_config(data_cfg) | |
def setup(self, stage=None): | |
self.datasets = dict( | |
(k, instantiate_from_config(self.dataset_configs[k])) | |
for k in self.dataset_configs) | |
if self.wrap: | |
for k in self.datasets: | |
self.datasets[k] = WrappedDataset(self.datasets[k]) | |
def _train_concat_dataloader(self): | |
is_iterable_dataset = isinstance(self.datasets['ds1'], Txt2ImgIterableBaseDataset) | |
if is_iterable_dataset or self.use_worker_init_fn: | |
init_fn = worker_init_fn | |
else: | |
init_fn = None | |
concat_dataset = [] | |
for ds in self.datasets.keys(): | |
concat_dataset.append(self.datasets[ds]) | |
concat_dataset = ConcatDataset(concat_dataset) | |
sampler_train = torch.utils.data.DistributedSampler( | |
concat_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True | |
) | |
return DataLoader(concat_dataset, batch_size=self.batch_size, sampler=sampler_train, | |
num_workers=self.num_workers, worker_init_fn=init_fn, persistent_workers=True) | |
def _train_dataloader(self): | |
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) | |
if is_iterable_dataset or self.use_worker_init_fn: | |
init_fn = worker_init_fn | |
else: | |
init_fn = None | |
sampler_train = torch.utils.data.DistributedSampler( | |
self.datasets["train"], num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True | |
) | |
return DataLoader(self.datasets["train"], batch_size=self.batch_size, sampler=sampler_train, | |
num_workers=self.num_workers, worker_init_fn=init_fn, persistent_workers=True) | |
def _val_dataloader(self, shuffle=False): | |
if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: | |
init_fn = worker_init_fn | |
else: | |
init_fn = None | |
return DataLoader(self.datasets["validation"], | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
worker_init_fn=init_fn, | |
shuffle=shuffle, persistent_workers=True) | |
def _test_dataloader(self, shuffle=False): | |
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) | |
if is_iterable_dataset or self.use_worker_init_fn: | |
init_fn = worker_init_fn | |
else: | |
init_fn = None | |
# do not shuffle dataloader for iterable dataset | |
shuffle = shuffle and (not is_iterable_dataset) | |
return DataLoader(self.datasets["test"], batch_size=self.batch_size, | |
num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle, persistent_workers=True) | |
def _predict_dataloader(self, shuffle=False): | |
if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: | |
init_fn = worker_init_fn | |
else: | |
init_fn = None | |
return DataLoader(self.datasets["predict"], batch_size=self.batch_size, | |
num_workers=self.num_workers, worker_init_fn=init_fn, persistent_workers=True) | |
def train_one_epoch(config, model, model_ema, data_loader, val_data_loader, optimizer, epoch, lr_scheduler, scaler): | |
model.train() | |
optimizer.zero_grad() | |
num_steps = len(data_loader) | |
accumul_steps = config.trainer.accumulate_grad_batches | |
batch_time = AverageMeter() | |
loss_meter = AverageMeter() | |
val_loss_meter = AverageMeter() | |
norm_meter = AverageMeter() | |
loss_scale_meter = AverageMeter() | |
loss_scale_meter_min = AverageMeter() | |
start = time.time() | |
end = time.time() | |
for idx, batch in enumerate(data_loader): | |
batch_size = batch['edited'].shape[0] | |
if config.model.params.deepspeed != '': | |
loss, _ = model(batch, idx, accumul_steps) | |
model.backward(loss) | |
model.step() | |
loss_scale = optimizer.cur_scale | |
grad_norm = model.get_global_grad_norm() | |
with torch.no_grad(): | |
if idx % config.trainer.accumulate_grad_batches == 0: | |
model_ema(model) | |
loss_number = loss.item() | |
else: | |
with amp.autocast(enabled=config.model.params.fp16): | |
loss, _ = model(batch, idx, accumul_steps) | |
if config.trainer.accumulate_grad_batches > 1: | |
loss = loss / config.trainer.accumulate_grad_batches | |
scaler.scale(loss).backward() | |
# loss.backward() | |
if config.trainer.clip_grad > 0.0: | |
scaler.unscale_(optimizer) | |
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.trainer.clip_grad) | |
else: | |
grad_norm = get_grad_norm(model.parameters()) | |
if (idx + 1) % config.trainer.accumulate_grad_batches == 0: | |
scaler.step(optimizer) | |
optimizer.zero_grad() | |
scaler.update() | |
# scaler.unscale_grads() | |
# optimizer.step() | |
# optimizer.zero_grad() | |
# lr_scheduler.step_update(epoch * num_steps + idx) | |
else: | |
optimizer.zero_grad() | |
scaler.scale(loss).backward() | |
if config.trainer.clip_grad > 0.0: | |
scaler.unscale_(optimizer) | |
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.trainer.clip_grad) | |
else: | |
grad_norm = get_grad_norm(model.parameters()) | |
scaler.step(optimizer) | |
scaler.update() | |
# lr_scheduler.step_update(epoch * num_steps + idx) | |
loss_scale = scaler.get_scale() | |
loss_number = loss.item() * config.trainer.accumulate_grad_batches | |
torch.cuda.synchronize() | |
loss_meter.update(loss_number, batch_size) | |
if grad_norm is not None: | |
norm_meter.update(grad_norm) | |
else: | |
norm_meter.update(0.0) | |
loss_scale_meter.update(loss_scale) | |
# loss_scale_meter.update(0.0) | |
batch_time.update(time.time() - end) | |
end = time.time() | |
if idx % 100 == 0: | |
lr = optimizer.param_groups[0]['lr'] | |
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) | |
etas = batch_time.avg * (num_steps - idx) | |
logger.info( | |
f'Train: [{epoch}][{idx}/{num_steps}]\t' | |
f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' | |
f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' | |
f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' | |
f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' | |
f'loss_scale {loss_scale_meter.val:.4f} ({loss_scale_meter.avg:.4f})\t' | |
f'mem {memory_used:.0f}MB') | |
if (epoch * num_steps + idx) % 100 == 0: | |
log_message = dict( | |
lr=optimizer.param_groups[0]['lr'], | |
time=batch_time.val, | |
epoch=epoch, | |
iter=idx, | |
loss=loss_meter.val, | |
grad_norm=norm_meter.val, | |
loss_scale=loss_scale_meter.val, | |
memory=torch.cuda.max_memory_allocated() / (1024.0 * 1024.0), | |
global_iter=epoch * num_steps + idx) | |
# log_message.update({'ref_img': wandb.Image(unnormalize(img[:8].cpu().float())), 'mask': wandb.Image(mask[:8].cpu().float().unsqueeze(1))}) | |
# if x_rec is not None: | |
# log_message.update({'rec_img': wandb.Image(unnormalize(x_rec[:8].cpu().float()))}) | |
wandb_log( | |
data=log_message, | |
step=epoch * num_steps + idx, | |
) | |
if idx == num_steps - 1: | |
with torch.no_grad(): | |
model_ema.store(model.parameters()) | |
model_ema.copy_to(model) | |
for val_idx, batch in enumerate(val_data_loader): | |
batch_size = batch['edited'].shape[0] | |
loss, _ = model(batch, -1, 1) | |
loss_number = loss.item() | |
val_loss_meter.update(loss_number, batch_size) | |
if val_idx % 10 == 0: | |
logger.info( | |
f'Val: [{val_idx}/{len(val_data_loader)}]\t' | |
f'loss {val_loss_meter.val:.4f} ({val_loss_meter.avg:.4f})\t') | |
if val_idx == 50: | |
break | |
model_ema.restore(model.parameters()) | |
epoch_time = time.time() - start | |
logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") | |
if __name__ == "__main__": | |
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | |
# add cwd for convenience and to make classes in this file available when | |
# running as `python main.py` | |
# (in particular `main.DataModuleFromConfig`) | |
sys.path.append(os.getcwd()) | |
parser = get_parser() | |
opt, unknown = parser.parse_known_args() | |
assert opt.name | |
cfg_fname = os.path.split(opt.base[0])[-1] | |
cfg_name = os.path.splitext(cfg_fname)[0] | |
nowname = f"{cfg_name}_{opt.name}" | |
logdir = os.path.join(opt.logdir, nowname) | |
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: | |
rank = int(os.environ["RANK"]) | |
world_size = int(os.environ['WORLD_SIZE']) | |
print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") | |
else: | |
rank = -1 | |
world_size = -1 | |
if opt.amd: | |
os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.local_rank) | |
torch.distributed.init_process_group(backend='gloo', init_method='env://', world_size=world_size, rank=rank) | |
else: | |
torch.cuda.set_device(opt.local_rank) | |
torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) | |
torch.distributed.barrier() | |
seed = opt.seed + dist.get_rank() | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
cudnn.benchmark = True | |
ckptdir = os.path.join(logdir, "checkpoints") | |
cfgdir = os.path.join(logdir, "configs") | |
os.makedirs(logdir, exist_ok=True) | |
os.makedirs(ckptdir, exist_ok=True) | |
os.makedirs(cfgdir, exist_ok=True) | |
# init and save configs | |
# config: the configs in the config file | |
configs = [OmegaConf.load(cfg) for cfg in opt.base] | |
cli = OmegaConf.from_dotlist(unknown) | |
config = OmegaConf.merge(*configs, cli) | |
if config.model.params.deepspeed != '': | |
create_ds_config(opt, config, cfgdir) | |
if dist.get_rank() == 0: | |
run = wandb.init( | |
id=nowname, | |
name=nowname, | |
project='readoutpose', | |
config=OmegaConf.to_container(config, resolve=True), | |
) | |
logger = create_logger(output_dir=logdir, dist_rank=dist.get_rank(), name=f"{nowname}") | |
resume_file = auto_resume_helper(config, ckptdir) | |
if resume_file: | |
resume = True | |
logger.info(f'resume checkpoint in {resume_file}') | |
else: | |
resume = False | |
logger.info(f'no checkpoint found in {ckptdir}, ignoring auto resume') | |
# model | |
model = instantiate_from_config(config.model) | |
model_ema = LitEma(model, decay_resume=config.model.params.get('ema_resume', 0.9999)) | |
# data | |
data = instantiate_from_config(config.data) | |
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html | |
# calling these ourselves should not be necessary but it is. | |
# lightning still takes care of proper multiprocessing though | |
data.prepare_data() | |
data.setup() | |
data_loader_train = data.train_dataloader() | |
data_loader_val = data.val_dataloader() | |
print("#### Data #####") | |
for k in data.datasets: | |
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") | |
# configure learning rate | |
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate | |
ngpu = dist.get_world_size() | |
if 'accumulate_grad_batches' in config.trainer: | |
accumulate_grad_batches = config.trainer.accumulate_grad_batches | |
else: | |
accumulate_grad_batches = 1 | |
print(f"accumulate_grad_batches = {accumulate_grad_batches}") | |
if opt.scale_lr: | |
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr | |
print( | |
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( | |
model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) | |
else: | |
model.learning_rate = base_lr | |
print("++++ NOT USING LR SCALING ++++") | |
print(f"Setting learning rate to {model.learning_rate:.2e}") | |
if not opt.amd: | |
model.cuda() | |
if config.model.params.fp16 and config.model.params.deepspeed == '': | |
scaler = amp.GradScaler() | |
param_groups = model.parameters() | |
else: | |
scaler = None | |
param_groups = model.parameters() | |
if config.model.params.deepspeed != '': | |
model, optimizer, _, _ = deepspeed.initialize( | |
args=config, | |
model=model, | |
model_parameters=param_groups, | |
dist_init_required=False, | |
) | |
for name, param in model.named_parameters(): | |
param.global_name = name | |
model_without_ddp = model | |
lr_scheduler = None | |
model_ema = model_ema.to(next(model.parameters()).device) | |
else: | |
optimizer, lr_scheduler = model.configure_optimizers() | |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[opt.local_rank], broadcast_buffers=False) | |
model_without_ddp = model.module | |
# print(optimizer.param_groups[1]) | |
if opt.resume != '': | |
resume_file = opt.resume | |
if resume_file: | |
_, start_epoch = load_checkpoint(resume_file, config, model_without_ddp, model_ema, optimizer, lr_scheduler, scaler, logger) | |
else: | |
start_epoch = 0 | |
logger.info("Start training") | |
start_time = time.time() | |
for epoch in range(start_epoch, config.trainer.max_epochs): | |
data_loader_train.sampler.set_epoch(epoch) | |
train_one_epoch(config, model, model_ema, data_loader_train, data_loader_val, optimizer, epoch, lr_scheduler, scaler) | |
if epoch % config.trainer.save_freq == 0: | |
save_checkpoint(ckptdir, config, epoch, model_without_ddp, model_ema, 0., optimizer, lr_scheduler, scaler, logger) | |
total_time = time.time() - start_time | |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
logger.info('Training time {}'.format(total_time_str)) | |