File size: 6,723 Bytes
0a63786
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32f018b
0a63786
 
 
 
 
 
 
 
 
 
32f018b
0a63786
 
 
 
2c33647
0a63786
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import torch
from omegaconf import OmegaConf
from pytorch_lightning.loggers import WandbLogger
from misc_utils.model_utils import instantiate_from_config, get_obj_from_str

from diffusers import AutoencoderKL
import os
import json

def get_models(args):
    unet = instantiate_from_config(args.unet)
    model_dict = {
        'unet': unet,
    }

    if args.get('vae'):
        vae = instantiate_from_config(args.vae)
        model_dict['vae'] = vae

    if args.get('text_model'):
        text_model = instantiate_from_config(args.text_model)
        model_dict['text_model'] = text_model

    if args.get('ctrlnet'): # 这边还可以加ctrlnet... (感觉是哪个地方搬来的代码)
        ctrlnet = instantiate_from_config(args.ctrlnet)
        model_dict['ctrlnet'] = ctrlnet

    return model_dict

def get_text_model(args):
    # 注意简化一下这个函数
    base_path = None
    if args.get('diffusion'):
        if args.diffusion.params.get('base_path'):# 这边有base path的情况下已经load参数了
            base_path = args.diffusion.params.base_path
    if args.get('text_model'):
        text_model = instantiate_from_config(args.text_model)
        return text_model
    return None

def get_vae(args):
    # 简化函数+1
    base_path = None
    if args.get('diffusion'):
        if args.diffusion.params.get('base_path'):
            base_path = args.diffusion.params.base_path
            vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae")
        return vae
    return None

def get_ic_models(args):
    unet = instantiate_from_config(args.unet)
    model_dict = {
        'unet': unet,
    }
    
    vae = get_vae(args) # 这边vae是直接diffusers中的组件加载的
    if vae:
        model_dict['vae'] = vae

    text_model = get_text_model(args) # text model的话整体没咋变, 主要更改了from_pretrained的来源
    if text_model:
        model_dict['text_model'] = text_model

    if args.get('ctrlnet'): # 这边还可以加ctrlnet... (感觉是哪个地方搬来的代码)
        ctrlnet = instantiate_from_config(args.ctrlnet)
        model_dict['ctrlnet'] = ctrlnet

    return model_dict

# def get_models(args):
#     unet = instantiate_from_config(args.unet)
#     model_dict = {
#         'unet': unet,
#     }

#     if args.get('vae'):
#         vae = instantiate_from_config(args.vae)
#         model_dict['vae'] = vae

#     if args.get('text_model'):
#         text_model = instantiate_from_config(args.text_model)
#         model_dict['text_model'] = text_model

#     if args.get('ctrlnet'): # 这边还可以加ctrlnet... (感觉是哪个地方搬来的代码)
#         ctrlnet = instantiate_from_config(args.ctrlnet)
#         model_dict['ctrlnet'] = ctrlnet

#     return model_dict

def get_DDPM(diffusion_configs, log_args={}, **models):
    diffusion_model_class = diffusion_configs['target']
    diffusion_args = diffusion_configs['params']
    DDPM_model = get_obj_from_str(diffusion_model_class) # pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporal
    ddpm_model = DDPM_model(
        log_args=log_args,
        **models,
        **diffusion_args
    )
    return ddpm_model


def get_logger(args):
    wandb_logger = WandbLogger(
        project=args["expt_name"],
    )
    return wandb_logger

def get_callbacks(args, wandb_logger):
    callbacks = []
    for callback in args['callbacks']:
        if callback.get('require_wandb', False):
            # we need to pass wandb logger to the callback
            callback_obj = get_obj_from_str(callback.target)
            callbacks.append(
                callback_obj(wandb_logger=wandb_logger, **callback.params)
            )
        else:
            callbacks.append(
                instantiate_from_config(callback)
            )
    return callbacks

def get_dataset(args):
    from torch.utils.data import DataLoader
    data_args = args['data']
    # import pdb; pdb.set_trace()
    train_set = instantiate_from_config(data_args['train'])
    val_set = instantiate_from_config(data_args['val'])
    # import pdb; pdb.set_trace()
    # import pdb; pdb.set_trace()
    train_loader = DataLoader(
        train_set, batch_size=data_args['batch_size'], shuffle=True,
        num_workers=4*len(args['trainer_args']['devices']), pin_memory=True
    )
    val_loader = DataLoader(
        val_set, batch_size=data_args['val_batch_size'],
        num_workers=len(args['trainer_args']['devices']), pin_memory=True
    ) # 不shuffle
    return train_loader, val_loader, train_set, val_set

def unit_test_create_model(config_path):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    conf = OmegaConf.load(config_path)
    models = get_ic_models(conf)
    ddpm = get_DDPM(conf['diffusion'], log_args=conf, **models)
    ddpm = ddpm.to(device)
    return ddpm

def unit_test_create_dataset(config_path, split='train'):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    conf = OmegaConf.load(config_path)
    train_loader, val_loader, train_set, val_set = get_dataset(conf)
    if split == 'train':
        batch = next(iter(train_loader))
    else:
        batch = next(iter(val_loader))
    for k, v in batch.items():
        if isinstance(v, torch.Tensor):
            batch[k] = v.to(device)
    return batch

def unit_test_training_step(config_path):
    ddpm = unit_test_create_model(config_path)
    batch = unit_test_create_dataset(config_path)
    res = ddpm.training_step(batch, 0)
    return res

def unit_test_val_step(config_path):
    ddpm = unit_test_create_model(config_path)
    batch = unit_test_create_dataset(config_path, split='val')
    res = ddpm.validation_step(batch, 0)
    return res

NEGATIVE_PROMPTS = "(((deformed))), blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar, multiple breasts, (mutated hands and fingers:1.5), (long body :1.3), (mutation, poorly drawn :1.2), black-white, bad anatomy, liquid body, liquidtongue, disfigured, malformed, mutated, anatomical nonsense, text font ui, error, malformed hands, long neck, blurred, lowers, low res, bad anatomy, bad proportions, bad shadow, uncoordinated body, unnatural body, fused breasts, bad breasts, huge breasts, poorly drawn breasts, extra breasts, liquid breasts, heavy breasts, missingbreasts, huge haunch, huge thighs, huge calf, bad hands, fused hand, missing hand, disappearing arms, disappearing thigh, disappearing calf, disappearing legs, fusedears, bad ears, poorly drawn ears, extra ears, liquid ears, heavy ears, missing ears, old photo, low res, black and white, black and white filter, colorless"