Spaces:
Running
on
A10G
Running
on
A10G
File size: 3,470 Bytes
bfd34e9 f1cc496 bfd34e9 f1cc496 bfd34e9 736e88e f1cc496 736e88e bfd34e9 736e88e bfd34e9 da1e12f bfd34e9 da1e12f bfd34e9 da1e12f bfd34e9 f1cc496 bfd34e9 736e88e f1cc496 bfd34e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import torch
from src.utils.iimage import IImage
from pytorch_lightning import seed_everything
from tqdm import tqdm
from src.smplfusion import share, router, attentionpatch, transformerpatch
from src.smplfusion.patches.attentionpatch import painta
from src.utils import tokenize, scores
verbose = False
def init_painta(token_idx):
# Initialize painta
router.attention_forward = attentionpatch.painta.forward
router.basic_transformer_forward = transformerpatch.painta.forward
painta.painta_on = True
painta.painta_res = [16, 32]
painta.token_idx = token_idx
def init_guidance():
# Setup model for guidance only!
router.attention_forward = attentionpatch.default.forward_and_save
router.basic_transformer_forward = transformerpatch.default.forward
def run(
ddim,
method,
prompt,
image,
mask,
seed=0,
eta=0.1,
negative_prompt='',
positive_prompt='',
num_steps=50,
guidance_scale=7.5
):
image = image.padx(64)
mask = mask.dilate(1).alpha().padx(64)
full_prompt = prompt
if positive_prompt != '':
full_prompt = f'{prompt}, {positive_prompt}'
dt = 1000 // num_steps
# Text condition
context = ddim.encoder.encode([negative_prompt, full_prompt])
token_idx = list(range(1, tokenize(prompt).index('<end_of_text>')))
token_idx += [tokenize(full_prompt).index('<end_of_text>')]
# Initialize painta
if 'painta' in method: init_painta(token_idx)
else: init_guidance()
# Image condition
unet_condition = ddim.get_inpainting_condition(image, mask)
share.set_mask(mask)
dtype = unet_condition.dtype
# Starting latent
seed_everything(seed)
zt = torch.randn((1,4) + unet_condition.shape[2:]).cuda().to(dtype)
# Setup unet for guidance
ddim.unet.requires_grad_(True)
pbar = tqdm(range(999, 0, -dt)) if verbose else range(999, 0, -dt)
for timestep in share.DDIMIterator(pbar):
if 'painta' in method and share.timestep <= 500: init_guidance()
zt = zt.detach()
zt.requires_grad = True
# Reset storage
share._crossattn_similarity_res16 = []
# Run the model
_zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
with torch.autocast('cuda'):
eps_uncond, eps = ddim.unet(
torch.cat([_zt, _zt]).to(dtype),
timesteps = torch.tensor([timestep, timestep]).cuda(),
context = context
).detach().chunk(2)
# Unconditional guidance
eps = (eps_uncond + guidance_scale * (eps - eps_uncond))
z0 = (zt - share.schedule.sqrt_one_minus_alphas[timestep] * eps) / share.schedule.sqrt_alphas[timestep]
# Gradient Computation
score = scores.bce(share._crossattn_similarity_res16, share.mask16, token_idx = token_idx)
score.backward()
grad = zt.grad.detach()
ddim.unet.zero_grad()
# DDIM Step
with torch.no_grad():
sigma = share.schedule.sigma(share.timestep, dt)
grad /= grad.std()
zt = share.schedule.sqrt_alphas[share.timestep - dt] * z0 + \
torch.sqrt(1 - share.schedule.alphas[share.timestep - dt] - (eta * sigma) ** 2) * eps + \
(eta * sigma) * grad
with torch.no_grad():
output_image = IImage(ddim.vae.decode(z0 / ddim.config.scale_factor))
return output_image |