nowsyn commited on
Commit
107040a
1 Parent(s): 5eee39d

upload code

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__
2
+ */__pycache__
3
+ **/__pycache__
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 OpenMMLab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
annotator/hed/__init__.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an improved version and model of HED edge detection with Apache License, Version 2.0.
2
+ # Please use this implementation in your products
3
+ # This implementation may produce slightly different results from Saining Xie's official implementations,
4
+ # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
5
+ # Different from official models and other implementations, this is an RGB-input model (rather than BGR)
6
+ # and in this way it works better for gradio's RGB protocol
7
+
8
+ import os
9
+ import cv2
10
+ import torch
11
+ import numpy as np
12
+
13
+ from einops import rearrange
14
+ from annotator.util import annotator_ckpts_path, safe_step
15
+
16
+
17
+ class DoubleConvBlock(torch.nn.Module):
18
+ def __init__(self, input_channel, output_channel, layer_number):
19
+ super().__init__()
20
+ self.convs = torch.nn.Sequential()
21
+ self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
22
+ for i in range(1, layer_number):
23
+ self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
24
+ self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
25
+
26
+ def __call__(self, x, down_sampling=False):
27
+ h = x
28
+ if down_sampling:
29
+ h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
30
+ for conv in self.convs:
31
+ h = conv(h)
32
+ h = torch.nn.functional.relu(h)
33
+ return h, self.projection(h)
34
+
35
+
36
+ class ControlNetHED_Apache2(torch.nn.Module):
37
+ def __init__(self):
38
+ super().__init__()
39
+ self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
40
+ self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
41
+ self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
42
+ self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
43
+ self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
44
+ self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
45
+
46
+ def __call__(self, x):
47
+ h = x - self.norm
48
+ h, projection1 = self.block1(h)
49
+ h, projection2 = self.block2(h, down_sampling=True)
50
+ h, projection3 = self.block3(h, down_sampling=True)
51
+ h, projection4 = self.block4(h, down_sampling=True)
52
+ h, projection5 = self.block5(h, down_sampling=True)
53
+ return projection1, projection2, projection3, projection4, projection5
54
+
55
+
56
+ class HEDdetector:
57
+ def __init__(self):
58
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth"
59
+ modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth")
60
+ if not os.path.exists(modelpath):
61
+ from basicsr.utils.download_util import load_file_from_url
62
+ load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
63
+ self.netNetwork = ControlNetHED_Apache2().float().cuda().eval()
64
+ self.netNetwork.load_state_dict(torch.load(modelpath))
65
+
66
+ def __call__(self, input_image, safe=False):
67
+ assert input_image.ndim == 3
68
+ H, W, C = input_image.shape
69
+ with torch.no_grad():
70
+ image_hed = torch.from_numpy(input_image.copy()).float().cuda()
71
+ image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
72
+ edges = self.netNetwork(image_hed)
73
+ edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
74
+ edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
75
+ edges = np.stack(edges, axis=2)
76
+ edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
77
+ if safe:
78
+ edge = safe_step(edge)
79
+ edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
80
+ return edge
81
+
82
+
83
+ class SOFT_HEDdetector:
84
+ def __init__(self):
85
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth"
86
+ modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth")
87
+ if not os.path.exists(modelpath):
88
+ from basicsr.utils.download_util import load_file_from_url
89
+ load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
90
+ self.netNetwork = ControlNetHED_Apache2().float().cuda().eval()
91
+ self.netNetwork.load_state_dict(torch.load(modelpath))
92
+
93
+ def __call__(self, input_image, safe=False, threshold=200):
94
+ assert input_image.ndim == 3
95
+ H, W, C = input_image.shape
96
+ with torch.no_grad():
97
+ image_hed = torch.from_numpy(input_image.copy()).float().cuda()
98
+ image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
99
+ edges = self.netNetwork(image_hed)
100
+ edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
101
+ edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
102
+ edges = np.stack(edges, axis=2)
103
+ edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
104
+ if safe:
105
+ edge = safe_step(edge)
106
+ edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
107
+
108
+ content_image = edge
109
+ content_image[content_image > threshold] = 255
110
+ content_image[content_image < 255] = 0
111
+ kernel = np.ones((3,3), np.uint8)
112
+
113
+ content_image = cv2.dilate(content_image, kernel, iterations=1)
114
+ return content_image
annotator/lineart/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Caroline Chan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
annotator/lineart/__init__.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://github.com/carolineec/informative-drawings
2
+ # MIT License
3
+
4
+ import os
5
+ import cv2
6
+ import torch
7
+ import numpy as np
8
+
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from annotator.util import annotator_ckpts_path
12
+
13
+
14
+ norm_layer = nn.InstanceNorm2d
15
+
16
+
17
+ class ResidualBlock(nn.Module):
18
+ def __init__(self, in_features):
19
+ super(ResidualBlock, self).__init__()
20
+
21
+ conv_block = [ nn.ReflectionPad2d(1),
22
+ nn.Conv2d(in_features, in_features, 3),
23
+ norm_layer(in_features),
24
+ nn.ReLU(inplace=True),
25
+ nn.ReflectionPad2d(1),
26
+ nn.Conv2d(in_features, in_features, 3),
27
+ norm_layer(in_features)
28
+ ]
29
+
30
+ self.conv_block = nn.Sequential(*conv_block)
31
+
32
+ def forward(self, x):
33
+ return x + self.conv_block(x)
34
+
35
+
36
+ class Generator(nn.Module):
37
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
38
+ super(Generator, self).__init__()
39
+
40
+ # Initial convolution block
41
+ model0 = [ nn.ReflectionPad2d(3),
42
+ nn.Conv2d(input_nc, 64, 7),
43
+ norm_layer(64),
44
+ nn.ReLU(inplace=True) ]
45
+ self.model0 = nn.Sequential(*model0)
46
+
47
+ # Downsampling
48
+ model1 = []
49
+ in_features = 64
50
+ out_features = in_features*2
51
+ for _ in range(2):
52
+ model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
53
+ norm_layer(out_features),
54
+ nn.ReLU(inplace=True) ]
55
+ in_features = out_features
56
+ out_features = in_features*2
57
+ self.model1 = nn.Sequential(*model1)
58
+
59
+ model2 = []
60
+ # Residual blocks
61
+ for _ in range(n_residual_blocks):
62
+ model2 += [ResidualBlock(in_features)]
63
+ self.model2 = nn.Sequential(*model2)
64
+
65
+ # Upsampling
66
+ model3 = []
67
+ out_features = in_features//2
68
+ for _ in range(2):
69
+ model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
70
+ norm_layer(out_features),
71
+ nn.ReLU(inplace=True) ]
72
+ in_features = out_features
73
+ out_features = in_features//2
74
+ self.model3 = nn.Sequential(*model3)
75
+
76
+ # Output layer
77
+ model4 = [ nn.ReflectionPad2d(3),
78
+ nn.Conv2d(64, output_nc, 7)]
79
+ if sigmoid:
80
+ model4 += [nn.Sigmoid()]
81
+
82
+ self.model4 = nn.Sequential(*model4)
83
+
84
+ def forward(self, x, cond=None):
85
+ out = self.model0(x)
86
+ out = self.model1(out)
87
+ out = self.model2(out)
88
+ out = self.model3(out)
89
+ out = self.model4(out)
90
+
91
+ return out
92
+
93
+
94
+ class LineartDetector:
95
+ def __init__(self):
96
+ self.model = self.load_model('sk_model.pth')
97
+ self.model_coarse = self.load_model('sk_model2.pth')
98
+
99
+ def load_model(self, name):
100
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/" + name
101
+ modelpath = os.path.join(annotator_ckpts_path, name)
102
+ if not os.path.exists(modelpath):
103
+ from basicsr.utils.download_util import load_file_from_url
104
+ load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
105
+ model = Generator(3, 1, 3)
106
+ model.load_state_dict(torch.load(modelpath, map_location=torch.device('cpu')))
107
+ model.eval()
108
+ model = model.cuda()
109
+ return model
110
+
111
+ def __call__(self, input_image, coarse = False):
112
+ model = self.model_coarse if coarse else self.model
113
+ assert input_image.ndim == 3
114
+ image = input_image
115
+ # images = input_images
116
+ # results = []
117
+ with torch.no_grad():
118
+ image = torch.from_numpy(image).float().cuda()
119
+ # batch_imgs = torch.stack([torch.from_numpy(image).float().cuda() / 255.0 for image in images], dim=0)
120
+ image = image / 255.0
121
+ image = rearrange(image, 'h w c -> 1 c h w')
122
+ line = model(image)[0][0]
123
+
124
+ line = line.cpu().numpy()
125
+ line = (line * 255.0).clip(0, 255).astype(np.uint8)
126
+
127
+ # with torch.no_grad():
128
+ # # 将批次的图像传入模型
129
+ # outputs = model(batch_imgs)
130
+
131
+ # for output in outputs:
132
+ # line = output[0][0].cpu().numpy()
133
+ # line = (line * 255.0).clip(0, 255).astype(np.uint8)
134
+ # results.append(line)
135
+
136
+ # return results
137
+
138
+ return line
annotator/util.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import cv2
5
+ import os
6
+
7
+
8
+ annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
9
+
10
+
11
+ def HWC3(x):
12
+ assert x.dtype == np.uint8
13
+ if x.ndim == 2:
14
+ x = x[:, :, None]
15
+ assert x.ndim == 3
16
+ H, W, C = x.shape
17
+ assert C == 1 or C == 3 or C == 4
18
+ if C == 3:
19
+ return x
20
+ if C == 1:
21
+ return np.concatenate([x, x, x], axis=2)
22
+ if C == 4:
23
+ color = x[:, :, 0:3].astype(np.float32)
24
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
25
+ y = color * alpha + 255.0 * (1.0 - alpha)
26
+ y = y.clip(0, 255).astype(np.uint8)
27
+ return y
28
+
29
+
30
+ def resize_image(input_image, resolution):
31
+ H, W, C = input_image.shape
32
+ H = float(H)
33
+ W = float(W)
34
+ k = float(resolution) / min(H, W)
35
+ H *= k
36
+ W *= k
37
+ H = int(np.round(H / 64.0)) * 64
38
+ W = int(np.round(W / 64.0)) * 64
39
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
40
+ return img
41
+
42
+
43
+ def nms(x, t, s):
44
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
45
+
46
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
47
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
48
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
49
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
50
+
51
+ y = np.zeros_like(x)
52
+
53
+ for f in [f1, f2, f3, f4]:
54
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
55
+
56
+ z = np.zeros_like(y, dtype=np.uint8)
57
+ z[y > t] = 255
58
+ return z
59
+
60
+
61
+ def make_noise_disk(H, W, C, F):
62
+ noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
63
+ noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
64
+ noise = noise[F: F + H, F: F + W]
65
+ noise -= np.min(noise)
66
+ noise /= np.max(noise)
67
+ if C == 1:
68
+ noise = noise[:, :, None]
69
+ return noise
70
+
71
+
72
+ def min_max_norm(x):
73
+ x -= np.min(x)
74
+ x /= np.maximum(np.max(x), 1e-5)
75
+ return x
76
+
77
+
78
+ def safe_step(x, step=2):
79
+ y = x.astype(np.float32) * float(step + 1)
80
+ y = y.astype(np.int32).astype(np.float32) / float(step)
81
+ return y
82
+
83
+
84
+ def img2mask(img, H, W, low=10, high=90):
85
+ assert img.ndim == 3 or img.ndim == 2
86
+ assert img.dtype == np.uint8
87
+
88
+ if img.ndim == 3:
89
+ y = img[:, :, random.randrange(0, img.shape[2])]
90
+ else:
91
+ y = img
92
+
93
+ y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
94
+
95
+ if random.uniform(0, 1) < 0.5:
96
+ y = 255 - y
97
+
98
+ return y < np.percentile(y, random.randrange(low, high))
app.py CHANGED
@@ -1,7 +1,200 @@
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ from types import MethodType
2
+
3
+ import spaces
4
+ import os
5
  import gradio as gr
