haoyuliu00's picture
Initial commit with cleaned history
bf8981a
import torch
import os
import torch.nn as nn
from .latent_diffusion import LatentDiffusion
class Diffpro_SDF(nn.Module):
def __init__(
self,
ldm: LatentDiffusion,
):
"""
cond_type: {chord, texture}
cond_mode: {cond, mix, uncond}
mix: use a special condition for unconditional learning with probability of 0.2
use_enc: whether to use pretrained chord encoder to generate encoded condition
"""
super(Diffpro_SDF, self).__init__()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.ldm = ldm
@classmethod
def load_trained(
cls,
ldm,
chkpt_fpath,
):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = cls(ldm)
trained_leaner = torch.load(chkpt_fpath, map_location=device)
try:
model.load_state_dict(trained_leaner["model"])
except RuntimeError:
model_dict = trained_leaner["model"]
model_dict = {k.replace('cond_enc', 'autoreg_cond_enc'): v for k, v in model_dict.items()}
model_dict = {k.replace('style_enc', 'external_cond_enc'): v for k, v in model_dict.items()}
model.load_state_dict(model_dict)
return model
def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
return self.ldm.p_sample(xt, t)
def q_sample(self, x0: torch.Tensor, t: torch.Tensor):
return self.ldm.q_sample(x0, t)
def get_loss_dict(self, batch, step):
"""
z_y is the stuff the diffusion model needs to learn
"""
# x = batch.float().to(self.device)
x= batch
loss = self.ldm.loss(x)
return {"loss": loss}