Spaces:
Runtime error
Runtime error
import torch | |
from lib.utils.iimage import IImage | |
from pytorch_lightning import seed_everything | |
from tqdm import tqdm | |
from lib.smplfusion import share, router, attentionpatch, transformerpatch | |
from lib.smplfusion.patches.attentionpatch import painta | |
from lib.utils import tokenize, scores | |
verbose = False | |
def init_painta(token_idx): | |
# Initialize painta | |
router.attention_forward = attentionpatch.painta.forward | |
router.basic_transformer_forward = transformerpatch.painta.forward | |
painta.painta_on = True | |
painta.painta_res = [16, 32] | |
painta.token_idx = token_idx | |
def init_guidance(): | |
# Setup model for guidance only! | |
router.attention_forward = attentionpatch.default.forward_and_save | |
router.basic_transformer_forward = transformerpatch.default.forward | |
def run(ddim, method, prompt, image, mask, seed, eta, prefix, negative_prompt, positive_prompt, dt, guidance_scale): | |
# Text condition | |
prompt = prefix.format(prompt) | |
context = ddim.encoder.encode([negative_prompt, prompt + positive_prompt]) | |
token_idx = list(range(1 + prefix.split(' ').index('{}'), tokenize(prompt).index('<end_of_text>'))) | |
token_idx += [tokenize(prompt + positive_prompt).index('<end_of_text>')] | |
# Initialize painta | |
if 'painta' in method: init_painta(token_idx) | |
else: init_guidance() | |
# Image condition | |
unet_condition = ddim.get_inpainting_condition(image, mask) | |
share.set_mask(mask) | |
dtype = unet_condition.dtype | |
# Starting latent | |
seed_everything(seed) | |
zt = torch.randn((1,4) + unet_condition.shape[2:]).cuda().to(dtype) | |
# Setup unet for guidance | |
ddim.unet.requires_grad_(True) | |
pbar = tqdm(range(999, 0, -dt)) if verbose else range(999, 0, -dt) | |
for timestep in share.DDIMIterator(pbar): | |
if 'painta' in method and share.timestep <= 500: init_guidance() | |
zt = zt.detach() | |
zt.requires_grad = True | |
# Reset storage | |
share._crossattn_similarity_res16 = [] | |
# Run the model | |
_zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1) | |
with torch.autocast('cuda'): | |
eps_uncond, eps = ddim.unet( | |
torch.cat([_zt, _zt]).to(dtype), | |
timesteps = torch.tensor([timestep, timestep]).cuda(), | |
context = context | |
).detach().chunk(2) | |
# Unconditional guidance | |
eps = (eps_uncond + guidance_scale * (eps - eps_uncond)) | |
z0 = (zt - share.schedule.sqrt_one_minus_alphas[timestep] * eps) / share.schedule.sqrt_alphas[timestep] | |
# Gradient Computation | |
score = scores.bce(share._crossattn_similarity_res16, share.mask16, token_idx = token_idx) | |
score.backward() | |
grad = zt.grad.detach() | |
ddim.unet.zero_grad() # Cleanup already | |
# DDIM Step | |
with torch.no_grad(): | |
sigma = share.schedule.sigma(share.timestep, dt) | |
# Standartization | |
grad -= grad.mean() | |
grad /= grad.std() | |
zt = share.schedule.sqrt_alphas[share.timestep - dt] * z0 + torch.sqrt(1 - share.schedule.alphas[share.timestep - dt] - sigma ** 2) * eps + eta * sigma * grad | |
with torch.no_grad(): | |
output_image = IImage(ddim.vae.decode(z0 / ddim.config.scale_factor)) | |
return output_image |