Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torchvision | |
import os | |
import gc | |
import tqdm | |
import matplotlib.pyplot as plt | |
import torchvision.transforms as transforms | |
from lora_w2w import LoRAw2w | |
from transformers import AutoTokenizer, PretrainedConfig | |
from PIL import Image | |
import warnings | |
warnings.filterwarnings("ignore") | |
######## Editing Utilities | |
def get_direction(df, label, pinverse, return_dim, device): | |
### get labels | |
labels = [] | |
for folder in list(df.index): | |
labels.append(df.loc[folder][label]) | |
labels = torch.Tensor(labels).to(device).bfloat16() | |
### solve least squares | |
direction = (pinverse@labels).unsqueeze(0) | |
if return_dim == 1000: | |
return direction | |
else: | |
direction = torch.cat((direction, torch.zeros((1, return_dim-1000)).to(device)), dim=1) | |
return direction | |
def debias(direction, label, df, pinverse, device): | |
### get labels | |
labels = [] | |
for folder in list(df.index): | |
labels.append(df.loc[folder][label]) | |
labels = torch.Tensor(labels).to(device).bfloat16() | |
### solve least squares | |
d = (pinverse@labels).unsqueeze(0) | |
###align dimensionalities of the two vectors | |
if direction.shape[1] == 1000: | |
pass | |
else: | |
d = torch.cat((d, torch.zeros((1, direction.shape[1]-1000)).to(device)), dim=1) | |
#remove this component from the direction | |
direction = direction - ((direction@d.T)/(torch.norm(d)**2))*d | |
return direction | |
def edit_inference(network, edited_weights, unet, vae, text_encoder, tokenizer, prompt, negative_prompt, guidance_scale, noise_scheduler, ddim_steps, start_noise, seed, generator, device): | |
original_weights = network.proj.clone() | |
generator = generator.manual_seed(seed) | |
latents = torch.randn( | |
(1, unet.in_channels, 512 // 8, 512 // 8), | |
generator = generator, | |
device = device | |
).bfloat16() | |
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") | |
text_embeddings = text_encoder(text_input.input_ids.to(device))[0] | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = tokenizer( | |
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" | |
) | |
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
noise_scheduler.set_timesteps(ddim_steps) | |
latents = latents * noise_scheduler.init_noise_sigma | |
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): | |
latent_model_input = torch.cat([latents] * 2) | |
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) | |
if t>start_noise: | |
pass | |
elif t<=start_noise: | |
network.proj = torch.nn.Parameter(edited_weights) | |
network.reset() | |
with network: | |
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample | |
#guidance | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample | |
latents = 1 / 0.18215 * latents | |
image = vae.decode(latents).sample | |
image = (image / 2 + 0.5).clamp(0, 1) | |
#reset weights back to original | |
network.proj = torch.nn.Parameter(original_weights) | |
network.reset() | |
return image | |