File size: 4,965 Bytes
eb710fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import sys
import os 
import tqdm
sys.path.append(os.path.abspath(os.path.join("", "..")))
import torch
import gc
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
from utils import load_models, save_model_w2w, save_model_for_diffusers
from sampling import sample_weights
from huggingface_hub import snapshot_download

global device
global generator 
global unet
global vae 
global text_encoder
global tokenizer
global noise_scheduler
device = "cuda:0"
generator = torch.Generator(device=device)

models_path = snapshot_download(repo_id="Snapchat/w2w")

mean = torch.load(f"{models_path}/mean.pt").bfloat16().to(device)
std = torch.load(f"{models_path}/std.pt").bfloat16().to(device)
v = torch.load(f"{models_path}/V.pt").bfloat16().to(device)
proj = torch.load(f"{models_path}/proj_1000pc.pt").bfloat16().to(device)
df = torch.load(f"{models_path}/identity_df.pt")
weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt")

unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)

global network

def sample_model():
    global unet
    del unet
    global network
    unet, _, _, _, _ = load_models(device)
    network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)

def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
    global device
    global generator 
    global unet
    global vae 
    global text_encoder
    global tokenizer
    global noise_scheduler
    generator = generator.manual_seed(seed)
    latents = torch.randn(
        (1, unet.in_channels, 512 // 8, 512 // 8),
        generator = generator,
        device = device
    ).bfloat16()
   

    text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")

    text_embeddings = text_encoder(text_input.input_ids.to(device))[0]

    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer(
                            [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
                        )
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
    noise_scheduler.set_timesteps(ddim_steps) 
    latents = latents * noise_scheduler.init_noise_sigma
    
    for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
        latent_model_input = torch.cat([latents] * 2)
        latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
        with network:
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
        #guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
    
    latents = 1 / 0.18215 * latents
    image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]

    image = Image.fromarray((image * 255).round().astype("uint8"))

    return [image] 

with gr.Blocks(css=css) as demo:
    gr.Markdown("# <em>weights2weights</em> Demo")
    with gr.Row():
        with gr.Column():
            files = gr.Files(
                        label="Upload a photo of your face to invert, or sample a new model",
                        file_types=["image"]
                    )
            uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=125)

            sample = gr.Button("Sample New Model")

            with gr.Column(visible=False) as clear_button:
                remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
            prompt = gr.Textbox(label="Prompt",
                       info="Make sure to include 'sks person'" ,
                       placeholder="sks person", 
                       value="sks person")
            negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
            seed = gr.Number(value=5, precision=0, label="Seed", interactive=True)
            cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
            steps = gr.Slider(label="Inference Steps", precision=0, value=50, step=1, minimum=0, maximum=100, interactive=True)


            submit = gr.Button("Submit")

        with gr.Column():
            gallery = gr.Gallery(label="Generated Images")

        sample.click(fn=sample_model)
        
        submit.click(fn=inference,
                    inputs=[prompt, negative_prompt, cfg, steps, seed],
                    outputs=gallery)
            
demo.launch(share=True)