RelightVid / misc_utils /train_utils.py
aleafy's picture
modify
2c33647
import torch
from omegaconf import OmegaConf
from pytorch_lightning.loggers import WandbLogger
from misc_utils.model_utils import instantiate_from_config, get_obj_from_str
from diffusers import AutoencoderKL
import os
import json
def get_models(args):
unet = instantiate_from_config(args.unet)
model_dict = {
'unet': unet,
}
if args.get('vae'):
vae = instantiate_from_config(args.vae)
model_dict['vae'] = vae
if args.get('text_model'):
text_model = instantiate_from_config(args.text_model)
model_dict['text_model'] = text_model
if args.get('ctrlnet'): # 这边还可以加ctrlnet... (感觉是哪个地方搬来的代码)
ctrlnet = instantiate_from_config(args.ctrlnet)
model_dict['ctrlnet'] = ctrlnet
return model_dict
def get_text_model(args):
# 注意简化一下这个函数
base_path = None
if args.get('diffusion'):
if args.diffusion.params.get('base_path'):# 这边有base path的情况下已经load参数了
base_path = args.diffusion.params.base_path
if args.get('text_model'):
text_model = instantiate_from_config(args.text_model)
return text_model
return None
def get_vae(args):
# 简化函数+1
base_path = None
if args.get('diffusion'):
if args.diffusion.params.get('base_path'):
base_path = args.diffusion.params.base_path
vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae")
return vae
return None
def get_ic_models(args):
unet = instantiate_from_config(args.unet)
model_dict = {
'unet': unet,
}
vae = get_vae(args) # 这边vae是直接diffusers中的组件加载的
if vae:
model_dict['vae'] = vae
text_model = get_text_model(args) # text model的话整体没咋变, 主要更改了from_pretrained的来源
if text_model:
model_dict['text_model'] = text_model
if args.get('ctrlnet'): # 这边还可以加ctrlnet... (感觉是哪个地方搬来的代码)
ctrlnet = instantiate_from_config(args.ctrlnet)
model_dict['ctrlnet'] = ctrlnet
return model_dict
# def get_models(args):
# unet = instantiate_from_config(args.unet)
# model_dict = {
# 'unet': unet,
# }
# if args.get('vae'):
# vae = instantiate_from_config(args.vae)
# model_dict['vae'] = vae
# if args.get('text_model'):
# text_model = instantiate_from_config(args.text_model)
# model_dict['text_model'] = text_model
# if args.get('ctrlnet'): # 这边还可以加ctrlnet... (感觉是哪个地方搬来的代码)
# ctrlnet = instantiate_from_config(args.ctrlnet)
# model_dict['ctrlnet'] = ctrlnet
# return model_dict
def get_DDPM(diffusion_configs, log_args={}, **models):
diffusion_model_class = diffusion_configs['target']
diffusion_args = diffusion_configs['params']
DDPM_model = get_obj_from_str(diffusion_model_class) # pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporal
ddpm_model = DDPM_model(
log_args=log_args,
**models,
**diffusion_args
)
return ddpm_model
def get_logger(args):
wandb_logger = WandbLogger(
project=args["expt_name"],
)
return wandb_logger
def get_callbacks(args, wandb_logger):
callbacks = []
for callback in args['callbacks']:
if callback.get('require_wandb', False):
# we need to pass wandb logger to the callback
callback_obj = get_obj_from_str(callback.target)
callbacks.append(
callback_obj(wandb_logger=wandb_logger, **callback.params)
)
else:
callbacks.append(
instantiate_from_config(callback)
)
return callbacks
def get_dataset(args):
from torch.utils.data import DataLoader
data_args = args['data']
# import pdb; pdb.set_trace()
train_set = instantiate_from_config(data_args['train'])
val_set = instantiate_from_config(data_args['val'])
# import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
train_loader = DataLoader(
train_set, batch_size=data_args['batch_size'], shuffle=True,
num_workers=4*len(args['trainer_args']['devices']), pin_memory=True
)
val_loader = DataLoader(
val_set, batch_size=data_args['val_batch_size'],
num_workers=len(args['trainer_args']['devices']), pin_memory=True
) # 不shuffle
return train_loader, val_loader, train_set, val_set
def unit_test_create_model(config_path):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
conf = OmegaConf.load(config_path)
models = get_ic_models(conf)
ddpm = get_DDPM(conf['diffusion'], log_args=conf, **models)
ddpm = ddpm.to(device)
return ddpm
def unit_test_create_dataset(config_path, split='train'):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
conf = OmegaConf.load(config_path)
train_loader, val_loader, train_set, val_set = get_dataset(conf)
if split == 'train':
batch = next(iter(train_loader))
else:
batch = next(iter(val_loader))
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(device)
return batch
def unit_test_training_step(config_path):
ddpm = unit_test_create_model(config_path)
batch = unit_test_create_dataset(config_path)
res = ddpm.training_step(batch, 0)
return res
def unit_test_val_step(config_path):
ddpm = unit_test_create_model(config_path)
batch = unit_test_create_dataset(config_path, split='val')
res = ddpm.validation_step(batch, 0)
return res
NEGATIVE_PROMPTS = "(((deformed))), blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar, multiple breasts, (mutated hands and fingers:1.5), (long body :1.3), (mutation, poorly drawn :1.2), black-white, bad anatomy, liquid body, liquidtongue, disfigured, malformed, mutated, anatomical nonsense, text font ui, error, malformed hands, long neck, blurred, lowers, low res, bad anatomy, bad proportions, bad shadow, uncoordinated body, unnatural body, fused breasts, bad breasts, huge breasts, poorly drawn breasts, extra breasts, liquid breasts, heavy breasts, missingbreasts, huge haunch, huge thighs, huge calf, bad hands, fused hand, missing hand, disappearing arms, disappearing thigh, disappearing calf, disappearing legs, fusedears, bad ears, poorly drawn ears, extra ears, liquid ears, heavy ears, missing ears, old photo, low res, black and white, black and white filter, colorless"