Migrate to Diffusers

#1
by radames - opened
.gitignore CHANGED
@@ -1 +1,4 @@
1
  .idea
 
 
 
 
1
  .idea
2
+ __pycache__/
3
+ venv/
4
+ gradio_cached_examples/
app.py CHANGED
@@ -1,118 +1,200 @@
1
- import os
2
  import random
3
- from typing import Mapping
4
 
5
  import gradio as gr
6
- import numpy
7
  import torch
8
- from huggingface_hub import hf_hub_download
9
  from PIL import Image
 
 
 
10
 
11
- from cldm.model import create_model, load_state_dict
12
- from cldm.ddim_hacked import DDIMSampler
13
  from mediapipe_face_common import generate_annotation
14
 
 
 
 
 
 
 
15
  # Download the SD 1.5 model from HF
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- model_path = hf_hub_download(repo_id="CrucibleAI/ControlNetMediaPipeFace", filename="models/controlnet_sd21_laion_face_v2_full.ckpt", repo_type="model", revision="568dc2c9980572262d48cff1ef2a7e4a03fadeb6")
18
- config_path = hf_hub_download(repo_id="CrucibleAI/ControlNetMediaPipeFace", filename="models/cldm_v21.yaml", repo_type="model", revision="568dc2c9980572262d48cff1ef2a7e4a03fadeb6")
19
- model = create_model(config_path).cpu()
20
- model.load_state_dict(load_state_dict(model_path, location=device))
 
21
  model = model.to(device)
