Maitreyapatel commited on
Commit
964ba1e
1 Parent(s): 594c562
Files changed (2) hide show
  1. app.py +87 -23
  2. test.py +0 -220
app.py CHANGED
@@ -18,7 +18,7 @@ import cv2
18
  import numpy as np
19
  import torch
20
 
21
- is_gpu_busy = False
22
  PRETRAINED_MODEL_NAME_OR_PATH = "./checkpoints/"
23
 
24
 
@@ -37,8 +37,13 @@ def get_image_grid(images: List[Image.Image]) -> Image:
37
 
38
  class AttributionModel:
39
  def __init__(self):
 
 
 
 
40
  self.pipe = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2')#, safety_checker=None, torch_dtype=torch.float16)
41
- self.pipe = self.pipe.to("cuda")
 
42
  self.resize_transform = transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR)
43
  self.vae = AutoencoderKL.from_pretrained(
44
  'stabilityai/stable-diffusion-2', subfolder="vae"
@@ -55,9 +60,10 @@ class AttributionModel:
55
  self.mapping_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'mapping_network.pth')))
56
  self.decoding_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'decoding_network.pth')))
57
 
58
- self.vae = self.vae.to("cuda")
59
- self.mapping_network = self.mapping_network.to("cuda")
60
- self.decoding_network = self.decoding_network.to("cuda")
 
61
 
62
  self.test_norm = transforms.Compose(
63
  [
@@ -65,21 +71,26 @@ class AttributionModel:
65
  ]
66
  )
67
 
68
- def infer(self, prompt, negative, guidance_scale):
69
- images = []
70
  with torch.no_grad():
71
- out_latents = self.pipe([prompt], output_type="latent", num_inference_steps=10, guidance_scale=guidance_scale).images
72
- image = self.inference_with_attribution(out_latents)
73
- print(image[0])
74
- # image = self.pipe.numpy_to_pil(image)
75
- # image[0].save("im1.jpg")
76
- return [image[0]]*3 #, "caption") #get_image_grid(images)
 
 
 
 
77
 
78
  def inference_without_attribution(self, latents):
79
  latents = 1 / 0.18215 * latents
80
  with torch.no_grad():
81
  image = self.pipe.vae.decode(latents).sample
82
  image = image.clamp(-1,1)
 
 
83
  return image
84
 
85
  def get_phis(self, phi_dimension, batch_size ,eps = 1e-8):
@@ -111,8 +122,29 @@ class AttributionModel:
111
 
112
 
113
  attribution_model = AttributionModel()
 
 
 
 
 
 
 
 
 
114
 
115
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
116
  with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
117
  with gr.Column():
118
  text = gr.Textbox(
@@ -137,20 +169,52 @@ with gr.Blocks() as demo:
137
  rounded=(True, False, False, True),
138
  container=False,
139
  )
140
- btn = gr.Button("Generate image").style(full_width=False)
141
-
142
- with gr.Row():
143
- img_output_simple = gr.Image()
144
- img_output_attribute = gr.Image()
145
- img_output_diff = gr.Image()
146
-
147
 
148
  with gr.Row():
 
149
  guidance_scale = gr.Slider(
150
  label="Guidance Scale", minimum=0, maximum=10, value=9, step=0.1
151
  )
152
- btn.click(attribution_model.infer, inputs=[text, negative, guidance_scale], outputs=[img_output_simple, img_output_attribute, img_output_diff], postprocess=False)
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- if __name__=="__main__":
156
- demo.queue(concurrency_count=1, max_size=20).launch(max_threads=50)
 
18
  import numpy as np
19
  import torch
20
 
21
+
22
  PRETRAINED_MODEL_NAME_OR_PATH = "./checkpoints/"
23
 
24
 
 
37
 
38
  class AttributionModel:
39
  def __init__(self):
40
+ is_cuda = False
41
+ if torch.cuda.is_available():
42
+ is_cuda = True
43
+
44
  self.pipe = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2')#, safety_checker=None, torch_dtype=torch.float16)
45
+ if is_cuda:
46
+ self.pipe = self.pipe.to("cuda")
47
  self.resize_transform = transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR)
48
  self.vae = AutoencoderKL.from_pretrained(
49
  'stabilityai/stable-diffusion-2', subfolder="vae"
 
60
  self.mapping_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'mapping_network.pth')))
61
  self.decoding_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'decoding_network.pth')))
62
 
