Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,723 Bytes
0a63786 32f018b 0a63786 32f018b 0a63786 2c33647 0a63786 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import torch
from omegaconf import OmegaConf
from pytorch_lightning.loggers import WandbLogger
from misc_utils.model_utils import instantiate_from_config, get_obj_from_str
from diffusers import AutoencoderKL
import os
import json
def get_models(args):
unet = instantiate_from_config(args.unet)
model_dict = {
'unet': unet,
}
if args.get('vae'):
vae = instantiate_from_config(args.vae)
model_dict['vae'] = vae
if args.get('text_model'):
text_model = instantiate_from_config(args.text_model)
model_dict['text_model'] = text_model
if args.get('ctrlnet'): # 这边还可以加ctrlnet... (感觉是哪个地方搬来的代码)
ctrlnet = instantiate_from_config(args.ctrlnet)
model_dict['ctrlnet'] = ctrlnet
return model_dict
def get_text_model(args):
# 注意简化一下这个函数
base_path = None
if args.get('diffusion'):
if args.diffusion.params.get('base_path'):# 这边有base path的情况下已经load参数了
base_path = args.diffusion.params.base_path
if args.get('text_model'):
text_model = instantiate_from_config(args.text_model)
return text_model
return None
def get_vae(args):
# 简化函数+1
base_path = None
if args.get('diffusion'):
if args.diffusion.params.get('base_path'):
base_path = args.diffusion.params.base_path
vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae")
return vae
return None
def get_ic_models(args):
unet = instantiate_from_config(args.unet)
model_dict = {
'unet': unet,
}
vae = get_vae(args) # 这边vae是直接diffusers中的组件加载的
if vae:
model_dict['vae'] = vae
text_model = get_text_model(args) # text model的话整体没咋变, 主要更改了from_pretrained的来源
if text_model:
model_dict['text_model'] = text_model
if args.get('ctrlnet'): # 这边还可以加ctrlnet... (感觉是哪个地方搬来的代码)
ctrlnet = instantiate_from_config(args.ctrlnet)
model_dict['ctrlnet'] = ctrlnet
return model_dict
# def get_models(args):
# unet = instantiate_from_config(args.unet)
# model_dict = {
# 'unet': unet,
# }
# if args.get('vae'):
# vae = instantiate_from_config(args.vae)
# model_dict['vae'] = vae
# if args.get('text_model'):
# text_model = instantiate_from_config(args.text_model)
# model_dict['text_model'] = text_model
# if args.get('ctrlnet'): # 这边还可以加ctrlnet... (感觉是哪个地方搬来的代码)
# ctrlnet = instantiate_from_config(args.ctrlnet)
# model_dict['ctrlnet'] = ctrlnet
# return model_dict
def get_DDPM(diffusion_configs, log_args={}, **models):
diffusion_model_class = diffusion_configs['target']
diffusion_args = diffusion_configs['params']
DDPM_model = get_obj_from_str(diffusion_model_class) # pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporal
ddpm_model = DDPM_model(
log_args=log_args,
**models,
**diffusion_args
)
return ddpm_model
def get_logger(args):
wandb_logger = WandbLogger(
project=args["expt_name"],
)
return wandb_logger
def get_callbacks(args, wandb_logger):
callbacks = []
for callback in args['callbacks']:
if callback.get('require_wandb', False):
# we need to pass wandb logger to the callback
callback_obj = get_obj_from_str(callback.target)
callbacks.append(
callback_obj(wandb_logger=wandb_logger, **callback.params)
)
else:
callbacks.append(
instantiate_from_config(callback)
)
return callbacks
def get_dataset(args):
from torch.utils.data import DataLoader
data_args = args['data']
# import pdb; pdb.set_trace()
train_set = instantiate_from_config(data_args['train'])
val_set = instantiate_from_config(data_args['val'])
# import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
train_loader = DataLoader(
train_set, batch_size=data_args['batch_size'], shuffle=True,
num_workers=4*len(args['trainer_args']['devices']), pin_memory=True
)
val_loader = DataLoader(
val_set, batch_size=data_args['val_batch_size'],
num_workers=len(args['trainer_args']['devices']), pin_memory=True
) # 不shuffle
return train_loader, val_loader, train_set, val_set
def unit_test_create_model(config_path):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
conf = OmegaConf.load(config_path)
models = get_ic_models(conf)
ddpm = get_DDPM(conf['diffusion'], log_args=conf, **models)
ddpm = ddpm.to(device)
return ddpm
def unit_test_create_dataset(config_path, split='train'):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
conf = OmegaConf.load(config_path)
train_loader, val_loader, train_set, val_set = get_dataset(conf)
if split == 'train':
batch = next(iter(train_loader))
else:
batch = next(iter(val_loader))
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(device)
return batch
def unit_test_training_step(config_path):
ddpm = unit_test_create_model(config_path)
batch = unit_test_create_dataset(config_path)
res = ddpm.training_step(batch, 0)
return res
def unit_test_val_step(config_path):
ddpm = unit_test_create_model(config_path)
batch = unit_test_create_dataset(config_path, split='val')
res = ddpm.validation_step(batch, 0)
return res
NEGATIVE_PROMPTS = "(((deformed))), blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar, multiple breasts, (mutated hands and fingers:1.5), (long body :1.3), (mutation, poorly drawn :1.2), black-white, bad anatomy, liquid body, liquidtongue, disfigured, malformed, mutated, anatomical nonsense, text font ui, error, malformed hands, long neck, blurred, lowers, low res, bad anatomy, bad proportions, bad shadow, uncoordinated body, unnatural body, fused breasts, bad breasts, huge breasts, poorly drawn breasts, extra breasts, liquid breasts, heavy breasts, missingbreasts, huge haunch, huge thighs, huge calf, bad hands, fused hand, missing hand, disappearing arms, disappearing thigh, disappearing calf, disappearing legs, fusedears, bad ears, poorly drawn ears, extra ears, liquid ears, heavy ears, missing ears, old photo, low res, black and white, black and white filter, colorless"
|