turn-the-cam-anonymous's picture
adding CLIP taming
1ed7deb
import argparse, os, sys, glob, math, time
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from main import instantiate_from_config, DataModuleFromConfig
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from tqdm import trange
def save_image(x, path):
c,h,w = x.shape
assert c==3
x = ((x.detach().cpu().numpy().transpose(1,2,0)+1.0)*127.5).clip(0,255).astype(np.uint8)
Image.fromarray(x).save(path)
@torch.no_grad()
def run_conditional(model, dsets, outdir, top_k, temperature, batch_size=1):
if len(dsets.datasets) > 1:
split = sorted(dsets.datasets.keys())[0]
dset = dsets.datasets[split]
else:
dset = next(iter(dsets.datasets.values()))
print("Dataset: ", dset.__class__.__name__)
for start_idx in trange(0,len(dset)-batch_size+1,batch_size):
indices = list(range(start_idx, start_idx+batch_size))
example = default_collate([dset[i] for i in indices])
x = model.get_input("image", example).to(model.device)
for i in range(x.shape[0]):
save_image(x[i], os.path.join(outdir, "originals",
"{:06}.png".format(indices[i])))
cond_key = model.cond_stage_key
c = model.get_input(cond_key, example).to(model.device)
scale_factor = 1.0
quant_z, z_indices = model.encode_to_z(x)
quant_c, c_indices = model.encode_to_c(c)
cshape = quant_z.shape
xrec = model.first_stage_model.decode(quant_z)
for i in range(xrec.shape[0]):
save_image(xrec[i], os.path.join(outdir, "reconstructions",
"{:06}.png".format(indices[i])))
if cond_key == "segmentation":
# get image from segmentation mask
num_classes = c.shape[1]
c = torch.argmax(c, dim=1, keepdim=True)
c = torch.nn.functional.one_hot(c, num_classes=num_classes)
c = c.squeeze(1).permute(0, 3, 1, 2).float()
c = model.cond_stage_model.to_rgb(c)
idx = z_indices
half_sample = False
if half_sample:
start = idx.shape[1]//2
else:
start = 0
idx[:,start:] = 0
idx = idx.reshape(cshape[0],cshape[2],cshape[3])
start_i = start//cshape[3]
start_j = start %cshape[3]
cidx = c_indices
cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
sample = True
for i in range(start_i,cshape[2]-0):
if i <= 8:
local_i = i
elif cshape[2]-i < 8:
local_i = 16-(cshape[2]-i)
else:
local_i = 8
for j in range(start_j,cshape[3]-0):
if j <= 8:
local_j = j
elif cshape[3]-j < 8:
local_j = 16-(cshape[3]-j)
else:
local_j = 8
i_start = i-local_i
i_end = i_start+16
j_start = j-local_j
j_end = j_start+16
patch = idx[:,i_start:i_end,j_start:j_end]
patch = patch.reshape(patch.shape[0],-1)
cpatch = cidx[:, i_start:i_end, j_start:j_end]
cpatch = cpatch.reshape(cpatch.shape[0], -1)
patch = torch.cat((cpatch, patch), dim=1)
logits,_ = model.transformer(patch[:,:-1])
logits = logits[:, -256:, :]
logits = logits.reshape(cshape[0],16,16,-1)
logits = logits[:,local_i,local_j,:]
logits = logits/temperature
if top_k is not None:
logits = model.top_k_logits(logits, top_k)
# apply softmax to convert to probabilities
probs = torch.nn.functional.softmax(logits, dim=-1)
# sample from the distribution or take the most likely
if sample:
ix = torch.multinomial(probs, num_samples=1)
else:
_, ix = torch.topk(probs, k=1, dim=-1)
idx[:,i,j] = ix
xsample = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
for i in range(xsample.shape[0]):
save_image(xsample[i], os.path.join(outdir, "samples",
"{:06}.png".format(indices[i])))
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-r",
"--resume",
type=str,
nargs="?",
help="load from logdir or checkpoint in logdir",
)
parser.add_argument(
"-b",
"--base",
nargs="*",
metavar="base_config.yaml",
help="paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
default=list(),
)
parser.add_argument(
"-c",
"--config",
nargs="?",
metavar="single_config.yaml",
help="path to single config. If specified, base configs will be ignored "
"(except for the last one if left unspecified).",
const=True,
default="",
)
parser.add_argument(
"--ignore_base_data",
action="store_true",
help="Ignore data specification from base configs. Useful if you want "
"to specify a custom datasets on the command line.",
)
parser.add_argument(
"--outdir",
required=True,
type=str,
help="Where to write outputs to.",
)
parser.add_argument(
"--top_k",
type=int,
default=100,
help="Sample from among top-k predictions.",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="Sampling temperature.",
)
return parser
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
if "ckpt_path" in config.params:
print("Deleting the restore-ckpt path from the config...")
config.params.ckpt_path = None
if "downsample_cond_size" in config.params:
print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
config.params.downsample_cond_size = -1
config.params["downsample_cond_factor"] = 0.5
try:
if "ckpt_path" in config.params.first_stage_config.params:
config.params.first_stage_config.params.ckpt_path = None
print("Deleting the first-stage restore-ckpt path from the config...")
if "ckpt_path" in config.params.cond_stage_config.params:
config.params.cond_stage_config.params.ckpt_path = None
print("Deleting the cond-stage restore-ckpt path from the config...")
except:
pass
model = instantiate_from_config(config)
if sd is not None:
missing, unexpected = model.load_state_dict(sd, strict=False)
print(f"Missing Keys in State Dict: {missing}")
print(f"Unexpected Keys in State Dict: {unexpected}")
if gpu:
model.cuda()
if eval_mode:
model.eval()
return {"model": model}
def get_data(config):
# get data
data = instantiate_from_config(config.data)
data.prepare_data()
data.setup()
return data
def load_model_and_dset(config, ckpt, gpu, eval_mode):
# get data
dsets = get_data(config) # calls data.config ...
# now load the specified checkpoint
if ckpt:
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
else:
pl_sd = {"state_dict": None}
global_step = None
model = load_model_from_config(config.model,
pl_sd["state_dict"],
gpu=gpu,
eval_mode=eval_mode)["model"]
return dsets, model, global_step
if __name__ == "__main__":
sys.path.append(os.getcwd())
parser = get_parser()
opt, unknown = parser.parse_known_args()
ckpt = None
if opt.resume:
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
paths = opt.resume.split("/")
try:
idx = len(paths)-paths[::-1].index("logs")+1
except ValueError:
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
logdir = "/".join(paths[:idx])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
print(f"logdir:{logdir}")
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
opt.base = base_configs+opt.base
if opt.config:
if type(opt.config) == str:
opt.base = [opt.config]
else:
opt.base = [opt.base[-1]]
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
if opt.ignore_base_data:
for config in configs:
if hasattr(config, "data"): del config["data"]
config = OmegaConf.merge(*configs, cli)
print(ckpt)
gpu = True
eval_mode = True
show_config = False
if show_config:
print(OmegaConf.to_container(config))
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
print(f"Global step: {global_step}")
outdir = os.path.join(opt.outdir, "{:06}_{}_{}".format(global_step,
opt.top_k,
opt.temperature))
os.makedirs(outdir, exist_ok=True)
print("Writing samples to ", outdir)
for k in ["originals", "reconstructions", "samples"]:
os.makedirs(os.path.join(outdir, k), exist_ok=True)
run_conditional(model, dsets, outdir, opt.top_k, opt.temperature)