ID-Pose / src /sampling.py
tokenid
upload
917fe92
import math
import numpy as np
from contextlib import nullcontext
from PIL import Image
from einops import rearrange
import torch
from torch import autocast
from .ldm.models.diffusion.ddim import DDIMSampler
@torch.no_grad()
def sample_model(input_im, model, sampler, precision, h, w, ddim_steps, n_samples, scale, \
ddim_eta, x, y, z):
precision_scope = autocast if precision=='autocast' else nullcontext
with precision_scope('cuda'):
with model.ema_scope():
c = model.get_learned_conditioning(input_im).tile(n_samples,1,1)
T = torch.tensor([x, math.sin(y), math.cos(y), z])
T = T[None, None, :].repeat(n_samples, 1, 1).to(c.device)
c = torch.cat([c, T], dim=-1)
c = model.cc_projection(c)
cond = {}
cond['c_crossattn'] = [c]
c_concat = model.encode_first_stage((input_im.to(c.device))).mode().detach()
cond['c_concat'] = [model.encode_first_stage((input_im.to(c.device))).mode().detach()\
.repeat(n_samples, 1, 1, 1)]
if scale != 1.0:
uc = {}
uc['c_concat'] = [torch.zeros(n_samples, 4, h // 8, w // 8).to(c.device)]
uc['c_crossattn'] = [torch.zeros_like(c).to(c.device)]
else:
uc = None
shape = [4, h // 8, w // 8]
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=cond,
batch_size=n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=None)
# samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False)
x_samples_ddim = model.decode_first_stage(samples_ddim)
return torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()
def sample_images(
model,
input_im,
x=0.,
y=0.,
z=0.,
scale=3.0,
n_samples=4,
ddim_steps=50,
ddim_eta=1.0,
precision='fp32',
h=256,
w=256,
):
sampler = DDIMSampler(model)
x_samples_ddim = sample_model(input_im, model, sampler, precision, h, w,\
ddim_steps, n_samples, scale, ddim_eta, x, y, z)
output_ims = []
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
output_ims.append(x_sample.astype(np.uint8))
return output_ims