22
- ddim_sampler = DDIMSampler(model) # ControlNet _only_ works with DDIM.
23
-
24
-
25
- def process(input_image: Image.Image, prompt, a_prompt, n_prompt, max_faces: int, min_confidence: float, num_samples, ddim_steps, guess_mode, strength, scale, seed: int, eta):
26
- with torch.no_grad():
27
- # Scale to 512x512.
28
- img_size = input_image.size
29
- scale_factor = 512/min(img_size)
30
- input_image = input_image.resize((1+int(img_size[0]*scale_factor), 1+int(img_size[1]*scale_factor)))
31
- img_size = input_image.size
32
- left_padding = (img_size[0] - 512)//2
33
- top_padding = (img_size[1] - 512)//2
34
- input_image = input_image.crop((left_padding, top_padding, left_padding+512, top_padding+512))
35
-
36
- # Generate annotation
37
- input_image = numpy.asarray(input_image)
38
- empty = generate_annotation(input_image, max_faces, min_confidence)
39
- visualization = Image.fromarray(empty) # Save to help debug.
40
-
41
- # Prep for network:
42
- empty = numpy.moveaxis(empty, 2, 0) # h, w, c -> c, h, w
43
- control = torch.from_numpy(empty.copy()).float().to(device) / 255.0
44
- control = torch.stack([control for _ in range(num_samples)], dim=0)
45
- # control = einops.rearrange(control, 'b h w c -> b c h w').clone()
46
-
47
- # Sanity check the dimensions.
48
- B, C, H, W = control.shape
49
- assert C == 3
50
- assert B == num_samples
51
-
52
- if seed != -1:
53
- random.seed(seed)
54
- os.environ['PYTHONHASHSEED'] = str(seed)
55
- numpy.random.seed(seed)
56
- torch.manual_seed(seed)
57
- torch.cuda.manual_seed(seed)
58
- torch.backends.cudnn.deterministic = True
59
-
60
- # model.low_vram_shift(is_diffusing=False)
61
-
62
- cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
63
- un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
64
- shape = (4, H // 8, W // 8)
65
-
66
- # model.low_vram_shift(is_diffusing=True)
67
-
68
- model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
69
- samples, intermediates = ddim_sampler.sample(
70
- ddim_steps,
71
- num_samples,
72
- shape,
73
- cond,
74
- verbose=False,
75
- eta=eta,
76
- unconditional_guidance_scale=scale,
77
- unconditional_conditioning=un_cond
78
- )
79
-
80
- # model.low_vram_shift(is_diffusing=False)
81
-
82
- x_samples = model.decode_first_stage(samples)
83
- # x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(numpy.uint8)
84
- x_samples = numpy.moveaxis((x_samples * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(numpy.uint8), 1, -1) # b, c, h, w -> b, h, w, c
85
- results = [visualization] + [x_samples[i] for i in range(num_samples)]
86
-
87
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
 
90
  block = gr.Blocks().queue()
91
  with block:
 
 
92
  with gr.Row():
93
  gr.Markdown("## Control Stable Diffusion with a Facial Pose")
94
  with gr.Row():
95
  with gr.Column():
96
- input_image = gr.Image(source='upload', type="pil")
 
 
 
 
 
 
 
 
 
97
  prompt = gr.Textbox(label="Prompt")
98
  run_button = gr.Button(label="Run")
99
  with gr.Accordion("Advanced options", open=False):
100
- num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
101
- max_faces = gr.Slider(label="Max Faces", minimum=1, maximum=10, value=5, step=1)
102
- min_confidence = gr.Slider(label="Min Confidence", minimum=0.01, maximum=1.0, value=0.5, step=0.01)
103
- strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
 
 
 
 
104
  guess_mode = gr.Checkbox(label='Guess Mode', value=False)
105
- ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
106
- scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
107
- seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
 
 
 
108
  eta = gr.Number(label="eta (DDIM)", value=0.0)
109
- a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
 
110
  n_prompt = gr.Textbox(label="Negative Prompt",
111
  value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
112
  with gr.Column():
113
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
114
- ips = [input_image, prompt, a_prompt, n_prompt, max_faces, min_confidence, num_samples, ddim_steps, guess_mode, strength, scale, seed, eta]
115
- run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
116
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  block.launch(server_name='0.0.0.0')
 
 
1
  import random
 
2
 
3
  import gradio as gr
 
4
  import torch
5
+ from diffusers.utils import load_image
6
  from PIL import Image
7
+ import numpy as np
8
+ import base64
9
+ from io import BytesIO
10
 
 
 
11
  from mediapipe_face_common import generate_annotation
12
 
13
+ from diffusers import (
14
+ ControlNetModel,
15
+ StableDiffusionControlNetPipeline,
16
+ )
17
+
18
+
19
  # Download the SD 1.5 model from HF
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ controlnet = ControlNetModel.from_pretrained(
22
+ "CrucibleAI/ControlNetMediaPipeFace", torch_dtype=torch.float16, variant="fp16")
23
+ model = StableDiffusionControlNetPipeline.from_pretrained(
24
+ "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
25
+ )
26
  model = model.to(device)
27
+ model.enable_model_cpu_offload()
28
+
29
+
30
+ canvas_html = "<face-canvas id='canvas-root' data-mode='crucibleAI' style='display:flex;max-width: 500px;margin: 0 auto;'></face-canvas>"
31
+ load_js = """
32
+ async () => {
33
+ const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/face-canvas.js"
34
+ fetch(url)
35
+ .then(res => res.text())
36
+ .then(text => {
37
+ const script = document.createElement('script');
38
+ script.type = "module"
39
+ script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' }));
40
+ document.head.appendChild(script);
41
+ });
42
+ }
43
+ """
44
+ get_js_image = """
45
+ async (input_image, prompt, a_prompt, n_prompt, max_faces, min_confidence, num_samples, ddim_steps, guess_mode, strength, scale, seed, eta, image_file_live_opt, live_conditioning) => {
46
+ const canvasEl = document.getElementById("canvas-root");
47
+ const imageData = canvasEl? canvasEl._data : null;
48
+ return [input_image, prompt, a_prompt, n_prompt, max_faces, min_confidence, num_samples, ddim_steps, guess_mode, strength, scale, seed, eta, image_file_live_opt, imageData];
49
+ }
50
+ """
51
+
52
+
53
+ def pad_image(input_image):
54
+ pad_w, pad_h = np.max(((2, 2), np.ceil(
55
+ np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size
56
+ im_padded = Image.fromarray(
57
+ np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
58
+ w, h = im_padded.size
59
+ if w == h:
60
+ return im_padded
61
+ elif w > h:
62
+ new_image = Image.new(im_padded.mode, (w, w), (0, 0, 0))
63
+ new_image.paste(im_padded, (0, (w - h) // 2))
64
+ return new_image
65
+ else:
66
+ new_image = Image.new(im_padded.mode, (h, h), (0, 0, 0))
67
+ new_image.paste(im_padded, ((h - w) // 2, 0))
68
+ return new_image
69
+
70
+
71
+ def process(input_image: Image.Image, prompt, a_prompt, n_prompt, max_faces: int, min_confidence: float, num_samples, ddim_steps, guess_mode, strength, scale, seed: int, eta, image_file_live_opt="file", live_conditioning=None):
72
+ if input_image is None and 'image' not in live_conditioning:
73
+ raise gr.Error("Please provide an image")
74
+ try:
75
+ if image_file_live_opt == 'file':
76
+ input_image = input_image.convert('RGB')
77
+ empty = generate_annotation(
78
+ np.array(input_image), max_faces, min_confidence)
79
+ visualization = Image.fromarray(empty) # Save to help debug.
80
+ visualization = pad_image(visualization).resize((512, 512))
81
+ elif image_file_live_opt == 'webcam':
82
+ base64_img = live_conditioning['image']
83
+ image_data = base64.b64decode(base64_img.split(',')[1])
84
+ visualization = Image.open(BytesIO(image_data)).convert(
85
+ 'RGB').resize((512, 512))
86
+ if seed == -1:
87
+ seed = random.randint(0, 2147483647)
88
+ generator = torch.Generator(device).manual_seed(seed)
89
+
90
+ output = model(prompt=prompt + ' ' + a_prompt,
91
+ negative_prompt=n_prompt,
92
+ image=visualization,
93
+ generator=generator,
94
+ num_images_per_prompt=num_samples,
95
+ num_inference_steps=ddim_steps,
96
+ controlnet_conditioning_scale=strength,
97
+ guidance_scale=scale,
98
+ eta=eta,
99
+ )
100
+ results = [visualization] + output.images
101
+
102
+ return results
103
+ except Exception as e:
104
+ raise gr.Error(str(e))
105
+
106
+ # switch between file upload and webcam
107
+
108
+
109
+ def toggle(choice):
110
+ if choice == "file":
111
+ return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
112
+ elif choice == "webcam":
113
+ return gr.update(visible=False, value=None), gr.update(visible=True, value=canvas_html)
114
 
115
 
116
  block = gr.Blocks().queue()
117
  with block:
118
+ # hidden JSON component to store live conditioning
119
+ live_conditioning = gr.JSON(value={}, visible=False)
120
  with gr.Row():
121
  gr.Markdown("## Control Stable Diffusion with a Facial Pose")
122
  with gr.Row():
123
  with gr.Column():
124
+ image_file_live_opt = gr.Radio(["file", "webcam"], value="file",
125
+ label="How would you like to upload your image?")
126
+ input_image = gr.Image(source="upload", visible=True, type="pil")
127
+ canvas = gr.HTML(None, elem_id="canvas_html", visible=False)
128
+
129
+ image_file_live_opt.change(fn=toggle,
130
+ inputs=[image_file_live_opt],
131
+ outputs=[input_image, canvas],
132
+ queue=False)
133
+
134
  prompt = gr.Textbox(label="Prompt")
135
  run_button = gr.Button(label="Run")
136
  with gr.Accordion("Advanced options", open=False):
137
+ num_samples = gr.Slider(
138
+ label="Images", minimum=1, maximum=4, value=1, step=1)
139
+ max_faces = gr.Slider(
140
+ label="Max Faces", minimum=1, maximum=10, value=5, step=1)
141
+ min_confidence = gr.Slider(
142
+ label="Min Confidence", minimum=0.01, maximum=1.0, value=0.5, step=0.01)
143
+ strength = gr.Slider(
144
+ label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
145
  guess_mode = gr.Checkbox(label='Guess Mode', value=False)
146
+ ddim_steps = gr.Slider(
147
+ label="Steps", minimum=1, maximum=100, value=20, step=1)
148
+ scale = gr.Slider(label="Guidance Scale",
149
+ minimum=0.1, maximum=30.0, value=9.0, step=0.1)
150
+ seed = gr.Slider(label="Seed", minimum=-1,
151
+ maximum=2147483647, step=1, randomize=True)
152
  eta = gr.Number(label="eta (DDIM)", value=0.0)
153
+ a_prompt = gr.Textbox(
154
+ label="Added Prompt", value='best quality, extremely detailed')
155
  n_prompt = gr.Textbox(label="Negative Prompt",
156
  value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
157
  with gr.Column():
158
+ result_gallery = gr.Gallery(
159
+ label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
160
+ ips = [input_image, prompt, a_prompt, n_prompt, max_faces, min_confidence,
161
+ num_samples, ddim_steps, guess_mode, strength, scale, seed, eta]
162
+ run_button.click(fn=process, inputs=ips + [image_file_live_opt, live_conditioning],
163
+ outputs=[result_gallery],
164
+ _js=get_js_image)
165
+
166
+ # load js for live conditioning
167
+ block.load(None, None, None, _js=load_js)
168
+ gr.Examples(fn=process,
169
+ examples=[
170
+ ["./examples/two2.jpeg",
171
+ "Highly detailed photograph of two clowns",
172
+ "best quality, extremely detailed",
173
+ "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
174
+ 10, 0.4, 3, 20, False, 1.0, 9.0, -1, 0.0],
175
+ ["./examples/two.jpeg",
176
+ "a photo of two silly men",
177
+ "best quality, extremely detailed",
178
+ "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
179
+ 10, 0.4, 3, 20, False, 1.0, 9.0, -1, 0.0],
180
+ ["./examples/pedro-512.jpg",
181
+ "Highly detailed photograph of young woman smiling, with palm trees in the background",
182
+ "best quality, extremely detailed",
183
+ "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
184
+ 10, 0.4, 3, 20, False, 1.0, 9.0, -1, 0.0],
185
+ ["./examples/image1.jpg",
186
+ "Highly detailed photograph of a scary clown",
187
+ "best quality, extremely detailed",
188
+ "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
189
+ 10, 0.4, 3, 20, False, 1.0, 9.0, -1, 0.0],
190
+ ["./examples/image0.jpg",
191
+ "Highly detailed photograph of Madonna",
192
+ "best quality, extremely detailed",
193
+ "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
194
+ 10, 0.4, 3, 20, False, 1.0, 9.0, -1, 0.0],
195
+ ],
196
+ inputs=ips,
197
+ outputs=[result_gallery],
198
+ cache_examples=True)
199
 
200
  block.launch(server_name='0.0.0.0')
examples/.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
2
+ *.jpg filter=lfs diff=lfs merge=lfs -text
examples/image0.jpg ADDED
examples/image1.jpg ADDED
examples/pedro-512.jpg ADDED
examples/two.jpeg ADDED
examples/two2.jpeg ADDED
requirements.txt CHANGED
@@ -11,4 +11,6 @@ timm
11
  transformers==4.26.1
12
  torch==1.13.1
13
  torchvision==0.14.1
14
- tqdm==4.64.1
 
 
 
11
  transformers==4.26.1
12
  torch==1.13.1
13
  torchvision==0.14.1
14
+ tqdm==4.64.1
15
+ accelerate
16
+ diffusers