6
+ import torch
7
+ import cv2
8
+ from annotator.util import resize_image
9
+ from annotator.hed import SOFT_HEDdetector
10
+ from annotator.lineart import LineartDetector
11
+ from diffusers import UNet2DConditionModel, ControlNetModel
12
+ from transformers import CLIPVisionModelWithProjection
13
+ from huggingface_hub import snapshot_download
14
+ from PIL import Image
15
+ from ip_adapter import StyleShot, StyleContentStableDiffusionControlNetPipeline
16
+
17
+ device = "cuda"
18
+
19
+ contour_detector = SOFT_HEDdetector()
20
+ lineart_detector = LineartDetector()
21
+
22
+ base_model_path = "runwayml/stable-diffusion-v1-5"
23
+ transformer_block_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
24
+ styleshot_model_path = "Gaojunyao/StyleShot"
25
+ styleshot_lineart_model_path = "Gaojunyao/StyleShot_lineart"
26
+
27
+ if not os.path.isdir(base_model_path):
28
+ base_model_path = snapshot_download(base_model_path, local_dir=base_model_path)
29
+ print(f"Downloaded model to {base_model_path}")
30
+ if not os.path.isdir(transformer_block_path):
31
+ transformer_block_path = snapshot_download(transformer_block_path, local_dir=transformer_block_path)
32
+ print(f"Downloaded model to {transformer_block_path}")
33
+ if not os.path.isdir(styleshot_model_path):
34
+ styleshot_model_path = snapshot_download(styleshot_model_path, local_dir=styleshot_model_path)
35
+ print(f"Downloaded model to {styleshot_model_path}")
36
+ if not os.path.isdir(styleshot_lineart_model_path):
37
+ styleshot_lineart_model_path = snapshot_download(styleshot_lineart_model_path, local_dir=styleshot_lineart_model_path)
38
+ print(f"Downloaded model to {styleshot_lineart_model_path}")
39
+
40
+
41
+ # weights for ip-adapter and our content-fusion encoder
42
+ contour_ip_ckpt = os.path.join(styleshot_model_path, "pretrained_weight/ip.bin")
43
+ contour_style_aware_encoder_path = os.path.join(styleshot_model_path, "pretrained_weight/style_aware_encoder.bin")
44
+ contour_transformer_block_path = transformer_block_path
45
+ contour_unet = UNet2DConditionModel.from_pretrained(base_model_path, subfolder="unet")
46
+ contour_content_fusion_encoder = ControlNetModel.from_unet(contour_unet)
47
+
48
+ contour_pipe = StyleContentStableDiffusionControlNetPipeline.from_pretrained(base_model_path, controlnet=contour_content_fusion_encoder)
49
+ contour_styleshot = StyleShot(device, contour_pipe, contour_ip_ckpt, contour_style_aware_encoder_path, contour_transformer_block_path)
50
+
51
+ lineart_ip_ckpt = os.path.join(styleshot_lineart_model_path, "pretrained_weight/ip.bin")
52
+ lineart_style_aware_encoder_path = os.path.join(styleshot_lineart_model_path, "pretrained_weight/style_aware_encoder.bin")
53
+ lineart_transformer_block_path = transformer_block_path
54
+ lineart_unet = UNet2DConditionModel.from_pretrained(base_model_path, subfolder="unet")
55
+ lineart_content_fusion_encoder = ControlNetModel.from_unet(lineart_unet)
56
+
57
+ lineart_pipe = StyleContentStableDiffusionControlNetPipeline.from_pretrained(base_model_path, controlnet=lineart_content_fusion_encoder)
58
+ lineart_styleshot = StyleShot(device, lineart_pipe, lineart_ip_ckpt, lineart_style_aware_encoder_path, lineart_transformer_block_path)
59
+
60
+
61
+ @spaces.GPU
62
+ def process(style_image, content_image, prompt, num_samples, image_resolution, condition_scale, style_scale,ddim_steps, guidance_scale, seed, a_prompt, n_prompt, btn1, Contour_Threshold=200):
63
+ weight_dtype = torch.float32
64
+
65
+ style_shots = []
66
+ btns = []
67
+ contour_content_images = []
68
+ contour_results = []
69
+ lineart_content_images = []
70
+ lineart_results = []
71
+
72
+ type1 = 'Contour'
73
+ type2 = 'Lineart'
74
+
75
+ if btn1 == type1 or content_image is None:
76
+ style_shots = [contour_styleshot]
77
+ btns = [type1]
78
+ elif btn1 == type2:
79
+ style_shots = [lineart_styleshot]
80
+ btns = [type2]
81
+ elif btn1 == "Both":
82
+ style_shots = [contour_styleshot, lineart_styleshot]
83
+ btns = [type1, type2]
84
+
85
+ ori_style_image = style_image.copy()
86
+
87
+
88
+ if content_image is not None:
89
+ ori_content_image = content_image.copy()
90
+ else:
91
+ ori_content_image = None
92
+
93
+ for styleshot, btn in zip(style_shots, btns):
94
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
95
+ prompts = [prompt+" "+a_prompt]
96
+
97
+ style_image = Image.fromarray(ori_style_image)
98
+
99
+ if ori_content_image is not None:
100
+ if btn == type1:
101
+ content_image = resize_image(ori_content_image, image_resolution)
102
+ content_image = contour_detector(content_image, threshold=Contour_Threshold)
103
+ elif btn == type2:
104
+ content_image = resize_image(ori_content_image, image_resolution)
105
+ content_image = lineart_detector(content_image, coarse=False)
106
+
107
+ content_image = Image.fromarray(content_image)
108
+ else:
109
+ content_image = cv2.resize(ori_style_image, (image_resolution, image_resolution))
110
+ content_image = Image.fromarray(content_image)
111
+ condition_scale = 0.0
112
+
113
+ g_images = styleshot.generate(style_image=style_image,
114
+ prompt=[[prompt]],
115
+ negative_prompt=n_prompt,
116
+ scale=style_scale,
117
+ num_samples = num_samples,
118
+ seed = seed,
119
+ num_inference_steps=ddim_steps,
120
+ guidance_scale=guidance_scale,
121
+ content_image=content_image,
122
+ controlnet_conditioning_scale= float(condition_scale))
123
+
124
+ if btn == type1:
125
+ contour_content_images = [content_image]
126
+ contour_results = g_images[0]
127
+ elif btn == type2:
128
+ lineart_content_images = [content_image]
129
+ lineart_results = g_images[0]
130
+ if ori_content_image is None:
131
+ contour_content_images = []
132
+ lineart_results = []
133
+ lineart_content_images = []
134
+
135
+ return [contour_results, contour_content_images, lineart_results, lineart_content_images]
136
+
137
+
138
+ block = gr.Blocks().queue()
139
+ with block:
140
+ with gr.Row():
141
+ gr.Markdown("## Styleshot Demo")
142
+ with gr.Row():
143
+ with gr.Column():
144
+ style_image = gr.Image(sources=['upload'], type="numpy", label='Style Image')
145
+ with gr.Column():
146
+ with gr.Box():
147
+ with gr.Column():
148
+ content_image = gr.Image(sources=['upload'], type="numpy", label='Content Image (optional)')
149
+ btn1 = gr.Radio(
150
+ choices=["Contour", "Lineart", "Both"],
151
+ interactive=True,
152
+ label="Preprocessor",
153
+ value="Both",
154
+ )
155
+ gr.Markdown("We recommend using 'Contour' for sparse control and 'Lineart' for detailed control. If you choose 'Both', we will provide results for two types of control. If you choose 'Contour', you can adjust the 'Contour Threshold' under the 'Advanced options' for the level of detail in control. ")
156
+ with gr.Row():
157
+ prompt = gr.Textbox(label="Prompt")
158
+ with gr.Row():
159
+ run_button = gr.Button(value="Run")
160
+ with gr.Row():
161
+ with gr.Column():
162
+ with gr.Accordion("Advanced options", open=False):
163
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=4, step=1)
164
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
165
+ condition_scale = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
166
+
167
+ Contour_Threshold = gr.Slider(label="Contour Threshold", minimum=0, maximum=255, value=200, step=1)
168
+
169
+ style_scale = gr.Slider(label="Style Strength", minimum=0, maximum=2, value=1.0, step=0.01)
170
+
171
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
172
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
173
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=42, step=1)
174
+
175
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
176
+ n_prompt = gr.Textbox(label="Negative Prompt",
177
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
178
+
179
+ with gr.Row():
180
+ with gr.Box():
181
+ gr.Markdown("### Results for Contour")
182
+ with gr.Row():
183
+ with gr.Column(scale = 1):
184
+ contour_gallery = gr.Gallery(label='Contour Output', show_label=True, elem_id="gallery", columns=[1], rows=[1], height='auto')
185
+ with gr.Column(scale = 4):
186
+ image_gallery = gr.Gallery(label='Result for Contour', show_label=True, elem_id="gallery", columns=[4], rows=[1], height='auto')
187
+ with gr.Row():
188
+ with gr.Box():
189
+ gr.Markdown("### Results for Lineart")
190
+ with gr.Row():
191
+ with gr.Column(scale = 1):
192
+ line_gallery = gr.Gallery(label='Lineart Output', show_label=True, elem_id="gallery", columns=[1], rows=[1], height='auto')
193
+ with gr.Column(scale = 4):
194
+ line_image_gallery = gr.Gallery(label='Result for Lineart', show_label=True, elem_id="gallery", columns=[4], rows=[1], height='auto')
195
+
196
+ ips = [style_image, content_image, prompt, num_samples, image_resolution, condition_scale, style_scale, ddim_steps, guidance_scale, seed, a_prompt, n_prompt, btn1, Contour_Threshold]
197
+ run_button.click(fn=process, inputs=ips, outputs=[image_gallery, contour_gallery, line_image_gallery, line_gallery])
198
 
 
 
199
 
200
+ block.launch(server_name='0.0.0.0')
 
ip_adapter/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull, StyleShot, StyleContentStableDiffusionControlNetPipeline
2
+
3
+ __all__ = [
4
+ "IPAdapter",
5
+ "IPAdapterPlus",
6
+ "IPAdapterPlusXL",
7
+ "IPAdapterXL",
8
+ "IPAdapterFull",
9
+ "StyleShot",
10
+ "StyleContentStableDiffusionControlNetPipeline",
11
+ ]
ip_adapter/attention_processor.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class AttnProcessor(nn.Module):
8
+ r"""
9
+ Default processor for performing attention-related computations.
10
+ """
11
+
12
+ def __init__(
13
+ self,
14
+ hidden_size=None,
15
+ cross_attention_dim=None,
16
+ ):
17
+ super().__init__()
18
+
19
+ def __call__(
20
+ self,
21
+ attn,
22
+ hidden_states,
23
+ encoder_hidden_states=None,
24
+ attention_mask=None,
25
+ temb=None,
26
+ ):
27
+ residual = hidden_states
28
+
29
+ if attn.spatial_norm is not None:
30
+ hidden_states = attn.spatial_norm(hidden_states, temb)
31
+
32
+ input_ndim = hidden_states.ndim
33
+
34
+ if input_ndim == 4:
35
+ batch_size, channel, height, width = hidden_states.shape
36
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
37
+
38
+ batch_size, sequence_length, _ = (
39
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
40
+ )
41
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
42
+
43
+ if attn.group_norm is not None:
44
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
45
+
46
+ query = attn.to_q(hidden_states)
47
+
48
+ if encoder_hidden_states is None:
49
+ encoder_hidden_states = hidden_states
50
+ elif attn.norm_cross:
51
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
52
+
53
+ key = attn.to_k(encoder_hidden_states)
54
+ value = attn.to_v(encoder_hidden_states)
55
+
56
+ query = attn.head_to_batch_dim(query)
57
+ key = attn.head_to_batch_dim(key)
58
+ value = attn.head_to_batch_dim(value)
59
+
60
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
61
+ hidden_states = torch.bmm(attention_probs, value)
62
+ hidden_states = attn.batch_to_head_dim(hidden_states)
63
+
64
+ # linear proj
65
+ hidden_states = attn.to_out[0](hidden_states)
66
+ # dropout
67
+ hidden_states = attn.to_out[1](hidden_states)
68
+
69
+ if input_ndim == 4:
70
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
71
+
72
+ if attn.residual_connection:
73
+ hidden_states = hidden_states + residual
74
+
75
+ hidden_states = hidden_states / attn.rescale_output_factor
76
+
77
+ return hidden_states
78
+
79
+
80
+ class IPAttnProcessor(nn.Module):
81
+ r"""
82
+ Attention processor for IP-Adapater.
83
+ Args:
84
+ hidden_size (`int`):
85
+ The hidden size of the attention layer.
86
+ cross_attention_dim (`int`):
87
+ The number of channels in the `encoder_hidden_states`.
88
+ scale (`float`, defaults to 1.0):
89
+ the weight scale of image prompt.
90
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
91
+ The context length of the image features.
92
+ """
93
+
94
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
95
+ super().__init__()
96
+
97
+ self.hidden_size = hidden_size
98
+ self.cross_attention_dim = cross_attention_dim
99
+ self.scale = scale
100
+ self.num_tokens = num_tokens
101
+
102
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
103
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
104
+
105
+ def __call__(
106
+ self,
107
+ attn,
108
+ hidden_states,
109
+ encoder_hidden_states=None,
110
+ attention_mask=None,
111
+ temb=None,
112
+ ):
113
+ residual = hidden_states
114
+
115
+ if attn.spatial_norm is not None:
116
+ hidden_states = attn.spatial_norm(hidden_states, temb)
117
+
118
+ input_ndim = hidden_states.ndim
119
+
120
+ if input_ndim == 4:
121
+ batch_size, channel, height, width = hidden_states.shape
122
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
123
+
124
+ batch_size, sequence_length, _ = (
125
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
126
+ )
127
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
128
+
129
+ if attn.group_norm is not None:
130
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
131
+
132
+ query = attn.to_q(hidden_states)
133
+
134
+ if encoder_hidden_states is None:
135
+ encoder_hidden_states = hidden_states
136
+ else:
137
+ # get encoder_hidden_states, ip_hidden_states
138
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
139
+ encoder_hidden_states, ip_hidden_states = (
140
+ encoder_hidden_states[:, :end_pos, :],
141
+ encoder_hidden_states[:, end_pos:, :],
142
+ )
143
+ if attn.norm_cross:
144
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
145
+
146
+ key = attn.to_k(encoder_hidden_states)
147
+ value = attn.to_v(encoder_hidden_states)
148
+
149
+ query = attn.head_to_batch_dim(query)
150
+ key = attn.head_to_batch_dim(key)
151
+ value = attn.head_to_batch_dim(value)
152
+
153
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
154
+ hidden_states = torch.bmm(attention_probs, value)
155
+ hidden_states = attn.batch_to_head_dim(hidden_states)
156
+
157
+ # for ip-adapter
158
+ ip_key = self.to_k_ip(ip_hidden_states)
159
+ ip_value = self.to_v_ip(ip_hidden_states)
160
+
161
+ ip_key = attn.head_to_batch_dim(ip_key)
162
+ ip_value = attn.head_to_batch_dim(ip_value)
163
+
164
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
165
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
166
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
167
+
168
+ hidden_states = hidden_states + self.scale * ip_hidden_states
169
+
170
+ # linear proj
171
+ hidden_states = attn.to_out[0](hidden_states)
172
+ # dropout
173
+ hidden_states = attn.to_out[1](hidden_states)
174
+
175
+ if input_ndim == 4:
176
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
177
+
178
+ if attn.residual_connection:
179
+ hidden_states = hidden_states + residual
180
+
181
+ hidden_states = hidden_states / attn.rescale_output_factor
182
+
183
+ return hidden_states
184
+
185
+
186
+ class AttnProcessor2_0(torch.nn.Module):
187
+ r"""
188
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ hidden_size=None,
194
+ cross_attention_dim=None,
195
+ ):
196
+ super().__init__()
197
+ if not hasattr(F, "scaled_dot_product_attention"):
198
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
199
+
200
+ def __call__(
201
+ self,
202
+ attn,
203
+ hidden_states,
204
+ encoder_hidden_states=None,
205
+ attention_mask=None,
206
+ temb=None,
207
+ ):
208
+ residual = hidden_states
209
+
210
+ if attn.spatial_norm is not None:
211
+ hidden_states = attn.spatial_norm(hidden_states, temb)
212
+
213
+ input_ndim = hidden_states.ndim
214
+
215
+ if input_ndim == 4:
216
+ batch_size, channel, height, width = hidden_states.shape
217
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
218
+
219
+ batch_size, sequence_length, _ = (
220
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
221
+ )
222
+
223
+ if attention_mask is not None:
224
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
225
+ # scaled_dot_product_attention expects attention_mask shape to be
226
+ # (batch, heads, source_length, target_length)
227
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
228
+
229
+ if attn.group_norm is not None:
230
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
231
+
232
+ query = attn.to_q(hidden_states)
233
+
234
+ if encoder_hidden_states is None:
235
+ encoder_hidden_states = hidden_states
236
+ elif attn.norm_cross:
237
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
238
+
239
+ key = attn.to_k(encoder_hidden_states)
240
+ value = attn.to_v(encoder_hidden_states)
241
+
242
+ inner_dim = key.shape[-1]
243
+ head_dim = inner_dim // attn.heads
244
+
245
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
246
+
247
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
248
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
249
+
250
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
251
+ # TODO: add support for attn.scale when we move to Torch 2.1
252
+ hidden_states = F.scaled_dot_product_attention(
253
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
254
+ )
255
+
256
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
257
+ hidden_states = hidden_states.to(query.dtype)
258
+
259
+ # linear proj
260
+ hidden_states = attn.to_out[0](hidden_states)
261
+ # dropout
262
+ hidden_states = attn.to_out[1](hidden_states)
263
+
264
+ if input_ndim == 4:
265
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
266
+
267
+ if attn.residual_connection:
268
+ hidden_states = hidden_states + residual
269
+
270
+ hidden_states = hidden_states / attn.rescale_output_factor
271
+
272
+ return hidden_states
273
+
274
+
275
+ class IPAttnProcessor2_0(torch.nn.Module):
276
+ r"""
277
+ Attention processor for IP-Adapater for PyTorch 2.0.
278
+ Args:
279
+ hidden_size (`int`):
280
+ The hidden size of the attention layer.
281
+ cross_attention_dim (`int`):
282
+ The number of channels in the `encoder_hidden_states`.
283
+ scale (`float`, defaults to 1.0):
284
+ the weight scale of image prompt.
285
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
286
+ The context length of the image features.
287
+ """
288
+
289
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
290
+ super().__init__()
291
+
292
+ if not hasattr(F, "scaled_dot_product_attention"):
293
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
294
+
295
+ self.hidden_size = hidden_size
296
+ self.cross_attention_dim = cross_attention_dim
297
+ self.scale = scale
298
+ self.num_tokens = num_tokens
299
+
300
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
301
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
302
+
303
+ def __call__(
304
+ self,
305
+ attn,
306
+ hidden_states,
307
+ encoder_hidden_states=None,
308
+ attention_mask=None,
309
+ temb=None,
310
+ ):
311
+ residual = hidden_states
312
+
313
+ if attn.spatial_norm is not None:
314
+ hidden_states = attn.spatial_norm(hidden_states, temb)
315
+
316
+ input_ndim = hidden_states.ndim
317
+
318
+ if input_ndim == 4:
319
+ batch_size, channel, height, width = hidden_states.shape
320
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
321
+
322
+ batch_size, sequence_length, _ = (
323
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
324
+ )
325
+
326
+ if attention_mask is not None:
327
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
328
+ # scaled_dot_product_attention expects attention_mask shape to be
329
+ # (batch, heads, source_length, target_length)
330
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
331
+
332
+ if attn.group_norm is not None:
333
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
334
+
335
+ query = attn.to_q(hidden_states)
336
+
337
+ if encoder_hidden_states is None:
338
+ encoder_hidden_states = hidden_states
339
+ else:
340
+ # get encoder_hidden_states, ip_hidden_states
341
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
342
+ encoder_hidden_states, ip_hidden_states = (
343
+ encoder_hidden_states[:, :end_pos, :],
344
+ encoder_hidden_states[:, end_pos:, :],
345
+ )
346
+ if attn.norm_cross:
347
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
348
+
349
+ key = attn.to_k(encoder_hidden_states)
350
+ value = attn.to_v(encoder_hidden_states)
351
+
352
+ inner_dim = key.shape[-1]
353
+ head_dim = inner_dim // attn.heads
354
+
355
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
356
+
357
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
358
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
359
+
360
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
361
+ # TODO: add support for attn.scale when we move to Torch 2.1
362
+ hidden_states = F.scaled_dot_product_attention(
363
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
364
+ )
365
+
366
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
367
+ hidden_states = hidden_states.to(query.dtype)
368
+
369
+ # for ip-adapter
370
+ ip_key = self.to_k_ip(ip_hidden_states)
371
+ ip_value = self.to_v_ip(ip_hidden_states)
372
+
373
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
374
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
375
+
376
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
377
+ # TODO: add support for attn.scale when we move to Torch 2.1
378
+ ip_hidden_states = F.scaled_dot_product_attention(
379
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
380
+ )
381
+
382
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
383
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
384
+
385
+ hidden_states = hidden_states + self.scale * ip_hidden_states
386
+
387
+ # linear proj
388
+ hidden_states = attn.to_out[0](hidden_states)
389
+ # dropout
390
+ hidden_states = attn.to_out[1](hidden_states)
391
+
392
+ if input_ndim == 4:
393
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
394
+
395
+ if attn.residual_connection:
396
+ hidden_states = hidden_states + residual
397
+
398
+ hidden_states = hidden_states / attn.rescale_output_factor
399
+
400
+ return hidden_states
401
+
402
+
403
+ ## for controlnet
404
+ class CNAttnProcessor:
405
+ r"""
406
+ Default processor for performing attention-related computations.
407
+ """
408
+
409
+ def __init__(self, num_tokens=4):
410
+ self.num_tokens = num_tokens
411
+
412
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
413
+ residual = hidden_states
414
+
415
+ if attn.spatial_norm is not None:
416
+ hidden_states = attn.spatial_norm(hidden_states, temb)
417
+
418
+ input_ndim = hidden_states.ndim
419
+
420
+ if input_ndim == 4:
421
+ batch_size, channel, height, width = hidden_states.shape
422
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
423
+
424
+ batch_size, sequence_length, _ = (
425
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
426
+ )
427
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
428
+
429
+ if attn.group_norm is not None:
430
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
431
+
432
+ query = attn.to_q(hidden_states)
433
+
434
+ if encoder_hidden_states is None:
435
+ encoder_hidden_states = hidden_states
436
+ else:
437
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
438
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
439
+ if attn.norm_cross:
440
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
441
+
442
+ key = attn.to_k(encoder_hidden_states)
443
+ value = attn.to_v(encoder_hidden_states)
444
+
445
+ query = attn.head_to_batch_dim(query)
446
+ key = attn.head_to_batch_dim(key)
447
+ value = attn.head_to_batch_dim(value)
448
+
449
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
450
+ hidden_states = torch.bmm(attention_probs, value)
451
+ hidden_states = attn.batch_to_head_dim(hidden_states)
452
+
453
+ # linear proj
454
+ hidden_states = attn.to_out[0](hidden_states)
455
+ # dropout
456
+ hidden_states = attn.to_out[1](hidden_states)
457
+
458
+ if input_ndim == 4:
459
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
460
+
461
+ if attn.residual_connection:
462
+ hidden_states = hidden_states + residual
463
+
464
+ hidden_states = hidden_states / attn.rescale_output_factor
465
+
466
+ return hidden_states
467
+
468
+
469
+ class CNAttnProcessor2_0:
470
+ r"""
471
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
472
+ """
473
+
474
+ def __init__(self, num_tokens=4):
475
+ if not hasattr(F, "scaled_dot_product_attention"):
476
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
477
+ self.num_tokens = num_tokens
478
+
479
+ def __call__(
480
+ self,
481
+ attn,
482
+ hidden_states,
483
+ encoder_hidden_states=None,
484
+ attention_mask=None,
485
+ temb=None,
486
+ ):
487
+ residual = hidden_states
488
+
489
+ if attn.spatial_norm is not None:
490
+ hidden_states = attn.spatial_norm(hidden_states, temb)
491
+
492
+ input_ndim = hidden_states.ndim
493
+
494
+ if input_ndim == 4:
495
+ batch_size, channel, height, width = hidden_states.shape
496
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
497
+
498
+ batch_size, sequence_length, _ = (
499
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
500
+ )
501
+
502
+ if attention_mask is not None:
503
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
504
+ # scaled_dot_product_attention expects attention_mask shape to be
505
+ # (batch, heads, source_length, target_length)
506
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
507
+
508
+ if attn.group_norm is not None:
509
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
510
+
511
+ query = attn.to_q(hidden_states)
512
+
513
+ if encoder_hidden_states is None:
514
+ encoder_hidden_states = hidden_states
515
+ else:
516
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
517
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
518
+ if attn.norm_cross:
519
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
520
+
521
+ key = attn.to_k(encoder_hidden_states)
522
+ value = attn.to_v(encoder_hidden_states)
523
+
524
+ inner_dim = key.shape[-1]
525
+ head_dim = inner_dim // attn.heads
526
+
527
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
528
+
529
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
530
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
531
+
532
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
533
+ # TODO: add support for attn.scale when we move to Torch 2.1
534
+ hidden_states = F.scaled_dot_product_attention(
535
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
536
+ )
537
+
538
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
539
+ hidden_states = hidden_states.to(query.dtype)
540
+
541
+ # linear proj
542
+ hidden_states = attn.to_out[0](hidden_states)
543
+ # dropout
544
+ hidden_states = attn.to_out[1](hidden_states)
545
+
546
+ if input_ndim == 4:
547
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
548
+
549
+ if attn.residual_connection:
550
+ hidden_states = hidden_states + residual
551
+
552
+ hidden_states = hidden_states / attn.rescale_output_factor
553
+
554
+ return hidden_states
ip_adapter/custom_pipelines.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from diffusers import StableDiffusionXLPipeline
5
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
6
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
7
+
8
+ from .utils import is_torch2_available
9
+
10
+ if is_torch2_available():
11
+ from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
12
+ else:
13
+ from .attention_processor import IPAttnProcessor
14
+
15
+
16
+ class StableDiffusionXLCustomPipeline(StableDiffusionXLPipeline):
17
+ def set_scale(self, scale):
18
+ for attn_processor in self.unet.attn_processors.values():
19
+ if isinstance(attn_processor, IPAttnProcessor):
20
+ attn_processor.scale = scale
21
+
22
+ @torch.no_grad()
23
+ def __call__( # noqa: C901
24
+ self,
25
+ prompt: Optional[Union[str, List[str]]] = None,
26
+ prompt_2: Optional[Union[str, List[str]]] = None,
27
+ height: Optional[int] = None,
28
+ width: Optional[int] = None,
29
+ num_inference_steps: int = 50,
30
+ denoising_end: Optional[float] = None,
31
+ guidance_scale: float = 5.0,
32
+ negative_prompt: Optional[Union[str, List[str]]] = None,
33
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
34
+ num_images_per_prompt: Optional[int] = 1,
35
+ eta: float = 0.0,
36
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
37
+ latents: Optional[torch.FloatTensor] = None,
38
+ prompt_embeds: Optional[torch.FloatTensor] = None,
39
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
40
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
41
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
42
+ output_type: Optional[str] = "pil",
43
+ return_dict: bool = True,
44
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
45
+ callback_steps: int = 1,
46
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
47
+ guidance_rescale: float = 0.0,
48
+ original_size: Optional[Tuple[int, int]] = None,
49
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
50
+ target_size: Optional[Tuple[int, int]] = None,
51
+ negative_original_size: Optional[Tuple[int, int]] = None,
52
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
53
+ negative_target_size: Optional[Tuple[int, int]] = None,
54
+ control_guidance_start: float = 0.0,
55
+ control_guidance_end: float = 1.0,
56
+ ):
57
+ r"""
58
+ Function invoked when calling the pipeline for generation.
59
+
60
+ Args:
61
+ prompt (`str` or `List[str]`, *optional*):
62
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
63
+ instead.
64
+ prompt_2 (`str` or `List[str]`, *optional*):
65
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
66
+ used in both text-encoders
67
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
68
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
69
+ Anything below 512 pixels won't work well for
70
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
71
+ and checkpoints that are not specifically fine-tuned on low resolutions.
72
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
73
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
74
+ Anything below 512 pixels won't work well for
75
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
76
+ and checkpoints that are not specifically fine-tuned on low resolutions.
77
+ num_inference_steps (`int`, *optional*, defaults to 50):
78
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
79
+ expense of slower inference.
80
+ denoising_end (`float`, *optional*):
81
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
82
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
83
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
84
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
85
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
86
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
87
+ guidance_scale (`float`, *optional*, defaults to 5.0):
88
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
89
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
90
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
91
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
92
+ usually at the expense of lower image quality.
93
+ negative_prompt (`str` or `List[str]`, *optional*):
94
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
95
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
96
+ less than `1`).
97
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
98
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
99
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
100
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
101
+ The number of images to generate per prompt.
102
+ eta (`float`, *optional*, defaults to 0.0):
103
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
104
+ [`schedulers.DDIMScheduler`], will be ignored for others.
105
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
106
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
107
+ to make generation deterministic.
108
+ latents (`torch.FloatTensor`, *optional*):
109
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
110
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
111
+ tensor will ge generated by sampling using the supplied random `generator`.
112
+ prompt_embeds (`torch.FloatTensor`, *optional*):
113
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
114
+ provided, text embeddings will be generated from `prompt` input argument.
115
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
116
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
117
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
118
+ argument.
119
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
120
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
121
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
122
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
123
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
124
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
125
+ input argument.
126
+ output_type (`str`, *optional*, defaults to `"pil"`):
127
+ The output format of the generate image. Choose between
128
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
129
+ return_dict (`bool`, *optional*, defaults to `True`):
130
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
131
+ of a plain tuple.
132
+ callback (`Callable`, *optional*):
133
+ A function that will be called every `callback_steps` steps during inference. The function will be
134
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
135
+ callback_steps (`int`, *optional*, defaults to 1):
136
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
137
+ called at every step.
138
+ cross_attention_kwargs (`dict`, *optional*):
139
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
140
+ `self.processor` in
141
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
142
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
143
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
144
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
145
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
146
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
147
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
148
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
149
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
150
+ explained in section 2.2 of
151
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
152
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
153
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
154
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
155
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
156
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
157
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
158
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
159
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
160
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
161
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
162
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
163
+ micro-conditioning as explained in section 2.2 of
164
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
165
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
166
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
167
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
168
+ micro-conditioning as explained in section 2.2 of
169
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
170
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
171
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
172
+ To negatively condition the generation process based on a target image resolution. It should be as same
173
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
174
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
175
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
176
+ control_guidance_start (`float`, *optional*, defaults to 0.0):
177
+ The percentage of total steps at which the ControlNet starts applying.
178
+ control_guidance_end (`float`, *optional*, defaults to 1.0):
179
+ The percentage of total steps at which the ControlNet stops applying.
180
+
181
+ Examples:
182
+
183
+ Returns:
184
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
185
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
186
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
187
+ """
188
+ # 0. Default height and width to unet
189
+ height = height or self.default_sample_size * self.vae_scale_factor
190
+ width = width or self.default_sample_size * self.vae_scale_factor
191
+
192
+ original_size = original_size or (height, width)
193
+ target_size = target_size or (height, width)
194
+
195
+ # 1. Check inputs. Raise error if not correct
196
+ self.check_inputs(
197
+ prompt,
198
+ prompt_2,
199
+ height,
200
+ width,
201
+ callback_steps,
202
+ negative_prompt,
203
+ negative_prompt_2,
204
+ prompt_embeds,
205
+ negative_prompt_embeds,
206
+ pooled_prompt_embeds,
207
+ negative_pooled_prompt_embeds,
208
+ )
209
+
210
+ # 2. Define call parameters
211
+ if prompt is not None and isinstance(prompt, str):
212
+ batch_size = 1
213
+ elif prompt is not None and isinstance(prompt, list):
214
+ batch_size = len(prompt)
215
+ else:
216
+ batch_size = prompt_embeds.shape[0]
217
+
218
+ device = self._execution_device
219
+
220
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
221
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
222
+ # corresponds to doing no classifier free guidance.
223
+ do_classifier_free_guidance = guidance_scale > 1.0
224
+
225
+ # 3. Encode input prompt
226
+ text_encoder_lora_scale = (
227
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
228
+ )
229
+ (
230
+ prompt_embeds,
231
+ negative_prompt_embeds,
232
+ pooled_prompt_embeds,
233
+ negative_pooled_prompt_embeds,
234
+ ) = self.encode_prompt(
235
+ prompt=prompt,
236
+ prompt_2=prompt_2,
237
+ device=device,
238
+ num_images_per_prompt=num_images_per_prompt,
239
+ do_classifier_free_guidance=do_classifier_free_guidance,
240
+ negative_prompt=negative_prompt,
241
+ negative_prompt_2=negative_prompt_2,
242
+ prompt_embeds=prompt_embeds,
243
+ negative_prompt_embeds=negative_prompt_embeds,
244
+ pooled_prompt_embeds=pooled_prompt_embeds,
245
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
246
+ lora_scale=text_encoder_lora_scale,
247
+ )
248
+
249
+ # 4. Prepare timesteps
250
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
251
+
252
+ timesteps = self.scheduler.timesteps
253
+
254
+ # 5. Prepare latent variables
255
+ num_channels_latents = self.unet.config.in_channels
256
+ latents = self.prepare_latents(
257
+ batch_size * num_images_per_prompt,
258
+ num_channels_latents,
259
+ height,
260
+ width,
261
+ prompt_embeds.dtype,
262
+ device,
263
+ generator,
264
+ latents,
265
+ )
266
+
267
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
268
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
269
+
270
+ # 7. Prepare added time ids & embeddings
271
+ add_text_embeds = pooled_prompt_embeds
272
+ if self.text_encoder_2 is None:
273
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
274
+ else:
275
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
276
+
277
+ add_time_ids = self._get_add_time_ids(
278
+ original_size,
279
+ crops_coords_top_left,
280
+ target_size,
281
+ dtype=prompt_embeds.dtype,
282
+ text_encoder_projection_dim=text_encoder_projection_dim,
283
+ )
284
+ if negative_original_size is not None and negative_target_size is not None:
285
+ negative_add_time_ids = self._get_add_time_ids(
286
+ negative_original_size,
287
+ negative_crops_coords_top_left,
288
+ negative_target_size,
289
+ dtype=prompt_embeds.dtype,
290
+ text_encoder_projection_dim=text_encoder_projection_dim,
291
+ )
292
+ else:
293
+ negative_add_time_ids = add_time_ids
294
+
295
+ if do_classifier_free_guidance:
296
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
297
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
298
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
299
+
300
+ prompt_embeds = prompt_embeds.to(device)
301
+ add_text_embeds = add_text_embeds.to(device)
302
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
303
+
304
+ # 8. Denoising loop
305
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
306
+
307
+ # 7.1 Apply denoising_end
308
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
309
+ discrete_timestep_cutoff = int(
310
+ round(
311
+ self.scheduler.config.num_train_timesteps
312
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
313
+ )
314
+ )
315
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
316
+ timesteps = timesteps[:num_inference_steps]
317
+
318
+ # get init conditioning scale
319
+ for attn_processor in self.unet.attn_processors.values():
320
+ if isinstance(attn_processor, IPAttnProcessor):
321
+ conditioning_scale = attn_processor.scale
322
+ break
323
+
324
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
325
+ for i, t in enumerate(timesteps):
326
+ if (i / len(timesteps) < control_guidance_start) or ((i + 1) / len(timesteps) > control_guidance_end):
327
+ self.set_scale(0.0)
328
+ else:
329
+ self.set_scale(conditioning_scale)
330
+
331
+ # expand the latents if we are doing classifier free guidance
332
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
333
+
334
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
335
+
336
+ # predict the noise residual
337
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
338
+ noise_pred = self.unet(
339
+ latent_model_input,
340
+ t,
341
+ encoder_hidden_states=prompt_embeds,
342
+ cross_attention_kwargs=cross_attention_kwargs,
343
+ added_cond_kwargs=added_cond_kwargs,
344
+ return_dict=False,
345
+ )[0]
346
+
347
+ # perform guidance
348
+ if do_classifier_free_guidance:
349
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
350
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
351
+
352
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
353
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
354
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
355
+
356
+ # compute the previous noisy sample x_t -> x_t-1
357
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
358
+
359
+ # call the callback, if provided
360
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
361
+ progress_bar.update()
362
+ if callback is not None and i % callback_steps == 0:
363
+ callback(i, t, latents)
364
+
365
+ if not output_type == "latent":
366
+ # make sure the VAE is in float32 mode, as it overflows in float16
367
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
368
+
369
+ if needs_upcasting:
370
+ self.upcast_vae()
371
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
372
+
373
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
374
+
375
+ # cast back to fp16 if needed
376
+ if needs_upcasting:
377
+ self.vae.to(dtype=torch.float16)
378
+ else:
379
+ image = latents
380
+
381
+ if output_type != "latent":
382
+ # apply watermark if available
383
+ if self.watermark is not None:
384
+ image = self.watermark.apply_watermark(image)
385
+
386
+ image = self.image_processor.postprocess(image, output_type=output_type)
387
+
388
+ # Offload all models
389
+ self.maybe_free_model_hooks()
390
+
391
+ if not return_dict:
392
+ return (image,)
393
+
394
+ return StableDiffusionXLPipelineOutput(images=image)
ip_adapter/ip_adapter.py ADDED
@@ -0,0 +1,1086 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+
5
+ import torch
6
+ from typing import Optional, Union, Any, Dict, Tuple, List, Callable
7
+ from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
8
+ from diffusers.utils import (
9
+ USE_PEFT_BACKEND,
10
+ deprecate,
11
+ logging,
12
+ replace_example_docstring,
13
+ scale_lora_layers,
14
+ unscale_lora_layers,
15
+ )
16
+ from diffusers.pipelines.controlnet.pipeline_controlnet import retrieve_timesteps
17
+ from diffusers import StableDiffusionPipeline
18
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
19
+ from diffusers.pipelines.controlnet.pipeline_controlnet import StableDiffusionControlNetPipeline
20
+ from diffusers.models.controlnet import ControlNetModel
21
+ from diffusers.image_processor import PipelineImageInput
22
+ from diffusers.pipelines.controlnet import MultiControlNetModel
23
+ from PIL import Image
24
+ from safetensors import safe_open
25
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
26
+ from torchvision import transforms
27
+ from .style_encoder import Style_Aware_Encoder
28
+ from .tools import pre_processing
29
+
30
+ from .utils import is_torch2_available
31
+
32
+ if is_torch2_available():
33
+ from .attention_processor import (
34
+ AttnProcessor2_0 as AttnProcessor,
35
+ )
36
+ from .attention_processor import (
37
+ CNAttnProcessor2_0 as CNAttnProcessor,
38
+ )
39
+ from .attention_processor import (
40
+ IPAttnProcessor2_0 as IPAttnProcessor,
41
+ )
42
+ else:
43
+ from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
44
+ from .resampler import Resampler
45
+
46
+
47
+ class ImageProjModel(torch.nn.Module):
48
+ """Projection Model"""
49
+
50
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
51
+ super().__init__()
52
+
53
+ self.cross_attention_dim = cross_attention_dim
54
+ self.clip_extra_context_tokens = clip_extra_context_tokens
55
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
56
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
57
+
58
+ def forward(self, image_embeds):
59
+ embeds = image_embeds
60
+ clip_extra_context_tokens = self.proj(embeds).reshape(
61
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
62
+ )
63
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
64
+ return clip_extra_context_tokens
65
+
66
+
67
+ class MLPProjModel(torch.nn.Module):
68
+ """SD model with image prompt"""
69
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
70
+ super().__init__()
71
+
72
+ self.proj = torch.nn.Sequential(
73
+ torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
74
+ torch.nn.GELU(),
75
+ torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
76
+ torch.nn.LayerNorm(cross_attention_dim)
77
+ )
78
+
79
+ def forward(self, image_embeds):
80
+ clip_extra_context_tokens = self.proj(image_embeds)
81
+ return clip_extra_context_tokens
82
+
83
+
84
+ class IPAdapter:
85
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
86
+ self.device = device
87
+ self.image_encoder_path = image_encoder_path
88
+ self.ip_ckpt = ip_ckpt
89
+ self.num_tokens = num_tokens
90
+
91
+ self.pipe = sd_pipe.to(self.device)
92
+ self.set_ip_adapter()
93
+
94
+ # load image encoder
95
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
96
+ self.device, dtype=torch.float16
97
+ )
98
+ self.clip_image_processor = CLIPImageProcessor()
99
+ # image proj model
100
+ self.image_proj_model = self.init_proj()
101
+
102
+ self.load_ip_adapter()
103
+
104
+ def init_proj(self):
105
+ image_proj_model = ImageProjModel(
106
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
107
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
108
+ clip_extra_context_tokens=self.num_tokens,
109
+ ).to(self.device, dtype=torch.float16)
110
+ return image_proj_model
111
+
112
+ def set_ip_adapter(self):
113
+ unet = self.pipe.unet
114
+ attn_procs = {}
115
+ for name in unet.attn_processors.keys():
116
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
117
+ if name.startswith("mid_block"):
118
+ hidden_size = unet.config.block_out_channels[-1]
119
+ elif name.startswith("up_blocks"):
120
+ block_id = int(name[len("up_blocks.")])
121
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
122
+ elif name.startswith("down_blocks"):
123
+ block_id = int(name[len("down_blocks.")])
124
+ hidden_size = unet.config.block_out_channels[block_id]
125
+ if cross_attention_dim is None:
126
+ attn_procs[name] = AttnProcessor()
127
+ else:
128
+ attn_procs[name] = IPAttnProcessor(
129
+ hidden_size=hidden_size,
130
+ cross_attention_dim=cross_attention_dim,
131
+ scale=1.0,
132
+ num_tokens=self.num_tokens,
133
+ ).to(self.device, dtype=torch.float16)
134
+ unet.set_attn_processor(attn_procs)
135
+ if hasattr(self.pipe, "controlnet"):
136
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
137
+ for controlnet in self.pipe.controlnet.nets:
138
+ controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
139
+ else:
140
+ self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
141
+
142
+ def load_ip_adapter(self):
143
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
144
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
145
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
146
+ for key in f.keys():
147
+ if key.startswith("image_proj."):
148
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
149
+ elif key.startswith("ip_adapter."):
150
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
151
+ else:
152
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
153
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
154
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
155
+ ip_layers.load_state_dict(state_dict["ip_adapter"])
156
+
157
+ @torch.inference_mode()
158
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
159
+ if pil_image is not None:
160
+ if isinstance(pil_image, Image.Image):
161
+ pil_image = [pil_image]
162
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
163
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
164
+ else:
165
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
166
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
167
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
168
+ return image_prompt_embeds, uncond_image_prompt_embeds
169
+
170
+ def set_scale(self, scale):
171
+ for attn_processor in self.pipe.unet.attn_processors.values():
172
+ if isinstance(attn_processor, IPAttnProcessor):
173
+ attn_processor.scale = scale
174
+
175
+ def generate(
176
+ self,
177
+ pil_image=None,
178
+ clip_image_embeds=None,
179
+ prompt=None,
180
+ negative_prompt=None,
181
+ scale=1.0,
182
+ num_samples=4,
183
+ seed=None,
184
+ guidance_scale=7.5,
185
+ num_inference_steps=30,
186
+ **kwargs,
187
+ ):
188
+ self.set_scale(scale)
189
+
190
+ if pil_image is not None:
191
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
192
+ else:
193
+ num_prompts = clip_image_embeds.size(0)
194
+
195
+ if prompt is None:
196
+ prompt = "best quality, high quality"
197
+ if negative_prompt is None:
198
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
199
+
200
+ if not isinstance(prompt, List):
201
+ prompt = [prompt] * num_prompts
202
+ if not isinstance(negative_prompt, List):
203
+ negative_prompt = [negative_prompt] * num_prompts
204
+
205
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
206
+ pil_image=pil_image, clip_image_embeds=clip_image_embeds
207
+ )
208
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
209
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
210
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
211
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
212
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
213
+
214
+ with torch.inference_mode():
215
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
216
+ prompt,
217
+ device=self.device,
218
+ num_images_per_prompt=num_samples,
219
+ do_classifier_free_guidance=True,
220
+ negative_prompt=negative_prompt,
221
+ )
222
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
223
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
224
+
225
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
226
+ images = self.pipe(
227
+ prompt_embeds=prompt_embeds,
228
+ negative_prompt_embeds=negative_prompt_embeds,
229
+ guidance_scale=guidance_scale,
230
+ num_inference_steps=num_inference_steps,
231
+ generator=generator,
232
+ **kwargs,
233
+ ).images
234
+
235
+ return images
236
+
237
+
238
+ class IPAdapterXL(IPAdapter):
239
+ """SDXL"""
240
+
241
+ def generate(
242
+ self,
243
+ pil_image,
244
+ prompt=None,
245
+ negative_prompt=None,
246
+ scale=1.0,
247
+ num_samples=4,
248
+ seed=None,
249
+ num_inference_steps=30,
250
+ **kwargs,
251
+ ):
252
+ self.set_scale(scale)
253
+
254
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
255
+
256
+ if prompt is None:
257
+ prompt = "best quality, high quality"
258
+ if negative_prompt is None:
259
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
260
+
261
+ if not isinstance(prompt, List):
262
+ prompt = [prompt] * num_prompts
263
+ if not isinstance(negative_prompt, List):
264
+ negative_prompt = [negative_prompt] * num_prompts
265
+
266
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
267
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
268
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
269
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
270
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
271
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
272
+
273
+ with torch.inference_mode():
274
+ (
275
+ prompt_embeds,
276
+ negative_prompt_embeds,
277
+ pooled_prompt_embeds,
278
+ negative_pooled_prompt_embeds,
279
+ ) = self.pipe.encode_prompt(
280
+ prompt,
281
+ num_images_per_prompt=num_samples,
282
+ do_classifier_free_guidance=True,
283
+ negative_prompt=negative_prompt,
284
+ )
285
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
286
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
287
+
288
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
289
+ images = self.pipe(
290
+ prompt_embeds=prompt_embeds,
291
+ negative_prompt_embeds=negative_prompt_embeds,
292
+ pooled_prompt_embeds=pooled_prompt_embeds,
293
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
294
+ num_inference_steps=num_inference_steps,
295
+ generator=generator,
296
+ **kwargs,
297
+ ).images
298
+
299
+ return images
300
+
301
+
302
+ class IPAdapterPlus(IPAdapter):
303
+ """IP-Adapter with fine-grained features"""
304
+
305
+ def init_proj(self):
306
+ image_proj_model = Resampler(
307
+ dim=self.pipe.unet.config.cross_attention_dim,
308
+ depth=4,
309
+ dim_head=64,
310
+ heads=12,
311
+ num_queries=self.num_tokens,
312
+ embedding_dim=self.image_encoder.config.hidden_size,
313
+ output_dim=self.pipe.unet.config.cross_attention_dim,
314
+ ff_mult=4,
315
+ ).to(self.device, dtype=torch.float16)
316
+ return image_proj_model
317
+
318
+ @torch.inference_mode()
319
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
320
+ if isinstance(pil_image, Image.Image):
321
+ pil_image = [pil_image]
322
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
323
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
324
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
325
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
326
+ uncond_clip_image_embeds = self.image_encoder(
327
+ torch.zeros_like(clip_image), output_hidden_states=True
328
+ ).hidden_states[-2]
329
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
330
+ return image_prompt_embeds, uncond_image_prompt_embeds
331
+
332
+
333
+ class IPAdapterFull(IPAdapterPlus):
334
+ """IP-Adapter with full features"""
335
+
336
+ def init_proj(self):
337
+ image_proj_model = MLPProjModel(
338
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
339
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
340
+ ).to(self.device, dtype=torch.float16)
341
+ return image_proj_model
342
+
343
+
344
+ class IPAdapterPlusXL(IPAdapter):
345
+ """SDXL"""
346
+
347
+ def init_proj(self):
348
+ image_proj_model = Resampler(
349
+ dim=1280,
350
+ depth=4,
351
+ dim_head=64,
352
+ heads=20,
353
+ num_queries=self.num_tokens,
354
+ embedding_dim=self.image_encoder.config.hidden_size,
355
+ output_dim=self.pipe.unet.config.cross_attention_dim,
356
+ ff_mult=4,
357
+ ).to(self.device, dtype=torch.float16)
358
+ return image_proj_model
359
+
360
+ @torch.inference_mode()
361
+ def get_image_embeds(self, pil_image):
362
+ if isinstance(pil_image, Image.Image):
363
+ pil_image = [pil_image]
364
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
365
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
366
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
367
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
368
+ uncond_clip_image_embeds = self.image_encoder(
369
+ torch.zeros_like(clip_image), output_hidden_states=True
370
+ ).hidden_states[-2]
371
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
372
+ return image_prompt_embeds, uncond_image_prompt_embeds
373
+
374
+ def generate(
375
+ self,
376
+ pil_image,
377
+ prompt=None,
378
+ negative_prompt=None,
379
+ scale=1.0,
380
+ num_samples=4,
381
+ seed=None,
382
+ num_inference_steps=30,
383
+ **kwargs,
384
+ ):
385
+ self.set_scale(scale)
386
+
387
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
388
+
389
+ if prompt is None:
390
+ prompt = "best quality, high quality"
391
+ if negative_prompt is None:
392
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
393
+
394
+ if not isinstance(prompt, List):
395
+ prompt = [prompt] * num_prompts
396
+ if not isinstance(negative_prompt, List):
397
+ negative_prompt = [negative_prompt] * num_prompts
398
+
399
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
400
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
401
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
402
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
403
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
404
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
405
+
406
+ with torch.inference_mode():
407
+ (
408
+ prompt_embeds,
409
+ negative_prompt_embeds,
410
+ pooled_prompt_embeds,
411
+ negative_pooled_prompt_embeds,
412
+ ) = self.pipe.encode_prompt(
413
+ prompt,
414
+ num_images_per_prompt=num_samples,
415
+ do_classifier_free_guidance=True,
416
+ negative_prompt=negative_prompt,
417
+ )
418
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
419
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
420
+
421
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
422
+ images = self.pipe(
423
+ prompt_embeds=prompt_embeds,
424
+ negative_prompt_embeds=negative_prompt_embeds,
425
+ pooled_prompt_embeds=pooled_prompt_embeds,
426
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
427
+ num_inference_steps=num_inference_steps,
428
+ generator=generator,
429
+ **kwargs,
430
+ ).images
431
+
432
+ return images
433
+
434
+
435
+ def StyleProcessor(style_image, device):
436
+ transform = transforms.Compose([
437
+ transforms.ToTensor(),
438
+ transforms.Normalize([0.5], [0.5]),
439
+ ])
440
+ # centercrop for style condition
441
+ crop = transforms.Compose(
442
+ [
443
+ transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
444
+ transforms.CenterCrop(512),
445
+ ]
446
+ )
447
+ style_image = crop(style_image)
448
+ high_style_patch, middle_style_patch, low_style_patch = pre_processing(style_image.convert("RGB"), transform)
449
+ # shuffling
450
+ high_style_patch, middle_style_patch, low_style_patch = (high_style_patch[torch.randperm(high_style_patch.shape[0])],
451
+ middle_style_patch[torch.randperm(middle_style_patch.shape[0])],
452
+ low_style_patch[torch.randperm(low_style_patch.shape[0])])
453
+ return (high_style_patch.to(device, dtype=torch.float32), middle_style_patch.to(device, dtype=torch.float32), low_style_patch.to(device, dtype=torch.float32))
454
+
455
+
456
+ class StyleShot(torch.nn.Module):
457
+ """StyleShot generation"""
458
+ def __init__(self, device, pipe, ip_ckpt, style_aware_encoder_ckpt, transformer_patch):
459
+ super().__init__()
460
+ self.num_tokens = 6
461
+ self.device = device
462
+ self.pipe = pipe
463
+
464
+ self.set_ip_adapter(device)
465
+ self.ip_ckpt = ip_ckpt
466
+
467
+ self.style_aware_encoder = Style_Aware_Encoder(CLIPVisionModelWithProjection.from_pretrained(transformer_patch)).to(self.device, dtype=torch.float32)
468
+ self.style_aware_encoder.load_state_dict(torch.load(style_aware_encoder_ckpt))
469
+
470
+ self.style_image_proj_modules = self.init_proj()
471
+
472
+ self.load_ip_adapter()
473
+ self.pipe = self.pipe.to(self.device, dtype=torch.float32)
474
+
475
+ def init_proj(self):
476
+ style_image_proj_modules = torch.nn.ModuleList([
477
+ ImageProjModel(
478
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
479
+ clip_embeddings_dim=self.style_aware_encoder.projection_dim,
480
+ clip_extra_context_tokens=2,
481
+ ),
482
+ ImageProjModel(
483
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
484
+ clip_embeddings_dim=self.style_aware_encoder.projection_dim,
485
+ clip_extra_context_tokens=2,
486
+ ),
487
+ ImageProjModel(
488
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
489
+ clip_embeddings_dim=self.style_aware_encoder.projection_dim,
490
+ clip_extra_context_tokens=2,
491
+ )])
492
+ return style_image_proj_modules.to(self.device, dtype=torch.float32)
493
+
494
+ def load_ip_adapter(self):
495
+ sd = torch.load(self.ip_ckpt, map_location="cpu")
496
+ style_image_proj_sd = {}
497
+ ip_sd = {}
498
+ controlnet_sd = {}
499
+ for k in sd:
500
+ if k.startswith("unet"):
501
+ pass
502
+ elif k.startswith("style_image_proj_modules"):
503
+ style_image_proj_sd[k.replace("style_image_proj_modules.", "")] = sd[k]
504
+ elif k.startswith("adapter_modules"):
505
+ ip_sd[k.replace("adapter_modules.", "")] = sd[k]
506
+ elif k.startswith("controlnet"):
507
+ controlnet_sd[k.replace("controlnet.", "")] = sd[k]
508
+ # Load state dict for image_proj_model and adapter_modules
509
+ self.style_image_proj_modules.load_state_dict(style_image_proj_sd, strict=True)
510
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
511
+ if hasattr(self.pipe, "controlnet") and isinstance(self.pipe, StyleContentStableDiffusionControlNetPipeline):
512
+ self.pipe.controlnet.load_state_dict(controlnet_sd, strict=True)
513
+ ip_layers.load_state_dict(ip_sd, strict=True)
514
+
515
+ def set_ip_adapter(self, device):
516
+ unet = self.pipe.unet
517
+ attn_procs = {}
518
+ for name in unet.attn_processors.keys():
519
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
520
+ if name.startswith("mid_block"):
521
+ hidden_size = unet.config.block_out_channels[-1]
522
+ elif name.startswith("up_blocks"):
523
+ block_id = int(name[len("up_blocks.")])
524
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
525
+ elif name.startswith("down_blocks"):
526
+ block_id = int(name[len("down_blocks.")])
527
+ hidden_size = unet.config.block_out_channels[block_id]
528
+ if cross_attention_dim is None:
529
+ attn_procs[name] = AttnProcessor()
530
+ else:
531
+ attn_procs[name] = IPAttnProcessor(
532
+ hidden_size=hidden_size,
533
+ cross_attention_dim=cross_attention_dim,
534
+ scale=1.0,
535
+ num_tokens=self.num_tokens,
536
+ ).to(device, dtype=torch.float16)
537
+ if hasattr(self.pipe, "controlnet") and not isinstance(self.pipe, StyleContentStableDiffusionControlNetPipeline):
538
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
539
+ for controlnet in self.pipe.controlnet.nets:
540
+ controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
541
+ else:
542
+ self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
543
+ unet.set_attn_processor(attn_procs)
544
+
545
+ @torch.inference_mode()
546
+ def get_image_embeds(self, style_image=None):
547
+ style_image = StyleProcessor(style_image, self.device)
548
+ style_embeds = self.style_aware_encoder(style_image).to(self.device, dtype=torch.float32)
549
+ style_ip_tokens = []
550
+ uncond_style_ip_tokens = []
551
+ for idx, style_embed in enumerate([style_embeds[:, 0, :], style_embeds[:, 1, :], style_embeds[:, 2, :]]):
552
+ style_ip_tokens.append(self.style_image_proj_modules[idx](style_embed))
553
+ uncond_style_ip_tokens.append(self.style_image_proj_modules[idx](torch.zeros_like(style_embed)))
554
+ style_ip_tokens = torch.cat(style_ip_tokens, dim=1)
555
+ uncond_style_ip_tokens = torch.cat(uncond_style_ip_tokens, dim=1)
556
+ return style_ip_tokens, uncond_style_ip_tokens
557
+
558
+ def set_scale(self, scale):
559
+ for attn_processor in self.pipe.unet.attn_processors.values():
560
+ if isinstance(attn_processor, IPAttnProcessor):
561
+ attn_processor.scale = scale
562
+
563
+ def samples(self, image_prompt_embeds, uncond_image_prompt_embeds, num_samples, device, prompt, negative_prompt,
564
+ seed, guidance_scale, num_inference_steps, content_image, **kwargs, ):
565
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
566
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
567
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
568
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
569
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
570
+ with torch.inference_mode():
571
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
572
+ prompt,
573
+ device=device,
574
+ num_images_per_prompt=num_samples,
575
+ do_classifier_free_guidance=True,
576
+ negative_prompt=negative_prompt,
577
+ )
578
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
579
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
580
+
581
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
582
+ if content_image is None:
583
+ images = self.pipe(
584
+ prompt_embeds=prompt_embeds,
585
+ negative_prompt_embeds=negative_prompt_embeds,
586
+ guidance_scale=guidance_scale,
587
+ num_inference_steps=num_inference_steps,
588
+ generator=generator,
589
+ **kwargs,
590
+ ).images
591
+ else:
592
+ images = self.pipe(
593
+ prompt_embeds=prompt_embeds,
594
+ negative_prompt_embeds=negative_prompt_embeds,
595
+ guidance_scale=guidance_scale,
596
+ num_inference_steps=num_inference_steps,
597
+ generator=generator,
598
+ image=content_image,
599
+ style_embeddings=image_prompt_embeds,
600
+ negative_style_embeddings=uncond_image_prompt_embeds,
601
+ **kwargs,
602
+ ).images
603
+ return images
604
+
605
+ def generate(
606
+ self,
607
+ style_image=None,
608
+ prompt=None,
609
+ negative_prompt=None,
610
+ scale=1.0,
611
+ num_samples=1,
612
+ seed=42,
613
+ guidance_scale=7.5,
614
+ num_inference_steps=50,
615
+ content_image=None,
616
+ **kwargs,
617
+ ):
618
+ self.set_scale(scale)
619
+
620
+ num_prompts = 1
621
+
622
+ if prompt is None:
623
+ prompt = "best quality, high quality"
624
+ if negative_prompt is None:
625
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
626
+
627
+ if not isinstance(prompt, List):
628
+ prompt = [prompt] * num_prompts
629
+ if not isinstance(negative_prompt, List):
630
+ negative_prompt = [negative_prompt] * num_prompts
631
+
632
+ style_ip_tokens, uncond_style_ip_tokens = self.get_image_embeds(style_image)
633
+ generate_images = []
634
+ for p in prompt:
635
+ images = self.samples(style_ip_tokens, uncond_style_ip_tokens, num_samples, self.device, p * num_prompts, negative_prompt, seed, guidance_scale, num_inference_steps, content_image, **kwargs, )
636
+ generate_images.append(images)
637
+ return generate_images
638
+
639
+
640
+ class StyleContentStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
641
+ @torch.no_grad()
642
+ def __call__(
643
+ self,
644
+ prompt: Union[str, List[str]] = None,
645
+ image: PipelineImageInput = None,
646
+ height: Optional[int] = None,
647
+ width: Optional[int] = None,
648
+ num_inference_steps: int = 50,
649
+ timesteps: List[int] = None,
650
+ guidance_scale: float = 7.5,
651
+ negative_prompt: Optional[Union[str, List[str]]] = None,
652
+ num_images_per_prompt: Optional[int] = 1,
653
+ eta: float = 0.0,
654
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
655
+ latents: Optional[torch.FloatTensor] = None,
656
+ prompt_embeds: Optional[torch.FloatTensor] = None,
657
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
658
+ ip_adapter_image: Optional[PipelineImageInput] = None,
659
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
660
+ output_type: Optional[str] = "pil",
661
+ return_dict: bool = True,
662
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
663
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
664
+ guess_mode: bool = False,
665
+ control_guidance_start: Union[float, List[float]] = 0.0,
666
+ control_guidance_end: Union[float, List[float]] = 1.0,
667
+ clip_skip: Optional[int] = None,
668
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
669
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
670
+ style_embeddings: Optional[torch.FloatTensor] = None,
671
+ negative_style_embeddings: Optional[torch.FloatTensor] = None,
672
+ **kwargs,
673
+ ):
674
+ r"""
675
+ The call function to the pipeline for generation.
676
+
677
+ Args:
678
+ prompt (`str` or `List[str]`, *optional*):
679
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
680
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
681
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
682
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
683
+ specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
684
+ accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
685
+ and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
686
+ `init`, images must be passed as a list such that each element of the list can be correctly batched for
687
+ input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single ControlNet,
688
+ each will be paired with each prompt in the `prompt` list. This also applies to multiple ControlNets,
689
+ where a list of image lists can be passed to batch for each prompt and each ControlNet.
690
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
691
+ The height in pixels of the generated image.
692
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
693
+ The width in pixels of the generated image.
694
+ num_inference_steps (`int`, *optional*, defaults to 50):
695
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
696
+ expense of slower inference.
697
+ timesteps (`List[int]`, *optional*):
698
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
699
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
700
+ passed will be used. Must be in descending order.
701
+ guidance_scale (`float`, *optional*, defaults to 7.5):
702
+ A higher guidance scale value encourages the model to generate images closely linked to the text
703
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
704
+ negative_prompt (`str` or `List[str]`, *optional*):
705
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
706
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
707
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
708
+ The number of images to generate per prompt.
709
+ eta (`float`, *optional*, defaults to 0.0):
710
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
711
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
712
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
713
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
714
+ generation deterministic.
715
+ latents (`torch.FloatTensor`, *optional*):
716
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
717
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
718
+ tensor is generated by sampling using the supplied random `generator`.
719
+ prompt_embeds (`torch.FloatTensor`, *optional*):
720
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
721
+ provided, text embeddings are generated from the `prompt` input argument.
722
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
723
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
724
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
725
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
726
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
727
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
728
+ Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
729
+ if `do_classifier_free_guidance` is set to `True`.
730
+ If not provided, embeddings are computed from the `ip_adapter_image` input argument.
731
+ output_type (`str`, *optional*, defaults to `"pil"`):
732
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
733
+ return_dict (`bool`, *optional*, defaults to `True`):
734
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
735
+ plain tuple.
736
+ callback (`Callable`, *optional*):
737
+ A function that calls every `callback_steps` steps during inference. The function is called with the
738
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
739
+ callback_steps (`int`, *optional*, defaults to 1):
740
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
741
+ every step.
742
+ cross_attention_kwargs (`dict`, *optional*):
743
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
744
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
745
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
746
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
747
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
748
+ the corresponding scale as a list.
749
+ guess_mode (`bool`, *optional*, defaults to `False`):
750
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
751
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
752
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
753
+ The percentage of total steps at which the ControlNet starts applying.
754
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
755
+ The percentage of total steps at which the ControlNet stops applying.
756
+ clip_skip (`int`, *optional*):
757
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
758
+ the output of the pre-final layer will be used for computing the prompt embeddings.
759
+ callback_on_step_end (`Callable`, *optional*):
760
+ A function that calls at the end of each denoising steps during the inference. The function is called
761
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
762
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
763
+ `callback_on_step_end_tensor_inputs`.
764
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
765
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
766
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
767
+ `._callback_tensor_inputs` attribute of your pipeine class.
768
+
769
+ Examples:
770
+
771
+ Returns:
772
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
773
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
774
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
775
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
776
+ "not-safe-for-work" (nsfw) content.
777
+ """
778
+
779
+ callback = kwargs.pop("callback", None)
780
+ callback_steps = kwargs.pop("callback_steps", None)
781
+
782
+ if callback is not None:
783
+ deprecate(
784
+ "callback",
785
+ "1.0.0",
786
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
787
+ )
788
+ if callback_steps is not None:
789
+ deprecate(
790
+ "callback_steps",
791
+ "1.0.0",
792
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
793
+ )
794
+
795
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
796
+
797
+ # align format for control guidance
798
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
799
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
800
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
801
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
802
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
803
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
804
+ control_guidance_start, control_guidance_end = (
805
+ mult * [control_guidance_start],
806
+ mult * [control_guidance_end],
807
+ )
808
+
809
+ # 1. Check inputs. Raise error if not correct
810
+ self.check_inputs(
811
+ prompt,
812
+ image,
813
+ callback_steps,
814
+ negative_prompt,
815
+ prompt_embeds,
816
+ negative_prompt_embeds,
817
+ ip_adapter_image,
818
+ ip_adapter_image_embeds,
819
+ controlnet_conditioning_scale,
820
+ control_guidance_start,
821
+ control_guidance_end,
822
+ callback_on_step_end_tensor_inputs,
823
+ )
824
+
825
+ self._guidance_scale = guidance_scale
826
+ self._clip_skip = clip_skip
827
+ self._cross_attention_kwargs = cross_attention_kwargs
828
+
829
+ # 2. Define call parameters
830
+ if prompt is not None and isinstance(prompt, str):
831
+ batch_size = 1
832
+ elif prompt is not None and isinstance(prompt, list):
833
+ batch_size = len(prompt)
834
+ else:
835
+ batch_size = prompt_embeds.shape[0]
836
+
837
+ device = self._execution_device
838
+
839
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
840
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
841
+
842
+ global_pool_conditions = (
843
+ controlnet.config.global_pool_conditions
844
+ if isinstance(controlnet, ControlNetModel)
845
+ else controlnet.nets[0].config.global_pool_conditions
846
+ )
847
+ guess_mode = guess_mode or global_pool_conditions
848
+
849
+ # 3. Encode input prompt
850
+ text_encoder_lora_scale = (
851
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
852
+ )
853
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
854
+ prompt,
855
+ device,
856
+ num_images_per_prompt,
857
+ self.do_classifier_free_guidance,
858
+ negative_prompt,
859
+ prompt_embeds=prompt_embeds,
860
+ negative_prompt_embeds=negative_prompt_embeds,
861
+ lora_scale=text_encoder_lora_scale,
862
+ clip_skip=self.clip_skip,
863
+ )
864
+ # For classifier free guidance, we need to do two forward passes.
865
+ # Here we concatenate the unconditional and text embeddings into a single batch
866
+ # to avoid doing two forward passes
867
+ if self.do_classifier_free_guidance:
868
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
869
+
870
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
871
+ image_embeds = self.prepare_ip_adapter_image_embeds(
872
+ ip_adapter_image,
873
+ ip_adapter_image_embeds,
874
+ device,
875
+ batch_size * num_images_per_prompt,
876
+ self.do_classifier_free_guidance,
877
+ )
878
+
879
+ # 4. Prepare image
880
+ if isinstance(controlnet, ControlNetModel):
881
+ image = self.prepare_image(
882
+ image=image,
883
+ width=width,
884
+ height=height,
885
+ batch_size=batch_size * num_images_per_prompt,
886
+ num_images_per_prompt=num_images_per_prompt,
887
+ device=device,
888
+ dtype=controlnet.dtype,
889
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
890
+ guess_mode=guess_mode,
891
+ )
892
+ height, width = image.shape[-2:]
893
+
894
+ elif isinstance(controlnet, MultiControlNetModel):
895
+ images = []
896
+
897
+ # Nested lists as ControlNet condition
898
+ if isinstance(image[0], list):
899
+ # Transpose the nested image list
900
+ image = [list(t) for t in zip(*image)]
901
+
902
+ for image_ in image:
903
+ image_ = self.prepare_image(
904
+ image=image_,
905
+ width=width,
906
+ height=height,
907
+ batch_size=batch_size * num_images_per_prompt,
908
+ num_images_per_prompt=num_images_per_prompt,
909
+ device=device,
910
+ dtype=controlnet.dtype,
911
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
912
+ guess_mode=guess_mode,
913
+ )
914
+
915
+ images.append(image_)
916
+
917
+ image = images
918
+ height, width = image[0].shape[-2:]
919
+ else:
920
+ assert False
921
+
922
+ # 5. Prepare timesteps
923
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
924
+ self._num_timesteps = len(timesteps)
925
+
926
+ # 6. Prepare latent variables
927
+ num_channels_latents = self.unet.config.in_channels
928
+ latents = self.prepare_latents(
929
+ batch_size * num_images_per_prompt,
930
+ num_channels_latents,
931
+ height,
932
+ width,
933
+ prompt_embeds.dtype,
934
+ device,
935
+ generator,
936
+ latents,
937
+ )
938
+
939
+ # 6.5 Optionally get Guidance Scale Embedding
940
+ timestep_cond = None
941
+ if self.unet.config.time_cond_proj_dim is not None:
942
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
943
+ timestep_cond = self.get_guidance_scale_embedding(
944
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
945
+ ).to(device=device, dtype=latents.dtype)
946
+
947
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
948
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
949
+
950
+ # 7.1 Add image embeds for IP-Adapter
951
+ added_cond_kwargs = (
952
+ {"image_embeds": image_embeds}
953
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
954
+ else None
955
+ )
956
+
957
+ # 7.2 Create tensor stating which controlnets to keep
958
+ controlnet_keep = []
959
+ for i in range(len(timesteps)):
960
+ keeps = [
961
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
962
+ for s, e in zip(control_guidance_start, control_guidance_end)
963
+ ]
964
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
965
+
966
+ # 8. Denoising loop
967
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
968
+ is_unet_compiled = is_compiled_module(self.unet)
969
+ is_controlnet_compiled = is_compiled_module(self.controlnet)
970
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
971
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
972
+ for i, t in enumerate(timesteps):
973
+ # Relevant thread:
974
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
975
+ if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
976
+ torch._inductor.cudagraph_mark_step_begin()
977
+ # expand the latents if we are doing classifier free guidance
978
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
979
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
980
+
981
+ # controlnet(s) inference
982
+ if guess_mode and self.do_classifier_free_guidance:
983
+ # Infer ControlNet only for the conditional batch.
984
+ control_model_input = latents
985
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
986
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
987
+ else:
988
+ control_model_input = latent_model_input
989
+ controlnet_prompt_embeds = prompt_embeds
990
+
991
+ if isinstance(controlnet_keep[i], list):
992
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
993
+ else:
994
+ controlnet_cond_scale = controlnet_conditioning_scale
995
+ if isinstance(controlnet_cond_scale, list):
996
+ controlnet_cond_scale = controlnet_cond_scale[0]
997
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
998
+
999
+ if self.do_classifier_free_guidance:
1000
+ style_embeddings_input = torch.cat([negative_style_embeddings, style_embeddings])
1001
+
1002
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1003
+ control_model_input,
1004
+ t,
1005
+ encoder_hidden_states=style_embeddings_input,
1006
+ controlnet_cond=image,
1007
+ conditioning_scale=cond_scale,
1008
+ guess_mode=guess_mode,
1009
+ return_dict=False,
1010
+ )
1011
+
1012
+ if guess_mode and self.do_classifier_free_guidance:
1013
+ # Infered ControlNet only for the conditional batch.
1014
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1015
+ # add 0 to the unconditional batch to keep it unchanged.
1016
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1017
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1018
+
1019
+ # predict the noise residual
1020
+ noise_pred = self.unet(
1021
+ latent_model_input,
1022
+ t,
1023
+ encoder_hidden_states=prompt_embeds,
1024
+ timestep_cond=timestep_cond,
1025
+ cross_attention_kwargs=self.cross_attention_kwargs,
1026
+ down_block_additional_residuals=down_block_res_samples,
1027
+ mid_block_additional_residual=mid_block_res_sample,
1028
+ added_cond_kwargs=added_cond_kwargs,
1029
+ return_dict=False,
1030
+ )[0]
1031
+
1032
+ # perform guidance
1033
+ if self.do_classifier_free_guidance:
1034
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1035
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1036
+
1037
+ # compute the previous noisy sample x_t -> x_t-1
1038
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1039
+
1040
+ if callback_on_step_end is not None:
1041
+ callback_kwargs = {}
1042
+ for k in callback_on_step_end_tensor_inputs:
1043
+ callback_kwargs[k] = locals()[k]
1044
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1045
+
1046
+ latents = callback_outputs.pop("latents", latents)
1047
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1048
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1049
+
1050
+ # call the callback, if provided
1051
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1052
+ progress_bar.update()
1053
+ if callback is not None and i % callback_steps == 0:
1054
+ step_idx = i // getattr(self.scheduler, "order", 1)
1055
+ callback(step_idx, t, latents)
1056
+
1057
+ # If we do sequential model offloading, let's offload unet and controlnet
1058
+ # manually for max memory savings
1059
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1060
+ self.unet.to("cpu")
1061
+ self.controlnet.to("cpu")
1062
+ torch.cuda.empty_cache()
1063
+
1064
+ if not output_type == "latent":
1065
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1066
+ 0
1067
+ ]
1068
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1069
+ else:
1070
+ image = latents
1071
+ has_nsfw_concept = None
1072
+
1073
+ if has_nsfw_concept is None:
1074
+ do_denormalize = [True] * image.shape[0]
1075
+ else:
1076
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1077
+
1078
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1079
+
1080
+ # Offload all models
1081
+ self.maybe_free_model_hooks()
1082
+
1083
+ if not return_dict:
1084
+ return (image, has_nsfw_concept)
1085
+
1086
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
ip_adapter/resampler.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from einops.layers.torch import Rearrange
10
+
11
+
12
+ # FFN
13
+ def FeedForward(dim, mult=4):
14
+ inner_dim = int(dim * mult)
15
+ return nn.Sequential(
16
+ nn.LayerNorm(dim),
17
+ nn.Linear(dim, inner_dim, bias=False),
18
+ nn.GELU(),
19
+ nn.Linear(inner_dim, dim, bias=False),
20
+ )
21
+
22
+
23
+ def reshape_tensor(x, heads):
24
+ bs, length, width = x.shape
25
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26
+ x = x.view(bs, length, heads, -1)
27
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28
+ x = x.transpose(1, 2)
29
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30
+ x = x.reshape(bs, heads, length, -1)
31
+ return x
32
+
33
+
34
+ class PerceiverAttention(nn.Module):
35
+ def __init__(self, *, dim, dim_head=64, heads=8):
36
+ super().__init__()
37
+ self.scale = dim_head**-0.5
38
+ self.dim_head = dim_head
39
+ self.heads = heads
40
+ inner_dim = dim_head * heads
41
+
42
+ self.norm1 = nn.LayerNorm(dim)
43
+ self.norm2 = nn.LayerNorm(dim)
44
+
45
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
46
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
48
+
49
+ def forward(self, x, latents):
50
+ """
51
+ Args:
52
+ x (torch.Tensor): image features
53
+ shape (b, n1, D)
54
+ latent (torch.Tensor): latent features
55
+ shape (b, n2, D)
56
+ """
57
+ x = self.norm1(x)
58
+ latents = self.norm2(latents)
59
+
60
+ b, l, _ = latents.shape
61
+
62
+ q = self.to_q(latents)
63
+ kv_input = torch.cat((x, latents), dim=-2)
64
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65
+
66
+ q = reshape_tensor(q, self.heads)
67
+ k = reshape_tensor(k, self.heads)
68
+ v = reshape_tensor(v, self.heads)
69
+
70
+ # attention
71
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74
+ out = weight @ v
75
+
76
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77
+
78
+ return self.to_out(out)
79
+
80
+
81
+ class Resampler(nn.Module):
82
+ def __init__(
83
+ self,
84
+ dim=1024,
85
+ depth=8,
86
+ dim_head=64,
87
+ heads=16,
88
+ num_queries=8,
89
+ embedding_dim=768,
90
+ output_dim=1024,
91
+ ff_mult=4,
92
+ max_seq_len: int = 257, # CLIP tokens + CLS token
93
+ apply_pos_emb: bool = False,
94
+ num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
95
+ ):
96
+ super().__init__()
97
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
98
+
99
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
100
+
101
+ self.proj_in = nn.Linear(embedding_dim, dim)
102
+
103
+ self.proj_out = nn.Linear(dim, output_dim)
104
+ self.norm_out = nn.LayerNorm(output_dim)
105
+
106
+ self.to_latents_from_mean_pooled_seq = (
107
+ nn.Sequential(
108
+ nn.LayerNorm(dim),
109
+ nn.Linear(dim, dim * num_latents_mean_pooled),
110
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
111
+ )
112
+ if num_latents_mean_pooled > 0
113
+ else None
114
+ )
115
+
116
+ self.layers = nn.ModuleList([])
117
+ for _ in range(depth):
118
+ self.layers.append(
119
+ nn.ModuleList(
120
+ [
121
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
122
+ FeedForward(dim=dim, mult=ff_mult),
123
+ ]
124
+ )
125
+ )
126
+
127
+ def forward(self, x):
128
+ if self.pos_emb is not None:
129
+ n, device = x.shape[1], x.device
130
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
131
+ x = x + pos_emb
132
+
133
+ latents = self.latents.repeat(x.size(0), 1, 1)
134
+
135
+ x = self.proj_in(x)
136
+
137
+ if self.to_latents_from_mean_pooled_seq:
138
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
139
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
140
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
141
+
142
+ for attn, ff in self.layers:
143
+ latents = attn(x, latents) + latents
144
+ latents = ff(latents) + latents
145
+
146
+ latents = self.proj_out(latents)
147
+ return self.norm_out(latents)
148
+
149
+
150
+ def masked_mean(t, *, dim, mask=None):
151
+ if mask is None:
152
+ return t.mean(dim=dim)
153
+
154
+ denom = mask.sum(dim=dim, keepdim=True)
155
+ mask = rearrange(mask, "b n -> b n 1")
156
+ masked_t = t.masked_fill(~mask, 0.0)
157
+
158
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
ip_adapter/style_encoder.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from collections import OrderedDict
5
+
6
+
7
+ def conv_nd(dims, *args, **kwargs):
8
+ """
9
+ Create a 1D, 2D, or 3D convolution module.
10
+ """
11
+ if dims == 1:
12
+ return nn.Conv1d(*args, **kwargs)
13
+ elif dims == 2:
14
+ return nn.Conv2d(*args, **kwargs)
15
+ elif dims == 3:
16
+ return nn.Conv3d(*args, **kwargs)
17
+ raise ValueError(f"unsupported dimensions: {dims}")
18
+
19
+
20
+ def avg_pool_nd(dims, *args, **kwargs):
21
+ """
22
+ Create a 1D, 2D, or 3D average pooling module.
23
+ """
24
+ if dims == 1:
25
+ return nn.AvgPool1d(*args, **kwargs)
26
+ elif dims == 2:
27
+ return nn.AvgPool2d(*args, **kwargs)
28
+ elif dims == 3:
29
+ return nn.AvgPool3d(*args, **kwargs)
30
+ raise ValueError(f"unsupported dimensions: {dims}")
31
+
32
+ def get_parameter_dtype(parameter: torch.nn.Module):
33
+ try:
34
+ params = tuple(parameter.parameters())
35
+ if len(params) > 0:
36
+ return params[0].dtype
37
+
38
+ buffers = tuple(parameter.buffers())
39
+ if len(buffers) > 0:
40
+ return buffers[0].dtype
41
+
42
+ except StopIteration:
43
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
44
+
45
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
46
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
47
+ return tuples
48
+
49
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
50
+ first_tuple = next(gen)
51
+ return first_tuple[1].dtype
52
+
53
+ class Downsample(nn.Module):
54
+ """
55
+ A downsampling layer with an optional convolution.
56
+ :param channels: channels in the inputs and outputs.
57
+ :param use_conv: a bool determining if a convolution is applied.
58
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
59
+ downsampling occurs in the inner-two dimensions.
60
+ """
61
+
62
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
63
+ super().__init__()
64
+ self.channels = channels
65
+ self.out_channels = out_channels or channels
66
+ self.use_conv = use_conv
67
+ self.dims = dims
68
+ stride = 2 if dims != 3 else (1, 2, 2)
69
+ if use_conv:
70
+ self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
71
+ else:
72
+ assert self.channels == self.out_channels
73
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
74
+
75
+ def forward(self, x):
76
+ assert x.shape[1] == self.channels
77
+ return self.op(x)
78
+
79
+
80
+ class ResnetBlock(nn.Module):
81
+
82
+ def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
83
+ super().__init__()
84
+ ps = ksize // 2
85
+ if in_c != out_c or sk == False:
86
+ self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
87
+ else:
88
+ self.in_conv = None
89
+ self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
90
+ self.act = nn.ReLU()
91
+ self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
92
+ if sk == False:
93
+ self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
94
+ else:
95
+ self.skep = None
96
+
97
+ self.down = down
98
+ if self.down == True:
99
+ self.down_opt = Downsample(in_c, use_conv=use_conv)
100
+
101
+ def forward(self, x):
102
+ if self.down == True:
103
+ x = self.down_opt(x)
104
+ if self.in_conv is not None: # edit
105
+ x = self.in_conv(x)
106
+
107
+ h = self.block1(x)
108
+ h = self.act(h)
109
+ h = self.block2(h)
110
+ if self.skep is not None:
111
+ return h + self.skep(x)
112
+ else:
113
+ return h + x
114
+
115
+ class Low_CNN(nn.Module):
116
+ def __init__(self, cin=192, ksize=1, sk=False, use_conv=True):
117
+ super(Low_CNN, self).__init__()
118
+ self.unshuffle = nn.PixelUnshuffle(8)
119
+ self.body = nn.Sequential(ResnetBlock(320, 320, down=False, ksize=ksize, sk=sk, use_conv=use_conv),
120
+ ResnetBlock(320, 640, down=False, ksize=ksize, sk=sk, use_conv=use_conv),
121
+ ResnetBlock(640, 1280, down=True, ksize=ksize, sk=sk, use_conv=use_conv),
122
+ ResnetBlock(1280, 1280, down=False, ksize=ksize, sk=sk, use_conv=use_conv))
123
+ self.conv_in = nn.Conv2d(cin, 320, 3, 1, 1)
124
+ self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
125
+ self.adapter = nn.Linear(1280, 1280)
126
+
127
+ @property
128
+ def dtype(self) -> torch.dtype:
129
+ """
130
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
131
+ """
132
+ return get_parameter_dtype(self)
133
+
134
+ def forward(self, x):
135
+ x = self.unshuffle(x)
136
+ x = self.conv_in(x)
137
+ x = self.body(x)
138
+ x = self.pool(x)
139
+ x = x.flatten(start_dim=1, end_dim=-1)
140
+ x = self.adapter(x)
141
+ return x
142
+
143
+
144
+ class Middle_CNN(nn.Module):
145
+ def __init__(self, cin=192, ksize=1, sk=False, use_conv=True):
146
+ super(Middle_CNN, self).__init__()
147
+ self.unshuffle = nn.PixelUnshuffle(8)
148
+ self.body = nn.Sequential(ResnetBlock(320, 320, down=False, ksize=ksize, sk=sk, use_conv=use_conv),
149
+ ResnetBlock(320, 640, down=False, ksize=ksize, sk=sk, use_conv=use_conv),
150
+ ResnetBlock(640, 640, down=True, ksize=ksize, sk=sk, use_conv=use_conv),
151
+ ResnetBlock(640, 1280, down=True, ksize=ksize, sk=sk, use_conv=use_conv),
152
+ ResnetBlock(1280, 1280, down=False, ksize=ksize, sk=sk, use_conv=use_conv))
153
+ self.conv_in = nn.Conv2d(cin, 320, 3, 1, 1)
154
+ self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
155
+ self.adapter = nn.Linear(1280, 1280)
156
+
157
+ @property
158
+ def dtype(self) -> torch.dtype:
159
+ """
160
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
161
+ """
162
+ return get_parameter_dtype(self)
163
+
164
+ def forward(self, x):
165
+ x = self.unshuffle(x)
166
+ x = self.conv_in(x)
167
+ x = self.body(x)
168
+ x = self.pool(x)
169
+ x = x.flatten(start_dim=1, end_dim=-1)
170
+ x = self.adapter(x)
171
+ return x
172
+
173
+
174
+ class High_CNN(nn.Module):
175
+ def __init__(self, cin=192, ksize=1, sk=False, use_conv=True):
176
+ super(High_CNN, self).__init__()
177
+ self.unshuffle = nn.PixelUnshuffle(8)
178
+ self.body = nn.Sequential(ResnetBlock(320, 320, down=False, ksize=ksize, sk=sk, use_conv=use_conv),
179
+ ResnetBlock(320, 640, down=False, ksize=ksize, sk=sk, use_conv=use_conv),
180
+ ResnetBlock(640, 640, down=True, ksize=ksize, sk=sk, use_conv=use_conv),
181
+ ResnetBlock(640, 640, down=True, ksize=ksize, sk=sk, use_conv=use_conv),
182
+ ResnetBlock(640, 1280, down=True, ksize=ksize, sk=sk, use_conv=use_conv),
183
+ ResnetBlock(1280, 1280, down=False, ksize=ksize, sk=sk, use_conv=use_conv))
184
+ self.conv_in = nn.Conv2d(cin, 320, 3, 1, 1)
185
+ self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
186
+ self.adapter = nn.Linear(1280, 1280)
187
+
188
+ @property
189
+ def dtype(self) -> torch.dtype:
190
+ """
191
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
192
+ """
193
+ return get_parameter_dtype(self)
194
+
195
+ def forward(self, x):
196
+ x = self.unshuffle(x)
197
+ x = self.conv_in(x)
198
+ x = self.body(x)
199
+ x = self.pool(x)
200
+ x = x.flatten(start_dim=1, end_dim=-1)
201
+ x = self.adapter(x)
202
+ return x
203
+
204
+
205
+ class Style_Aware_Encoder(torch.nn.Module):
206
+ def __init__(self, image_encoder):
207
+ super().__init__()
208
+ self.image_encoder = image_encoder
209
+ self.projection_dim = self.image_encoder.config.projection_dim
210
+ self.num_positions = 59
211
+ self.embed_dim = 1280
212
+ self.cnn = nn.ModuleList(
213
+ [High_CNN(sk=True, use_conv=False),
214
+ Middle_CNN(sk=True, use_conv=False),
215
+ Low_CNN(sk=True, use_conv=False)]
216
+ )
217
+ self.style_embeddings = nn.ParameterList(
218
+ [nn.Parameter(torch.randn(self.embed_dim)),
219
+ nn.Parameter(torch.randn(self.embed_dim)),
220
+ nn.Parameter(torch.randn(self.embed_dim))]
221
+ )
222
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
223
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
224
+
225
+ def forward(self, inputs, batch_size=1):
226
+ embeddings = []
227
+ for idx, x in enumerate(inputs):
228
+ class_embed = self.style_embeddings[idx].expand(batch_size, 1, -1)
229
+ patch_embed = self.cnn[idx](x)
230
+ patch_embed = patch_embed.view(batch_size, -1, patch_embed.shape[1])
231
+ embedding = torch.cat([class_embed, patch_embed], dim=1)
232
+ embeddings.append(embedding)
233
+ embeddings = torch.cat(embeddings, dim=1)
234
+ embeddings = embeddings + self.position_embedding(self.position_ids) # [B, 256, 1280] - [B, P, 1280]
235
+ embeddings = self.image_encoder.vision_model.pre_layrnorm(embeddings)
236
+ encoder_outputs = self.image_encoder.vision_model.encoder(
237
+ inputs_embeds=embeddings,
238
+ output_attentions=None,
239
+ output_hidden_states=None,
240
+ return_dict=None,
241
+ )
242
+ last_hidden_state = encoder_outputs[0]
243
+ pooled_output = last_hidden_state[:, [0, 9, 26], :]
244
+ pooled_output = self.image_encoder.vision_model.post_layernorm(pooled_output)
245
+ out = self.image_encoder.visual_projection(pooled_output)
246
+ return out
ip_adapter/test_resampler.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from resampler import Resampler
3
+ from transformers import CLIPVisionModel
4
+
5
+ BATCH_SIZE = 2
6
+ OUTPUT_DIM = 1280
7
+ NUM_QUERIES = 8
8
+ NUM_LATENTS_MEAN_POOLED = 4 # 0 for no mean pooling (previous behavior)
9
+ APPLY_POS_EMB = True # False for no positional embeddings (previous behavior)
10
+ IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
11
+
12
+
13
+ def main():
14
+ image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH)
15
+ embedding_dim = image_encoder.config.hidden_size
16
+ print(f"image_encoder hidden size: ", embedding_dim)
17
+
18
+ image_proj_model = Resampler(
19
+ dim=1024,
20
+ depth=2,
21
+ dim_head=64,
22
+ heads=16,
23
+ num_queries=NUM_QUERIES,
24
+ embedding_dim=embedding_dim,
25
+ output_dim=OUTPUT_DIM,
26
+ ff_mult=2,
27
+ max_seq_len=257,
28
+ apply_pos_emb=APPLY_POS_EMB,
29
+ num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED,
30
+ )
31
+
32
+ dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224)
33
+ with torch.no_grad():
34
+ image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2]
35
+ print("image_embds shape: ", image_embeds.shape)
36
+
37
+ with torch.no_grad():
38
+ ip_tokens = image_proj_model(image_embeds)
39
+ print("ip_tokens shape:", ip_tokens.shape)
40
+ assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM)
41
+
42
+
43
+ if __name__ == "__main__":
44
+ main()
ip_adapter/tools.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+
7
+
8
+ def crop_4_patches(image):
9
+ crop_size = int(image.size[0]/2)
10
+ return (image.crop((0, 0, crop_size, crop_size)), image.crop((0, crop_size, crop_size, 2*crop_size)),
11
+ image.crop((crop_size, 0, 2*crop_size, crop_size)), image.crop((crop_size, crop_size, 2*crop_size, 2*crop_size)))
12
+
13
+
14
+ def pre_processing(image, transform):
15
+ high_level = []
16
+ middle_level = []
17
+ low_level = []
18
+ crops_4 = crop_4_patches(image)
19
+ for c_4 in crops_4:
20
+ crops_8 = crop_4_patches(c_4)
21
+ high_level.append(transform(crops_8[0]))
22
+ high_level.append(transform(crops_8[3]))
23
+ for c_8 in [crops_8[1], crops_8[2]]:
24
+ crops_16 = crop_4_patches(c_8)
25
+ middle_level.append(transform(crops_16[0]))
26
+ middle_level.append(transform(crops_16[3]))
27
+ for c_16 in [crops_16[1], crops_16[2]]:
28
+ crops_32 = crop_4_patches(c_16)
29
+ low_level.append(transform(crops_32[0]))
30
+ low_level.append(transform(crops_32[3]))
31
+ return torch.stack(high_level), torch.stack(middle_level), torch.stack(low_level)
ip_adapter/utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+
3
+
4
+ def is_torch2_available():
5
+ return hasattr(F, "scaled_dot_product_attention")