File size: 3,643 Bytes
8483373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torchvision
import os
import gc
import tqdm
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from lora_w2w import LoRAw2w
from transformers import AutoTokenizer, PretrainedConfig
from PIL import Image
import warnings
warnings.filterwarnings("ignore")




######## Editing Utilities

def get_direction(df, label, pinverse, return_dim, device):
    ### get labels
    labels = []
    for folder in list(df.index): 
        labels.append(df.loc[folder][label])
    labels = torch.Tensor(labels).to(device).bfloat16()

    ### solve least squares
    direction = (pinverse@labels).unsqueeze(0)

    if return_dim == 1000: 
        return direction
    else:
        direction = torch.cat((direction, torch.zeros((1, return_dim-1000)).to(device)), dim=1)
        return direction
   
def debias(direction, label, df, pinverse, device):
    ### get labels
    labels = []
    for folder in list(df.index): 
        labels.append(df.loc[folder][label])
    labels = torch.Tensor(labels).to(device).bfloat16()

    ### solve least squares
    d = (pinverse@labels).unsqueeze(0)

    ###align dimensionalities of the two vectors
    if direction.shape[1] == 1000: 
        pass
    else:
        d = torch.cat((d, torch.zeros((1, direction.shape[1]-1000)).to(device)), dim=1)

    #remove this component from the direction
    direction = direction - ((direction@d.T)/(torch.norm(d)**2))*d
    return direction


@torch.no_grad
def edit_inference(network, edited_weights, unet, vae, text_encoder, tokenizer, prompt, negative_prompt, guidance_scale, noise_scheduler, ddim_steps, start_noise, seed, generator, device):
    
    original_weights = network.proj.clone()

    generator = generator.manual_seed(seed)
    latents = torch.randn(
        (1, unet.in_channels, 512 // 8, 512 // 8),
        generator = generator,
        device = device
    ).bfloat16()
   

    text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")

    text_embeddings = text_encoder(text_input.input_ids.to(device))[0]

    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer(
                            [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
                        )
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
    noise_scheduler.set_timesteps(ddim_steps) 
    latents = latents * noise_scheduler.init_noise_sigma
    

 
    for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
        latent_model_input = torch.cat([latents] * 2)
        latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
        
        if t>start_noise:
            pass
        elif t<=start_noise:
            network.proj = torch.nn.Parameter(edited_weights)
            network.reset()


        with network:
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
            
        
        #guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
    
    latents = 1 / 0.18215 * latents
    image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)

    #reset weights back to original 
    network.proj = torch.nn.Parameter(original_weights)
    network.reset()

    return image