import os # os.system("pip uninstall -y gradio") # #os.system('pip install gradio==3.43.1') import torch import torchvision import torchvision.transforms as transforms from torch.utils.data import Dataset, DataLoader import gradio as gr import sys import os import tqdm sys.path.append(os.path.abspath(os.path.join("", ".."))) import torch import gc import warnings warnings.filterwarnings("ignore") from PIL import Image from utils import load_models, save_model_w2w, save_model_for_diffusers from editing import get_direction, debias from sampling import sample_weights from lora_w2w import LoRAw2w from huggingface_hub import snapshot_download import numpy as np global device global generator global unet global vae global text_encoder global tokenizer global noise_scheduler global network device = "cuda:0" generator = torch.Generator(device=device) from gradio_imageslider import ImageSlider models_path = snapshot_download(repo_id="Snapchat/w2w") mean = torch.load(f"{models_path}/files/mean.pt").bfloat16().to(device) std = torch.load(f"{models_path}/files/std.pt").bfloat16().to(device) v = torch.load(f"{models_path}/files/V.pt").bfloat16().to(device) proj = torch.load(f"{models_path}/files/proj_1000pc.pt").bfloat16().to(device) df = torch.load(f"{models_path}/files/identity_df.pt") weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt") pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device) unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device) def sample_model(): global unet del unet global network unet, _, _, _, _ = load_models(device) network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00) @torch.no_grad() def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed): global device global generator global unet global vae global text_encoder global tokenizer global noise_scheduler generator = generator.manual_seed(seed) latents = torch.randn( (1, unet.in_channels, 512 // 8, 512 // 8), generator = generator, device = device ).bfloat16() text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = tokenizer( [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) noise_scheduler.set_timesteps(ddim_steps) latents = latents * noise_scheduler.init_noise_sigma for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) with network: noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample #guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = noise_scheduler.step(noise_pred, t, latents).prev_sample latents = 1 / 0.18215 * latents image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] image = Image.fromarray((image * 255).round().astype("uint8")) return image @torch.no_grad() def edit_inference(input_image, prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4): global device global generator global unet global vae global text_encoder global tokenizer global noise_scheduler global young global pointy global wavy global large original_weights = network.proj.clone() #pad to same number of PCs pcs_original = original_weights.shape[1] pcs_edits = young.shape[1] padding = torch.zeros((1,pcs_original-pcs_edits)).to(device) young_pad = torch.cat((young, padding), 1) pointy_pad = torch.cat((pointy, padding), 1) wavy_pad = torch.cat((wavy, padding), 1) large_pad = torch.cat((large, padding), 1) edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*8e5*large_pad generator = generator.manual_seed(seed) latents = torch.randn( (1, unet.in_channels, 512 // 8, 512 // 8), generator = generator, device = device ).bfloat16() text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = tokenizer( [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) noise_scheduler.set_timesteps(ddim_steps) latents = latents * noise_scheduler.init_noise_sigma for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) if t>start_noise: pass elif t<=start_noise: network.proj = torch.nn.Parameter(edited_weights) network.reset() with network: noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample #guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = noise_scheduler.step(noise_pred, t, latents).prev_sample latents = 1 / 0.18215 * latents image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] image = Image.fromarray((image * 255).round().astype("uint8")) #reset weights back to original network.proj = torch.nn.Parameter(original_weights) network.reset() return (image, input_image["background"]) def sample_then_run(): sample_model() prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity, weapon" seed = 5 cfg = 3.0 steps = 50 image = inference( prompt, negative_prompt, cfg, steps, seed) torch.save(network.proj, "model.pt" ) return image, "model.pt" global young global pointy global wavy global large young = get_direction(df, "Young", pinverse, 1000, device) young = debias(young, "Male", df, pinverse, device) young = debias(young, "Pointy_Nose", df, pinverse, device) young = debias(young, "Wavy_Hair", df, pinverse, device) young = debias(young, "Chubby", df, pinverse, device) pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device) pointy = debias(pointy, "Young", df, pinverse, device) pointy = debias(pointy, "Male", df, pinverse, device) pointy = debias(pointy, "Wavy_Hair", df, pinverse, device) pointy = debias(pointy, "Chubby", df, pinverse, device) pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device) wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device) wavy = debias(wavy, "Young", df, pinverse, device) wavy = debias(wavy, "Male", df, pinverse, device) wavy = debias(wavy, "Pointy_Nose", df, pinverse, device) wavy = debias(wavy, "Chubby", df, pinverse, device) wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device) large = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device) large = debias(large, "Male", df, pinverse, device) large = debias(large, "Young", df, pinverse, device) large = debias(large, "Pointy_Nose", df, pinverse, device) large = debias(large, "Wavy_Hair", df, pinverse, device) large = debias(large, "Mustache", df, pinverse, device) large = debias(large, "No_Beard", df, pinverse, device) large = debias(large, "Sideburns", df, pinverse, device) large = debias(large, "Big_Nose", df, pinverse, device) large = debias(large, "Big_Lips", df, pinverse, device) large = debias(large, "Black_Hair", df, pinverse, device) large = debias(large, "Brown_Hair", df, pinverse, device) large = debias(large, "Pale_Skin", df, pinverse, device) large = debias(large, "Heavy_Makeup", df, pinverse, device) class CustomImageDataset(Dataset): def __init__(self, images, transform=None): self.images = images self.transform = transform def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx] if self.transform: image = self.transform(image) return image def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1): global unet del unet global network unet, _, _, _, _ = load_models(device) proj = torch.zeros(1,pcs).bfloat16().to(device) network = LoRAw2w( proj, mean, std, v[:, :pcs], unet, rank=1, multiplier=1.0, alpha=27.0, train_method="xattn-strict" ).to(device, torch.bfloat16) ### load mask mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask) mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1) ### check if an actual mask was draw, otherwise mask is just all ones if torch.sum(mask) == 0: mask = torch.ones((1,1,64,64)).to(device).bfloat16() ### single image dataset image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), transforms.RandomCrop(512), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) train_dataset = CustomImageDataset(image, transform=image_transforms) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True) ### optimizer optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay) ### training loop unet.train() for epoch in tqdm.tqdm(range(epochs)): for batch in train_dataloader: ### prepare inputs batch = batch.to(device).bfloat16() latents = vae.encode(batch).latent_dist.sample() latents = latents*0.18215 noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] ### loss + sgd step with network: model_pred = unet(noisy_latents, timesteps, text_embeddings).sample loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean") optim.zero_grad() loss.backward() optim.step() ### return optimized network return network def run_inversion(input_image, pcs, epochs, weight_decay,lr): global network print(len(input_image["layers"])) init_image = input_image["background"].convert("RGB").resize((512, 512)) mask = input_image["layers"][0].convert("RGB").resize((512, 512)) network = invert([init_image], mask, pcs, epochs, weight_decay,lr) #sample an image prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity, weapon" seed = 5 cfg = 3.0 steps = 50 image = inference( prompt, negative_prompt, cfg, steps, seed) torch.save(network.proj, "model.pt" ) return (image,init_image), "model.pt" def file_upload(file): global unet del unet global network global device proj = torch.load(file.name).to(device) #pad to 10000 Principal components to keep everything consistent pcs = proj.shape[1] padding = torch.zeros((1,10000-pcs)).to(device) proj = torch.cat((proj, padding), 1) unet, _, _, _, _ = load_models(device) network = LoRAw2w( proj, mean, std, v[:, :10000], unet, rank=1, multiplier=1.0, alpha=27.0, train_method="xattn-strict" ).to(device, torch.bfloat16) prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity, weapon" seed = 5 cfg = 3.0 steps = 50 image = inference( prompt, negative_prompt, cfg, steps, seed) return image intro = """
project page | paper |
""" with gr.Blocks(css="style.css") as demo: gr.HTML(intro) gr.Markdown(""" Click sample (to sample an identity) *or* upload an image & click `invert` to get started ✨ > 💡 When inverting, draw a mask over the face for improved results. > To use a model previously downloaded from this demo see `Uplaoding a model` in the `Advanced options` """) with gr.Column(): with gr.Row(): with gr.Column(): # input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload image and draw to define mask", # height=512, width=512, brush_color='#00FFFF', mask_opacity=0.6) input_image = gr.ImageEditor(elem_id="image_upload", type='pil', label="Upload image and draw to define mask", height=512, width=512, brush=gr.Brush(), layers=False) with gr.Row(): sample = gr.Button("Sample New Model") invert_button = gr.Button("Invert") with gr.Column(): image_slider = ImageSlider(position=1., type="pil", height=512, width=512) # gallery1 = gr.Image(label="Identity from Original Model",height=512, width=512, interactive=False) prompt1 = gr.Textbox(label="Prompt", info="Make sure to include 'sks person'" , placeholder="sks person", value="sks person") # Editing with gr.Column(): #gallery2 = gr.Image(label="Identity from Edited Model", interactive=False, visible=False ) with gr.Row(): a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) with gr.Row(): a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) a4 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) # prompt2 = gr.Textbox(label="Prompt", # info="Make sure to include 'sks person'" , # placeholder="sks person", # value="sks person", visible=False) # seed2 = gr.Number(value=5, label="Seed", precision=0, interactive=True, visible=False) # submit2 = gr.Button("Generate", visible=False) with gr.Accordion("Advanced Options", open=False): with gr.Tab("Inversion"): with gr.Row(): lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True) pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True) with gr.Row(): epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True) weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True) with gr.Tab("Sampling"): with gr.Row(): cfg1= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True) steps1 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True) seed1 = gr.Number(value=5, label="Seed", precision=0, interactive=True) with gr.Row(): negative_prompt1 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon") injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True) # with gr.Tab("Editing"): # with gr.Column(): # cfg2 = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True) # steps2 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True) # injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True) # negative_prompt2 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon") with gr.Tab("Uploading a model"): gr.Markdown("""