weights2weights / editing.py
multimodalart's picture
Upload 200 files
8483373 verified
raw
history blame
3.64 kB
import torch
import torchvision
import os
import gc
import tqdm
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from lora_w2w import LoRAw2w
from transformers import AutoTokenizer, PretrainedConfig
from PIL import Image
import warnings
warnings.filterwarnings("ignore")
######## Editing Utilities
def get_direction(df, label, pinverse, return_dim, device):
### get labels
labels = []
for folder in list(df.index):
labels.append(df.loc[folder][label])
labels = torch.Tensor(labels).to(device).bfloat16()
### solve least squares
direction = (pinverse@labels).unsqueeze(0)
if return_dim == 1000:
return direction
else:
direction = torch.cat((direction, torch.zeros((1, return_dim-1000)).to(device)), dim=1)
return direction
def debias(direction, label, df, pinverse, device):
### get labels
labels = []
for folder in list(df.index):
labels.append(df.loc[folder][label])
labels = torch.Tensor(labels).to(device).bfloat16()
### solve least squares
d = (pinverse@labels).unsqueeze(0)
###align dimensionalities of the two vectors
if direction.shape[1] == 1000:
pass
else:
d = torch.cat((d, torch.zeros((1, direction.shape[1]-1000)).to(device)), dim=1)
#remove this component from the direction
direction = direction - ((direction@d.T)/(torch.norm(d)**2))*d
return direction
@torch.no_grad
def edit_inference(network, edited_weights, unet, vae, text_encoder, tokenizer, prompt, negative_prompt, guidance_scale, noise_scheduler, ddim_steps, start_noise, seed, generator, device):
original_weights = network.proj.clone()
generator = generator.manual_seed(seed)
latents = torch.randn(
(1, unet.in_channels, 512 // 8, 512 // 8),
generator = generator,
device = device
).bfloat16()
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
noise_scheduler.set_timesteps(ddim_steps)
latents = latents * noise_scheduler.init_noise_sigma
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
if t>start_noise:
pass
elif t<=start_noise:
network.proj = torch.nn.Parameter(edited_weights)
network.reset()
with network:
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
#guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
latents = 1 / 0.18215 * latents
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
#reset weights back to original
network.proj = torch.nn.Parameter(original_weights)
network.reset()
return image