import torch from omegaconf import OmegaConf from pytorch_lightning.loggers import WandbLogger from misc_utils.model_utils import instantiate_from_config, get_obj_from_str from diffusers import AutoencoderKL import os import json def get_models(args): unet = instantiate_from_config(args.unet) model_dict = { 'unet': unet, } if args.get('vae'): vae = instantiate_from_config(args.vae) model_dict['vae'] = vae if args.get('text_model'): text_model = instantiate_from_config(args.text_model) model_dict['text_model'] = text_model if args.get('ctrlnet'): # 这边还可以加ctrlnet... (感觉是哪个地方搬来的代码) ctrlnet = instantiate_from_config(args.ctrlnet) model_dict['ctrlnet'] = ctrlnet return model_dict def get_text_model(args): # 注意简化一下这个函数 base_path = None if args.get('diffusion'): if args.diffusion.params.get('base_path'):# 这边有base path的情况下已经load参数了 base_path = args.diffusion.params.base_path if args.get('text_model'): text_model = instantiate_from_config(args.text_model) return text_model return None def get_vae(args): # 简化函数+1 base_path = None if args.get('diffusion'): if args.diffusion.params.get('base_path'): base_path = args.diffusion.params.base_path vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae") return vae return None def get_ic_models(args): unet = instantiate_from_config(args.unet) model_dict = { 'unet': unet, } vae = get_vae(args) # 这边vae是直接diffusers中的组件加载的 if vae: model_dict['vae'] = vae text_model = get_text_model(args) # text model的话整体没咋变, 主要更改了from_pretrained的来源 if text_model: model_dict['text_model'] = text_model if args.get('ctrlnet'): # 这边还可以加ctrlnet... (感觉是哪个地方搬来的代码) ctrlnet = instantiate_from_config(args.ctrlnet) model_dict['ctrlnet'] = ctrlnet return model_dict # def get_models(args): # unet = instantiate_from_config(args.unet) # model_dict = { # 'unet': unet, # } # if args.get('vae'): # vae = instantiate_from_config(args.vae) # model_dict['vae'] = vae # if args.get('text_model'): # text_model = instantiate_from_config(args.text_model) # model_dict['text_model'] = text_model # if args.get('ctrlnet'): # 这边还可以加ctrlnet... (感觉是哪个地方搬来的代码) # ctrlnet = instantiate_from_config(args.ctrlnet) # model_dict['ctrlnet'] = ctrlnet # return model_dict def get_DDPM(diffusion_configs, log_args={}, **models): diffusion_model_class = diffusion_configs['target'] diffusion_args = diffusion_configs['params'] DDPM_model = get_obj_from_str(diffusion_model_class) # pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporal ddpm_model = DDPM_model( log_args=log_args, **models, **diffusion_args ) return ddpm_model def get_logger(args): wandb_logger = WandbLogger( project=args["expt_name"], ) return wandb_logger def get_callbacks(args, wandb_logger): callbacks = [] for callback in args['callbacks']: if callback.get('require_wandb', False): # we need to pass wandb logger to the callback callback_obj = get_obj_from_str(callback.target) callbacks.append( callback_obj(wandb_logger=wandb_logger, **callback.params) ) else: callbacks.append( instantiate_from_config(callback) ) return callbacks def get_dataset(args): from torch.utils.data import DataLoader data_args = args['data'] # import pdb; pdb.set_trace() train_set = instantiate_from_config(data_args['train']) val_set = instantiate_from_config(data_args['val']) # import pdb; pdb.set_trace() # import pdb; pdb.set_trace() train_loader = DataLoader( train_set, batch_size=data_args['batch_size'], shuffle=True, num_workers=4*len(args['trainer_args']['devices']), pin_memory=True ) val_loader = DataLoader( val_set, batch_size=data_args['val_batch_size'], num_workers=len(args['trainer_args']['devices']), pin_memory=True ) # 不shuffle return train_loader, val_loader, train_set, val_set def unit_test_create_model(config_path): device = 'cuda' if torch.cuda.is_available() else 'cpu' conf = OmegaConf.load(config_path) models = get_ic_models(conf) ddpm = get_DDPM(conf['diffusion'], log_args=conf, **models) ddpm = ddpm.to(device) return ddpm def unit_test_create_dataset(config_path, split='train'): device = 'cuda' if torch.cuda.is_available() else 'cpu' conf = OmegaConf.load(config_path) train_loader, val_loader, train_set, val_set = get_dataset(conf) if split == 'train': batch = next(iter(train_loader)) else: batch = next(iter(val_loader)) for k, v in batch.items(): if isinstance(v, torch.Tensor): batch[k] = v.to(device) return batch def unit_test_training_step(config_path): ddpm = unit_test_create_model(config_path) batch = unit_test_create_dataset(config_path) res = ddpm.training_step(batch, 0) return res def unit_test_val_step(config_path): ddpm = unit_test_create_model(config_path) batch = unit_test_create_dataset(config_path, split='val') res = ddpm.validation_step(batch, 0) return res NEGATIVE_PROMPTS = "(((deformed))), blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar, multiple breasts, (mutated hands and fingers:1.5), (long body :1.3), (mutation, poorly drawn :1.2), black-white, bad anatomy, liquid body, liquidtongue, disfigured, malformed, mutated, anatomical nonsense, text font ui, error, malformed hands, long neck, blurred, lowers, low res, bad anatomy, bad proportions, bad shadow, uncoordinated body, unnatural body, fused breasts, bad breasts, huge breasts, poorly drawn breasts, extra breasts, liquid breasts, heavy breasts, missingbreasts, huge haunch, huge thighs, huge calf, bad hands, fused hand, missing hand, disappearing arms, disappearing thigh, disappearing calf, disappearing legs, fusedears, bad ears, poorly drawn ears, extra ears, liquid ears, heavy ears, missing ears, old photo, low res, black and white, black and white filter, colorless"