File size: 3,114 Bytes
851751e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import torch
import numpy as np

from omegaconf import OmegaConf
from PIL import Image

from lidm.models.diffusion.ddim import DDIMSampler
from lidm.utils.misc_utils import instantiate_from_config, isimage, ismap
from lidm.utils.lidar_utils import range2pcd
from app_config import DEVICE


CUSTOM_STEPS = 50
ETA = 1.0

# model loading
MODEL_PATH = './models/lidm/kitti/cam2lidar'
CFG_PATH = os.path.join(MODEL_PATH, 'config.yaml')
CKPT_PATH = os.path.join(MODEL_PATH, 'model.ckpt')

# settings
model_config = OmegaConf.load(CFG_PATH)


def custom_to_pcd(x, config, rgb=None):
    x = x.squeeze().detach().cpu().numpy()
    x = (np.clip(x, -1., 1.) + 1.) / 2.
    if rgb is not None:
        rgb = rgb.squeeze().detach().cpu().numpy()
        rgb = (np.clip(rgb, -1., 1.) + 1.) / 2.
        rgb = rgb.transpose(1, 2, 0)
    xyz, rgb, _ = range2pcd(x, color=rgb, **config['data']['params']['dataset'])

    return xyz, rgb


def custom_to_pil(x):
    x = x.detach().cpu().squeeze().numpy()
    x = (np.clip(x, -1., 1.) + 1.) / 2.
    x = (255 * x).astype(np.uint8)

    if x.ndim == 3:
        x = x.transpose(1, 2, 0)
    x = Image.fromarray(x)

    return x


def logs2pil(logs, keys=["sample"]):
    imgs = dict()
    for k in logs:
        try:
            if len(logs[k].shape) == 4:
                img = custom_to_pil(logs[k][0, ...])
            elif len(logs[k].shape) == 3:
                img = custom_to_pil(logs[k])
            else:
                print(f"Unknown format for key {k}. ")
                img = None
        except:
            img = None
        imgs[k] = img
    return imgs


def load_model_from_config(config, sd, device):
    model = instantiate_from_config(config)
    model.load_state_dict(sd, strict=False)
    model.to(device)
    model.eval()
    return model


def load_model():
    pl_sd = torch.load(CKPT_PATH, map_location="cpu")
    model = load_model_from_config(model_config.model, pl_sd["state_dict"], DEVICE)
    return model


@torch.no_grad()
def convsample_ddim(model, cond, steps, shape, eta=1.0, verbose=False):
    ddim = DDIMSampler(model)
    bs = shape[0]
    shape = shape[1:]
    samples, intermediates = ddim.sample(steps, conditioning=cond, batch_size=bs, shape=shape, eta=eta, verbose=verbose, disable_tqdm=True)
    return samples, intermediates


@torch.no_grad()
def make_convolutional_sample(model, batch, batch_size, custom_steps=None, eta=1.0):
    xc = batch['camera']
    c = model.get_learned_conditioning(xc.to(model.device))

    with model.ema_scope("Plotting"):
        samples, z_denoise_row = model.sample_log(cond=c, batch_size=batch_size, ddim=True,
                                                  ddim_steps=custom_steps, eta=eta)
    x_samples = model.decode_first_stage(samples)

    return x_samples


def sample(model, cond):
    batch = {'camera': cond}
    img = make_convolutional_sample(model, batch, batch_size=1, custom_steps=CUSTOM_STEPS, eta=ETA)  # TODO add arguments for batch_size, custom_steps and eta
    img = img[0, 0]
    pcd = custom_to_pcd(img, model_config)[0].astype(np.float32)
    return img, pcd