File size: 1,750 Bytes
bf8981a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import os
import torch.nn as nn
from .latent_diffusion import LatentDiffusion


class Diffpro_SDF(nn.Module):

    def __init__(
        self,
        ldm: LatentDiffusion,
    ):
        """
        cond_type: {chord, texture}
        cond_mode: {cond, mix, uncond}
            mix: use a special condition for unconditional learning with probability of 0.2
        use_enc: whether to use pretrained chord encoder to generate encoded condition
        """
        super(Diffpro_SDF, self).__init__()
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.ldm = ldm

    @classmethod
    def load_trained(
        cls,
        ldm,
        chkpt_fpath,
    ):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = cls(ldm)
        trained_leaner = torch.load(chkpt_fpath, map_location=device)
        try:
            model.load_state_dict(trained_leaner["model"])
        except RuntimeError:
            model_dict = trained_leaner["model"]
            model_dict = {k.replace('cond_enc', 'autoreg_cond_enc'): v for k, v in model_dict.items()}
            model_dict = {k.replace('style_enc', 'external_cond_enc'): v for k, v in model_dict.items()}
            model.load_state_dict(model_dict)
        return model

    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
        return self.ldm.p_sample(xt, t)

    def q_sample(self, x0: torch.Tensor, t: torch.Tensor):
        return self.ldm.q_sample(x0, t)

    def get_loss_dict(self, batch, step):
        """
        z_y is the stuff the diffusion model needs to learn
        """
        # x = batch.float().to(self.device)

        x= batch
        loss = self.ldm.loss(x)
        return {"loss": loss}