63
+ if is_cuda:
64
+ self.vae = self.vae.to("cuda")
65
+ self.mapping_network = self.mapping_network.to("cuda")
66
+ self.decoding_network = self.decoding_network.to("cuda")
67
 
68
  self.test_norm = transforms.Compose(
69
  [
 
71
  ]
72
  )
73
 
74
+ def infer(self, prompt, negative, steps, guidance_scale):
 
75
  with torch.no_grad():
76
+ out_latents = self.pipe([prompt], negative_prompt=[negative], output_type="latent", num_inference_steps=steps, guidance_scale=guidance_scale).images
77
+ image_attr = self.inference_with_attribution(out_latents)
78
+ image_attr_pil = self.pipe.numpy_to_pil(image_attr[0])
79
+
80
+ image_org = self.inference_without_attribution(out_latents)
81
+ image_org_pil = self.pipe.numpy_to_pil(image_org[0])
82
+
83
+ image_diff_pil = self.pipe.numpy_to_pil(image_attr[0] - image_org[0])
84
+
85
+ return image_org_pil[0], image_attr_pil[0], image_diff_pil[0]
86
 
87
  def inference_without_attribution(self, latents):
88
  latents = 1 / 0.18215 * latents
89
  with torch.no_grad():
90
  image = self.pipe.vae.decode(latents).sample
91
  image = image.clamp(-1,1)
92
+ image = (image / 2 + 0.5).clamp(0, 1)
93
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
94
  return image
95
 
96
  def get_phis(self, phi_dimension, batch_size ,eps = 1e-8):
 
122
 
123
 
124
  attribution_model = AttributionModel()
125
+ def get_images(prompt, negative, steps, guidence_scale):
126
+ x1, x2, x3 = attribution_model.infer(prompt, negative, steps, guidence_scale)
127
+ return [x1, x2, x3]
128
+
129
+
130
+ image_examples = [
131
+ ["A pikachu fine dining with a view to the Eiffel Tower", "low quality", 50, 10],
132
+ ["A mecha robot in a favela in expressionist style", "low quality, 3d, photorealistic", 50, 10]
133
+ ]
134
 
135
  with gr.Blocks() as demo:
136
+ gr.Markdown(
137
+ """<h1 style="text-align: center;"><b>WOUAF:
138
+ Weight Modulation for User Attribution and Fingerprinting in Text-to-Image Diffusion Models</b> <br> <a href="https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion
139
+ -Models/">Project Page</a></h1>""")
140
+
141
+ gr.Markdown(
142
+ """<h3>Demo: Text-to-Image (Stable diffusion 2) with random user attribution</h3>
143
+ WOUAF can be applied to other applications such as In-painting, Image-editing, Image Super-Resolution etc.
144
+ <br>More details at: <a href="https://arxiv.org/abs/2306.04744">Paper</a>
145
+ """
146
+ )
147
+
148
  with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
149
  with gr.Column():
150
  text = gr.Textbox(
 
169
  rounded=(True, False, False, True),
170
  container=False,
171
  )
 
 
 
 
 
 
 
172
 
173
  with gr.Row():
174
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
175
  guidance_scale = gr.Slider(
176
  label="Guidance Scale", minimum=0, maximum=10, value=9, step=0.1
177
  )
 
178
 
179
+ with gr.Row():
180
+ btn = gr.Button(value="Generate Image", full_width=False)
181
+
182
+ with gr.Row():
183
+ im_2 = gr.Image(type="pil", label="without attribution")
184
+ im_3 = gr.Image(type="pil", label="**with** attribution")
185
+ im_4 = gr.Image(type="pil", label="pixel-wise difference")
186
+
187
+
188
+ btn.click(get_images, inputs=[text, negative, steps, guidance_scale], outputs=[im_2, im_3, im_4])
189
+
190
+ gr.Examples(
191
+ examples=image_examples,
192
+ inputs=[text, negative, steps, guidance_scale],
193
+ outputs=[im_2, im_3, im_4],
194
+ fn=get_images,
195
+ cache_examples=True,
196
+ )
197
+
198
+ gr.HTML(
199
+ """
200
+ <div class="footer">
201
+ <p>Pre-trained model by <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">StabilityAI</a>
202
+ </p>
203
+ <p>
204
+ Fine-tuned by authors for research purpose.
205
+ </p>
206
+ </div>
207
+ """
208
+ )
209
+ with gr.Accordion(label="Ethics & Privacy", open=False):
210
+ gr.HTML(
211
+ """<div class="acknowledgments">
212
+ <p><h4>Privacy</h4>
213
+ We do not collect any images or key data. This demo is designed with sole purpose of fun and reducing misuse of AI.
214
+ <p><h4>Biases and content acknowledgment</h4>
215
+ This model will have the same biases as Stable Diffusion V2.1 </div>
216
+ """
217
+ )
218
 
