File size: 2,679 Bytes
8483373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torchvision
import tqdm
import torchvision.transforms as transforms
from PIL import Image
import warnings
warnings.filterwarnings("ignore")



### run inversion  (optimize PC coefficients) given single image
def invert(network, unet, vae, text_encoder, tokenizer, prompt, noise_scheduler, epochs, image_path, mask_path, device, weight_decay = 1e-10, lr=1e-1):
    ### load mask
    if mask_path: 
        mask = Image.open(mask_path)
        mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask)
        mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()
    else: 
        mask = torch.ones((1,1,64,64)).to(device).bfloat16()

    ### single image dataset
    image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
                                                transforms.RandomCrop(512),
                                                transforms.ToTensor(),
                                                transforms.Normalize([0.5], [0.5])])


    train_dataset = torchvision.datasets.ImageFolder(root=image_path, transform = image_transforms)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True) 

    ### optimizer 
    optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay)    

    ### training loop
    unet.train()
    for epoch in tqdm.tqdm(range(epochs)):
        for batch,_ in train_dataloader:
            ### prepare inputs
            batch = batch.to(device).bfloat16()
            latents = vae.encode(batch).latent_dist.sample()
            latents = latents*0.18215
            noise = torch.randn_like(latents)
            bsz = latents.shape[0]
         
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
            timesteps = timesteps.long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
            text_embeddings = text_encoder(text_input.input_ids.to(device))[0]

            ### loss + sgd step
            with network:
                model_pred = unet(noisy_latents, timesteps, text_embeddings).sample
                loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean")
                optim.zero_grad()
                loss.backward()
                optim.step()

    ### return optimized network
    return network