File size: 1,928 Bytes
c05d22e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch

from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.util import instantiate_from_config


class T2IAdapterCannyBase(LatentDiffusion):

    def __init__(self, adapter_config, extra_cond_key, noise_schedule, *args, **kwargs):
        super(T2IAdapterCannyBase, self).__init__(*args, **kwargs)
        self.adapter = instantiate_from_config(adapter_config)
        self.extra_cond_key = extra_cond_key
        self.noise_schedule = noise_schedule

    def shared_step(self, batch, **kwargs):
        for k in self.ucg_training:
            p = self.ucg_training[k]
            for i in range(len(batch[k])):
                if self.ucg_prng.choice(2, p=[1 - p, p]):
                    if isinstance(batch[k], list):
                        batch[k][i] = ""
                    else:
                        raise NotImplementedError("only text ucg is currently supported")
        batch['jpg'] = batch['jpg'] * 2 - 1
        x, c = self.get_input(batch, self.first_stage_key)
        extra_cond = super(LatentDiffusion, self).get_input(batch, self.extra_cond_key).to(self.device)
        features_adapter = self.adapter(extra_cond)
        t = self.get_time_with_schedule(self.noise_schedule, x.size(0))
        loss, loss_dict = self(x, c, t=t, features_adapter=features_adapter)
        return loss, loss_dict

    def configure_optimizers(self):
        lr = self.learning_rate
        params = list(self.adapter.parameters())
        opt = torch.optim.AdamW(params, lr=lr)
        return opt

    def on_save_checkpoint(self, checkpoint):
        keys = list(checkpoint['state_dict'].keys())
        for key in keys:
            if 'adapter' not in key:
                del checkpoint['state_dict'][key]

    def on_load_checkpoint(self, checkpoint):
        for name in self.state_dict():
            if 'adapter' not in name:
                checkpoint['state_dict'][name] = self.state_dict()[name]