Spaces:
Running
on
A10G
Running
on
A10G
import argparse, os, sys, glob, datetime, yaml | |
import torch | |
import time | |
import numpy as np | |
from tqdm import trange | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from ldm.util import instantiate_from_config | |
rescale = lambda x: (x + 1.) / 2. | |
def custom_to_pil(x): | |
x = x.detach().cpu() | |
x = torch.clamp(x, -1., 1.) | |
x = (x + 1.) / 2. | |
x = x.permute(1, 2, 0).numpy() | |
x = (255 * x).astype(np.uint8) | |
x = Image.fromarray(x) | |
if not x.mode == "RGB": | |
x = x.convert("RGB") | |
return x | |
def custom_to_np(x): | |
# saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py | |
sample = x.detach().cpu() | |
sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) | |
sample = sample.permute(0, 2, 3, 1) | |
sample = sample.contiguous() | |
return sample | |
def logs2pil(logs, keys=["sample"]): | |
imgs = dict() | |
for k in logs: | |
try: | |
if len(logs[k].shape) == 4: | |
img = custom_to_pil(logs[k][0, ...]) | |
elif len(logs[k].shape) == 3: | |
img = custom_to_pil(logs[k]) | |
else: | |
print(f"Unknown format for key {k}. ") | |
img = None | |
except: | |
img = None | |
imgs[k] = img | |
return imgs | |
def convsample(model, shape, return_intermediates=True, | |
verbose=True, | |
make_prog_row=False): | |
if not make_prog_row: | |
return model.p_sample_loop(None, shape, | |
return_intermediates=return_intermediates, verbose=verbose) | |
else: | |
return model.progressive_denoising( | |
None, shape, verbose=True | |
) | |
def convsample_ddim(model, steps, shape, eta=1.0 | |
): | |
ddim = DDIMSampler(model) | |
bs = shape[0] | |
shape = shape[1:] | |
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,) | |
return samples, intermediates | |
def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,): | |
log = dict() | |
shape = [batch_size, | |
model.model.diffusion_model.in_channels, | |
model.model.diffusion_model.image_size, | |
model.model.diffusion_model.image_size] | |
with model.ema_scope("Plotting"): | |
t0 = time.time() | |
if vanilla: | |
sample, progrow = convsample(model, shape, | |
make_prog_row=True) | |
else: | |
sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, | |
eta=eta) | |
t1 = time.time() | |
x_sample = model.decode_first_stage(sample) | |
log["sample"] = x_sample | |
log["time"] = t1 - t0 | |
log['throughput'] = sample.shape[0] / (t1 - t0) | |
print(f'Throughput for this batch: {log["throughput"]}') | |
return log | |
def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None): | |
if vanilla: | |
print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.') | |
else: | |
print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}') | |
tstart = time.time() | |
n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1 | |
# path = logdir | |
if model.cond_stage_model is None: | |
all_images = [] | |
print(f"Running unconditional sampling for {n_samples} samples") | |
for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"): | |
logs = make_convolutional_sample(model, batch_size=batch_size, | |
vanilla=vanilla, custom_steps=custom_steps, | |
eta=eta) | |
n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample") | |
all_images.extend([custom_to_np(logs["sample"])]) | |
if n_saved >= n_samples: | |
print(f'Finish after generating {n_saved} samples') | |
break | |
all_img = np.concatenate(all_images, axis=0) | |
all_img = all_img[:n_samples] | |
shape_str = "x".join([str(x) for x in all_img.shape]) | |
nppath = os.path.join(nplog, f"{shape_str}-samples.npz") | |
np.savez(nppath, all_img) | |
else: | |
raise NotImplementedError('Currently only sampling for unconditional models supported.') | |
print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.") | |
def save_logs(logs, path, n_saved=0, key="sample", np_path=None): | |
for k in logs: | |
if k == key: | |
batch = logs[key] | |
if np_path is None: | |
for x in batch: | |
img = custom_to_pil(x) | |
imgpath = os.path.join(path, f"{key}_{n_saved:06}.png") | |
img.save(imgpath) | |
n_saved += 1 | |
else: | |
npbatch = custom_to_np(batch) | |
shape_str = "x".join([str(x) for x in npbatch.shape]) | |
nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz") | |
np.savez(nppath, npbatch) | |
n_saved += npbatch.shape[0] | |
return n_saved | |
def get_parser(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-r", | |
"--resume", | |
type=str, | |
nargs="?", | |
help="load from logdir or checkpoint in logdir", | |
) | |
parser.add_argument( | |
"-n", | |
"--n_samples", | |
type=int, | |
nargs="?", | |
help="number of samples to draw", | |
default=50000 | |
) | |
parser.add_argument( | |
"-e", | |
"--eta", | |
type=float, | |
nargs="?", | |
help="eta for ddim sampling (0.0 yields deterministic sampling)", | |
default=1.0 | |
) | |
parser.add_argument( | |
"-v", | |
"--vanilla_sample", | |
default=False, | |
action='store_true', | |
help="vanilla sampling (default option is DDIM sampling)?", | |
) | |
parser.add_argument( | |
"-l", | |
"--logdir", | |
type=str, | |
nargs="?", | |
help="extra logdir", | |
default="none" | |
) | |
parser.add_argument( | |
"-c", | |
"--custom_steps", | |
type=int, | |
nargs="?", | |
help="number of steps for ddim and fastdpm sampling", | |
default=50 | |
) | |
parser.add_argument( | |
"--batch_size", | |
type=int, | |
nargs="?", | |
help="the bs", | |
default=10 | |
) | |
return parser | |
def load_model_from_config(config, sd): | |
model = instantiate_from_config(config) | |
model.load_state_dict(sd,strict=False) | |
model.cuda() | |
model.eval() | |
return model | |
def load_model(config, ckpt, gpu, eval_mode): | |
if ckpt: | |
print(f"Loading model from {ckpt}") | |
pl_sd = torch.load(ckpt, map_location="cpu") | |
global_step = pl_sd["global_step"] | |
else: | |
pl_sd = {"state_dict": None} | |
global_step = None | |
model = load_model_from_config(config.model, | |
pl_sd["state_dict"]) | |
return model, global_step | |
if __name__ == "__main__": | |
now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") | |
sys.path.append(os.getcwd()) | |
command = " ".join(sys.argv) | |
parser = get_parser() | |
opt, unknown = parser.parse_known_args() | |
ckpt = None | |
if not os.path.exists(opt.resume): | |
raise ValueError("Cannot find {}".format(opt.resume)) | |
if os.path.isfile(opt.resume): | |
# paths = opt.resume.split("/") | |
try: | |
logdir = '/'.join(opt.resume.split('/')[:-1]) | |
# idx = len(paths)-paths[::-1].index("logs")+1 | |
print(f'Logdir is {logdir}') | |
except ValueError: | |
paths = opt.resume.split("/") | |
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt | |
logdir = "/".join(paths[:idx]) | |
ckpt = opt.resume | |
else: | |
assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory" | |
logdir = opt.resume.rstrip("/") | |
ckpt = os.path.join(logdir, "model.ckpt") | |
base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml"))) | |
opt.base = base_configs | |
configs = [OmegaConf.load(cfg) for cfg in opt.base] | |
cli = OmegaConf.from_dotlist(unknown) | |
config = OmegaConf.merge(*configs, cli) | |
gpu = True | |
eval_mode = True | |
if opt.logdir != "none": | |
locallog = logdir.split(os.sep)[-1] | |
if locallog == "": locallog = logdir.split(os.sep)[-2] | |
print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'") | |
logdir = os.path.join(opt.logdir, locallog) | |
print(config) | |
model, global_step = load_model(config, ckpt, gpu, eval_mode) | |
print(f"global step: {global_step}") | |
print(75 * "=") | |
print("logging to:") | |
logdir = os.path.join(logdir, "samples", f"{global_step:08}", now) | |
imglogdir = os.path.join(logdir, "img") | |
numpylogdir = os.path.join(logdir, "numpy") | |
os.makedirs(imglogdir) | |
os.makedirs(numpylogdir) | |
print(logdir) | |
print(75 * "=") | |
# write config out | |
sampling_file = os.path.join(logdir, "sampling_config.yaml") | |
sampling_conf = vars(opt) | |
with open(sampling_file, 'w') as f: | |
yaml.dump(sampling_conf, f, default_flow_style=False) | |
print(sampling_conf) | |
run(model, imglogdir, eta=opt.eta, | |
vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps, | |
batch_size=opt.batch_size, nplog=numpylogdir) | |
print("done.") | |