Georgiy Grigorev commited on
Commit
054082d
·
1 Parent(s): c85ee6c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from torch.optim import AdamW
4
+ from diffusers import StableDiffusionPipeline
5
+ from torch import autocast, inference_mode
6
+ import torch
7
+ import numpy as np
8
+
9
+ from scheduling_ddim import DDIMScheduler
10
+
11
+
12
+ device = 'cuda'
13
+ # don't forget to add your token or comment if already logged in
14
+ pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
15
+ scheduler=DDIMScheduler(beta_end=0.012,
16
+ beta_schedule="scaled_linear",
17
+ beta_start=0.00085),
18
+ use_auth_token="").to(device)
19
+ _ = pipe.vae.requires_grad_(False)
20
+ _ = pipe.text_encoder.requires_grad_(False)
21
+ _ = pipe.unet.requires_grad_(False)
22
+
23
+ def preprocess(image):
24
+ w, h = image.size
25
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
26
+ image = image.resize((w, h), resample=Image.LANCZOS)
27
+ image = np.array(image).astype(np.float32) / 255.0
28
+ image = image[None].transpose(0, 3, 1, 2)
29
+ image = torch.from_numpy(image)
30
+ return 2.0 * image - 1.0
31
+
32
+ def im2latent(pipe, im, generator):
33
+ init_image = preprocess(im).to(pipe.device)
34
+ init_latent_dist = pipe.vae.encode(init_image).latent_dist
35
+ init_latents = init_latent_dist.sample(generator=generator)
36
+
37
+ return init_latents * 0.18215
38
+
39
+
40
+ def image_mod(init_image, source_prompt, prompt, scale, steps, seed):
41
+ # fix seed
42
+ g = torch.Generator(device=pipe.device).manual_seed(84)
43
+
44
+ image_latents = im2latent(pipe, init_image, g)
45
+ pipe.scheduler.set_timesteps(steps)
46
+ # use text describing an image
47
+ # source_prompt = "a photo of a woman"
48
+ context = pipe._encode_prompt(source_prompt, pipe.device, 1, False, "")
49
+
50
+ decoded_latents = image_latents.clone()
51
+ with autocast(device), inference_mode():
52
+ # we are pivoting timesteps as we are moving in opposite direction
53
+ timesteps = pipe.scheduler.timesteps.flip(0)
54
+ # this would be our targets for pivoting
55
+ init_trajectory = torch.empty(len(timesteps), *decoded_latents.size()[1:], device=decoded_latents.device, dtype=decoded_latents.dtype)
56
+ for i, t in enumerate(tqdm(timesteps)):
57
+ init_trajectory[i:i+1] = decoded_latents
58
+ noise_pred = pipe.unet(decoded_latents, t, encoder_hidden_states=context).sample
59
+ decoded_latents = pipe.scheduler.reverse_step(noise_pred, t, decoded_latents).next_sample
60
+
61
+ # we would need to flip trajectory values for pivoting in right direction
62
+ init_trajectory = init_trajectory.cpu().flip(0)
63
+
64
+ latents = decoded_latents.clone()
65
+ context_uncond = pipe._encode_prompt("", pipe.device, 1, False, "")
66
+ # we will be optimizing uncond text embedding
67
+ context_uncond.requires_grad_(True)
68
+
69
+ # use same text
70
+ # prompt = "a photo of a woman"
71
+ context_cond = pipe._encode_prompt(prompt, pipe.device, 1, False, "")
72
+
73
+ # default lr works
74
+ opt = AdamW([context_uncond])
75
+
76
+ # concat latents for classifier-free guidance
77
+ latents = torch.cat([latents, latents])
78
+ latents.requires_grad_(True)
79
+ context = torch.cat((context_uncond, context_cond))
80
+
81
+ with autocast(device):
82
+ for i, t in enumerate(tqdm(pipe.scheduler.timesteps)):
83
+ latents = pipe.scheduler.scale_model_input(latents, t)
84
+ uncond, cond = pipe.unet(latents, t, encoder_hidden_states=context).sample.chunk(2)
85
+ with torch.enable_grad():
86
+ latents = pipe.scheduler.step(uncond + scale * (cond - uncond), t, latents, generator=g).prev_sample
87
+
88
+ opt.zero_grad()
89
+ # optimize uncond text emb
90
+ pivot_value = init_trajectory[[i]].to(pipe.device)
91
+ (latents - pivot_value).mean().backward()
92
+ opt.step()
93
+ latents = latents.detach()
94
+
95
+ images = pipe.decode_latents(latents)
96
+ im = pipe.numpy_to_pil(images)[0]
97
+ return im
98
+
99
+
100
+ demo = gr.Interface(
101
+ image_mod,
102
+ inputs=[gr.Image(type="pil"), gr.Textbox("a photo of a person"), gr.Textbox("a photo of a person"), gr.Slider(0, 10, 0.5, 0.1), gr.Slider(0, 100, 51, 1), gr.Number(42)],
103
+ outputs="image",
104
+ flagging_options=["blurry", "incorrect", "other"], examples=[
105
+ os.path.join(os.path.dirname(__file__), "images/00001.jpg"),
106
+ ])
107
+
108
+ if __name__ == "__main__":
109
+ demo.launch()