|
import argparse, os, sys, glob |
|
import torch |
|
import time |
|
import numpy as np |
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
from tqdm import tqdm, trange |
|
from einops import repeat |
|
|
|
from main import instantiate_from_config |
|
from taming.modules.transformer.mingpt import sample_with_past |
|
|
|
|
|
rescale = lambda x: (x + 1.) / 2. |
|
|
|
|
|
def chw_to_pillow(x): |
|
return Image.fromarray((255*rescale(x.detach().cpu().numpy().transpose(1,2,0))).clip(0,255).astype(np.uint8)) |
|
|
|
|
|
@torch.no_grad() |
|
def sample_classconditional(model, batch_size, class_label, steps=256, temperature=None, top_k=None, callback=None, |
|
dim_z=256, h=16, w=16, verbose_time=False, top_p=None): |
|
log = dict() |
|
assert type(class_label) == int, f'expecting type int but type is {type(class_label)}' |
|
qzshape = [batch_size, dim_z, h, w] |
|
assert not model.be_unconditional, 'Expecting a class-conditional Net2NetTransformer.' |
|
c_indices = repeat(torch.tensor([class_label]), '1 -> b 1', b=batch_size).to(model.device) |
|
t1 = time.time() |
|
index_sample = sample_with_past(c_indices, model.transformer, steps=steps, |
|
sample_logits=True, top_k=top_k, callback=callback, |
|
temperature=temperature, top_p=top_p) |
|
if verbose_time: |
|
sampling_time = time.time() - t1 |
|
print(f"Full sampling takes about {sampling_time:.2f} seconds.") |
|
x_sample = model.decode_to_img(index_sample, qzshape) |
|
log["samples"] = x_sample |
|
log["class_label"] = c_indices |
|
return log |
|
|
|
|
|
@torch.no_grad() |
|
def sample_unconditional(model, batch_size, steps=256, temperature=None, top_k=None, top_p=None, callback=None, |
|
dim_z=256, h=16, w=16, verbose_time=False): |
|
log = dict() |
|
qzshape = [batch_size, dim_z, h, w] |
|
assert model.be_unconditional, 'Expecting an unconditional model.' |
|
c_indices = repeat(torch.tensor([model.sos_token]), '1 -> b 1', b=batch_size).to(model.device) |
|
t1 = time.time() |
|
index_sample = sample_with_past(c_indices, model.transformer, steps=steps, |
|
sample_logits=True, top_k=top_k, callback=callback, |
|
temperature=temperature, top_p=top_p) |
|
if verbose_time: |
|
sampling_time = time.time() - t1 |
|
print(f"Full sampling takes about {sampling_time:.2f} seconds.") |
|
x_sample = model.decode_to_img(index_sample, qzshape) |
|
log["samples"] = x_sample |
|
return log |
|
|
|
|
|
@torch.no_grad() |
|
def run(logdir, model, batch_size, temperature, top_k, unconditional=True, num_samples=50000, |
|
given_classes=None, top_p=None): |
|
batches = [batch_size for _ in range(num_samples//batch_size)] + [num_samples % batch_size] |
|
if not unconditional: |
|
assert given_classes is not None |
|
print("Running in pure class-conditional sampling mode. I will produce " |
|
f"{num_samples} samples for each of the {len(given_classes)} classes, " |
|
f"i.e. {num_samples*len(given_classes)} in total.") |
|
for class_label in tqdm(given_classes, desc="Classes"): |
|
for n, bs in tqdm(enumerate(batches), desc="Sampling Class"): |
|
if bs == 0: break |
|
logs = sample_classconditional(model, batch_size=bs, class_label=class_label, |
|
temperature=temperature, top_k=top_k, top_p=top_p) |
|
save_from_logs(logs, logdir, base_count=n * batch_size, cond_key=logs["class_label"]) |
|
else: |
|
print(f"Running in unconditional sampling mode, producing {num_samples} samples.") |
|
for n, bs in tqdm(enumerate(batches), desc="Sampling"): |
|
if bs == 0: break |
|
logs = sample_unconditional(model, batch_size=bs, temperature=temperature, top_k=top_k, top_p=top_p) |
|
save_from_logs(logs, logdir, base_count=n * batch_size) |
|
|
|
|
|
def save_from_logs(logs, logdir, base_count, key="samples", cond_key=None): |
|
xx = logs[key] |
|
for i, x in enumerate(xx): |
|
x = chw_to_pillow(x) |
|
count = base_count + i |
|
if cond_key is None: |
|
x.save(os.path.join(logdir, f"{count:06}.png")) |
|
else: |
|
condlabel = cond_key[i] |
|
if type(condlabel) == torch.Tensor: condlabel = condlabel.item() |
|
os.makedirs(os.path.join(logdir, str(condlabel)), exist_ok=True) |
|
x.save(os.path.join(logdir, str(condlabel), f"{count:06}.png")) |
|
|
|
|
|
def get_parser(): |
|
def str2bool(v): |
|
if isinstance(v, bool): |
|
return v |
|
if v.lower() in ("yes", "true", "t", "y", "1"): |
|
return True |
|
elif v.lower() in ("no", "false", "f", "n", "0"): |
|
return False |
|
else: |
|
raise argparse.ArgumentTypeError("Boolean value expected.") |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-r", |
|
"--resume", |
|
type=str, |
|
nargs="?", |
|
help="load from logdir or checkpoint in logdir", |
|
) |
|
parser.add_argument( |
|
"-o", |
|
"--outdir", |
|
type=str, |
|
nargs="?", |
|
help="path where the samples will be logged to.", |
|
default="" |
|
) |
|
parser.add_argument( |
|
"-b", |
|
"--base", |
|
nargs="*", |
|
metavar="base_config.yaml", |
|
help="paths to base configs. Loaded from left-to-right. " |
|
"Parameters can be overwritten or added with command-line options of the form `--key value`.", |
|
default=list(), |
|
) |
|
parser.add_argument( |
|
"-n", |
|
"--num_samples", |
|
type=int, |
|
nargs="?", |
|
help="num_samples to draw", |
|
default=50000 |
|
) |
|
parser.add_argument( |
|
"--batch_size", |
|
type=int, |
|
nargs="?", |
|
help="the batch size", |
|
default=25 |
|
) |
|
parser.add_argument( |
|
"-k", |
|
"--top_k", |
|
type=int, |
|
nargs="?", |
|
help="top-k value to sample with", |
|
default=250, |
|
) |
|
parser.add_argument( |
|
"-t", |
|
"--temperature", |
|
type=float, |
|
nargs="?", |
|
help="temperature value to sample with", |
|
default=1.0 |
|
) |
|
parser.add_argument( |
|
"-p", |
|
"--top_p", |
|
type=float, |
|
nargs="?", |
|
help="top-p value to sample with", |
|
default=1.0 |
|
) |
|
parser.add_argument( |
|
"--classes", |
|
type=str, |
|
nargs="?", |
|
help="specify comma-separated classes to sample from. Uses 1000 classes per default.", |
|
default="imagenet" |
|
) |
|
return parser |
|
|
|
|
|
def load_model_from_config(config, sd, gpu=True, eval_mode=True): |
|
model = instantiate_from_config(config) |
|
if sd is not None: |
|
model.load_state_dict(sd) |
|
if gpu: |
|
model.cuda() |
|
if eval_mode: |
|
model.eval() |
|
return {"model": model} |
|
|
|
|
|
def load_model(config, ckpt, gpu, eval_mode): |
|
|
|
if ckpt: |
|
pl_sd = torch.load(ckpt, map_location="cpu") |
|
global_step = pl_sd["global_step"] |
|
print(f"loaded model from global step {global_step}.") |
|
else: |
|
pl_sd = {"state_dict": None} |
|
global_step = None |
|
model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"] |
|
return model, global_step |
|
|
|
|
|
if __name__ == "__main__": |
|
sys.path.append(os.getcwd()) |
|
parser = get_parser() |
|
|
|
opt, unknown = parser.parse_known_args() |
|
assert opt.resume |
|
|
|
ckpt = None |
|
|
|
if not os.path.exists(opt.resume): |
|
raise ValueError("Cannot find {}".format(opt.resume)) |
|
if os.path.isfile(opt.resume): |
|
paths = opt.resume.split("/") |
|
try: |
|
idx = len(paths)-paths[::-1].index("logs")+1 |
|
except ValueError: |
|
idx = -2 |
|
logdir = "/".join(paths[:idx]) |
|
ckpt = opt.resume |
|
else: |
|
assert os.path.isdir(opt.resume), opt.resume |
|
logdir = opt.resume.rstrip("/") |
|
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") |
|
|
|
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml"))) |
|
opt.base = base_configs+opt.base |
|
|
|
configs = [OmegaConf.load(cfg) for cfg in opt.base] |
|
cli = OmegaConf.from_dotlist(unknown) |
|
config = OmegaConf.merge(*configs, cli) |
|
|
|
model, global_step = load_model(config, ckpt, gpu=True, eval_mode=True) |
|
|
|
if opt.outdir: |
|
print(f"Switching logdir from '{logdir}' to '{opt.outdir}'") |
|
logdir = opt.outdir |
|
|
|
if opt.classes == "imagenet": |
|
given_classes = [i for i in range(1000)] |
|
else: |
|
cls_str = opt.classes |
|
assert not cls_str.endswith(","), 'class string should not end with a ","' |
|
given_classes = [int(c) for c in cls_str.split(",")] |
|
|
|
logdir = os.path.join(logdir, "samples", f"top_k_{opt.top_k}_temp_{opt.temperature:.2f}_top_p_{opt.top_p}", |
|
f"{global_step}") |
|
|
|
print(f"Logging to {logdir}") |
|
os.makedirs(logdir, exist_ok=True) |
|
|
|
run(logdir, model, opt.batch_size, opt.temperature, opt.top_k, unconditional=model.be_unconditional, |
|
given_classes=given_classes, num_samples=opt.num_samples, top_p=opt.top_p) |
|
|
|
print("done.") |
|
|