219
+ if __name__ == "__main__":
220
+ demo.launch()
test.py DELETED
@@ -1,220 +0,0 @@
1
- import gradio as gr
2
- from PIL import Image
3
-
4
- import torch
5
- import re
6
- import os
7
- import requests
8
-
9
- from customization import customize_vae_decoder
10
- from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, DDIMScheduler, EulerDiscreteScheduler
11
- from torchvision import transforms
12
- from attribution import MappingNetwork
13
-
14
- import math
15
- from typing import List
16
- from PIL import Image
17
- import cv2
18
- import numpy as np
19
- import torch
20
-
21
-
22
- PRETRAINED_MODEL_NAME_OR_PATH = "./checkpoints/"
23
-
24
-
25
- def get_image_grid(images: List[Image.Image]) -> Image:
26
- num_images = len(images)
27
- cols = 3#int(math.ceil(math.sqrt(num_images)))
28
- rows = 1#int(math.ceil(num_images / cols))
29
- width, height = images[0].size
30
- grid_image = Image.new('RGB', (cols * width, rows * height))
31
- for i, img in enumerate(images):
32
- x = i % cols
33
- y = i // cols
34
- grid_image.paste(img, (x * width, y * height))
35
- return grid_image
36
-
37
-
38
- class AttributionModel:
39
- def __init__(self):
40
- is_cuda = False
41
- if torch.cuda.is_available():
42
- is_cuda = True
43
-
44
- self.pipe = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2')#, safety_checker=None, torch_dtype=torch.float16)
45
- if is_cuda:
46
- self.pipe = self.pipe.to("cuda")
47
- self.resize_transform = transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR)
48
- self.vae = AutoencoderKL.from_pretrained(
49
- 'stabilityai/stable-diffusion-2', subfolder="vae"
50
- )
51
- self.vae = customize_vae_decoder(self.vae, 128, "qkv", "all", False, 1.0)
52
-
53
- self.mapping_network = MappingNetwork(32, 0, 128, None, num_layers=2, w_avg_beta=None, normalization = False).to("cuda")
54
-
55
- from torchvision.models import resnet50, ResNet50_Weights
56
- self.decoding_network = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
57
- self.decoding_network.fc = torch.nn.Linear(2048,32)
58
-
59
- self.vae.decoder.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'vae_decoder.pth')))
60
- self.mapping_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'mapping_network.pth')))
61
- self.decoding_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'decoding_network.pth')))
62
-
63
- if is_cuda:
64
- self.vae = self.vae.to("cuda")
65
- self.mapping_network = self.mapping_network.to("cuda")
66
- self.decoding_network = self.decoding_network.to("cuda")
67
-
68
- self.test_norm = transforms.Compose(
69
- [
70
- transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
71
- ]
72
- )
73
-
74
- def infer(self, prompt, negative, steps, guidance_scale):
75
- with torch.no_grad():
76
- out_latents = self.pipe([prompt], negative_prompt=[negative], output_type="latent", num_inference_steps=steps, guidance_scale=guidance_scale).images
77
- image_attr = self.inference_with_attribution(out_latents)
78
- image_attr_pil = self.pipe.numpy_to_pil(image_attr[0])
79
-
80
- image_org = self.inference_without_attribution(out_latents)
81
- image_org_pil = self.pipe.numpy_to_pil(image_org[0])
82
-
83
- image_diff_pil = self.pipe.numpy_to_pil(image_attr[0] - image_org[0])
84
-
85
- return image_org_pil[0], image_attr_pil[0], image_diff_pil[0]
86
-
87
- def inference_without_attribution(self, latents):
88
- latents = 1 / 0.18215 * latents
89
- with torch.no_grad():
90
- image = self.pipe.vae.decode(latents).sample
91
- image = image.clamp(-1,1)
92
- image = (image / 2 + 0.5).clamp(0, 1)
93
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
94
- return image
95
-
96
- def get_phis(self, phi_dimension, batch_size ,eps = 1e-8):
97
- phi_length = phi_dimension
98
- b = batch_size
99
- phi = torch.empty(b,phi_length).uniform_(0,1)
100
- return torch.bernoulli(phi) + eps
101
-
102
-
103
- def inference_with_attribution(self, latents, key=None):
104
- if key==None:
105
- key = self.get_phis(32, 1)
106
-
107
- latents = 1 / 0.18215 * latents
108
- with torch.no_grad():
109
- image = self.vae.decode(latents, self.mapping_network(key.cuda())).sample
110
- image = image.clamp(-1,1)
111
- image = (image / 2 + 0.5).clamp(0, 1)
112
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
113
- return image
114
-
115
- def postprocess(self, image):
116
- image = self.resize_transform(image)
117
- return image
118
-
119
- def detect_key(self, image):
120
- reconstructed_keys = self.decoding_network(self.test_norm((image / 2 + 0.5).clamp(0, 1)))
121
- return reconstructed_keys
122
-
123
-
124
- attribution_model = AttributionModel()
125
- def get_images(prompt, negative, steps, guidence_scale):
126
- x1, x2, x3 = attribution_model.infer(prompt, negative, steps, guidence_scale)
127
- return [x1, x2, x3]
128
-
129
-
130
- image_examples = [
131
- ["A pikachu fine dining with a view to the Eiffel Tower", "low quality", 50, 10],
132
- ["A mecha robot in a favela in expressionist style", "low quality, 3d, photorealistic", 50, 10]
133
- ]
134
-
135
- with gr.Blocks() as demo:
136
- gr.Markdown(
137
- """<h1 style="text-align: center;"><b>WOUAF:
138
- Weight Modulation for User Attribution and Fingerprinting in Text-to-Image Diffusion Models</b> <br> <a href="https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion
139
- -Models/">Project Page</a></h1>""")
140
-
141
- gr.Markdown(
142
- """<h3>Demo: Text-to-Image (Stable diffusion 2) with random user attribution</h3>
143
- WOUAF can be applied to other applications such as In-painting, Image-editing, Image Super-Resolution etc.
144
- <br>More details at: <a href="https://arxiv.org/abs/2306.04744">Paper</a>
145
- """
146
- )
147
-
148
- with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
149
- with gr.Column():
150
- text = gr.Textbox(
151
- label="Enter your prompt",
152
- show_label=False,
153
- max_lines=1,
154
- placeholder="Enter your prompt",
155
- elem_id="prompt-text-input",
156
- ).style(
157
- border=(True, False, True, True),
158
- rounded=(True, False, False, True),
159
- container=False,
160
- )
161
- negative = gr.Textbox(
162
- label="Enter your negative prompt",
163
- show_label=False,
164
- max_lines=1,
165
- placeholder="Enter a negative prompt",
166
- elem_id="negative-prompt-text-input",
167
- ).style(
168
- border=(True, False, True, True),
169
- rounded=(True, False, False, True),
170
- container=False,
171
- )
172
-
173
- with gr.Row():
174
- steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
175
- guidance_scale = gr.Slider(
176
- label="Guidance Scale", minimum=0, maximum=10, value=9, step=0.1
177
- )
178
-
179
- with gr.Row():
180
- btn = gr.Button(value="Generate Image", full_width=False)
181
-
182
- with gr.Row():
183
- im_2 = gr.Image(type="pil", label="without attribution")
184
- im_3 = gr.Image(type="pil", label="**with** attribution")
185
- im_4 = gr.Image(type="pil", label="pixel-wise difference")
186
-
187
-
188
- btn.click(get_images, inputs=[text, negative, steps, guidance_scale], outputs=[im_2, im_3, im_4])
189
-
190
- gr.Examples(
191
- examples=image_examples,
192
- inputs=[text, negative, steps, guidance_scale],
193
- outputs=[im_2, im_3, im_4],
194
- fn=get_images,
195
- cache_examples=True,
196
- )
197
-
198
- gr.HTML(
199
- """
200
- <div class="footer">
201
- <p>Pre-trained model by <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">StabilityAI</a>
202
- </p>
203
- <p>
204
- Fine-tuned by authors for research purpose.
205
- </p>
206
- </div>
207
- """
208
- )
209
- with gr.Accordion(label="Ethics & Privacy", open=False):
210
- gr.HTML(
211
- """<div class="acknowledgments">
212
- <p><h4>Privacy</h4>
213
- We do not collect any images or key data. This demo is designed with sole purpose of fun and reducing misuse of AI.
214
- <p><h4>Biases and content acknowledgment</h4>
215
- This model will have the same biases as Stable Diffusion V2.1 </div>
216
- """
217
- )
218
-
219
- if __name__ == "__main__":
220
- demo.launch()