import numpy as np import torch def torch_samps_to_imgs(imgs, uncenter=True): if uncenter: imgs = (imgs + 1) / 2 # [-1, 1] -> [0, 1] imgs = (imgs * 255).clamp(0, 255) imgs = imgs.to(torch.uint8) imgs = imgs.permute(0, 2, 3, 1) imgs = imgs.cpu().numpy() return imgs def imgs_to_torch(imgs): assert imgs.dtype == np.uint8 assert len(imgs.shape) == 4 and imgs.shape[-1] == 3, "expect (N, H, W, C)" _, H, W, _ = imgs.shape imgs = imgs.transpose(0, 3, 1, 2) imgs = (imgs / 255).astype(np.float32) imgs = (imgs * 2) - 1 imgs = torch.as_tensor(imgs) H, W = [_l - (_l % 32) for _l in (H, W)] imgs = torch.nn.functional.interpolate(imgs, (H, W), mode="bilinear") return imgs def test_encode_decode(): import imageio from run_img_sampling import ScoreAdapter, SD from vis import _draw fname = "~/clean.png" raw = imageio.imread(fname) raw = imgs_to_torch(raw[np.newaxis, ...]) model: ScoreAdapter = SD().run() raw = raw.to(model.device) zs = model.encode(raw) img = model.decode(zs) img = torch_samps_to_imgs(img) _draw( [imageio.imread(fname), img.squeeze(0)], ) def test(): test_encode_decode() if __name__ == "__main__": test()