abc
commited on
Commit
·
0380cfd
1
Parent(s):
72f0b64
Upload 7 files
Browse files- lycoris/kohya.py +37 -1
- lycoris/locon.py +8 -19
- lycoris/utils.py +72 -4
lycoris/kohya.py
CHANGED
@@ -70,6 +70,12 @@ class LycorisNetwork(torch.nn.Module):
|
|
70 |
"Downsample2D",
|
71 |
"Upsample2D"
|
72 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
74 |
LORA_PREFIX_UNET = 'lora_unet'
|
75 |
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
@@ -102,7 +108,12 @@ class LycorisNetwork(torch.nn.Module):
|
|
102 |
self.dropout = dropout
|
103 |
|
104 |
# create module instances
|
105 |
-
def create_modules(
|
|
|
|
|
|
|
|
|
|
|
106 |
print('Create LyCORIS Module')
|
107 |
loras = []
|
108 |
for name, module in root_module.named_modules():
|
@@ -132,6 +143,31 @@ class LycorisNetwork(torch.nn.Module):
|
|
132 |
else:
|
133 |
continue
|
134 |
loras.append(lora)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
return loras
|
136 |
|
137 |
self.text_encoder_loras = create_modules(
|
|
|
70 |
"Downsample2D",
|
71 |
"Upsample2D"
|
72 |
]
|
73 |
+
UNET_TARGET_REPLACE_NAME = [
|
74 |
+
"conv_in",
|
75 |
+
"conv_out",
|
76 |
+
"time_embedding.linear_1",
|
77 |
+
"time_embedding.linear_2",
|
78 |
+
]
|
79 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
80 |
LORA_PREFIX_UNET = 'lora_unet'
|
81 |
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
|
|
108 |
self.dropout = dropout
|
109 |
|
110 |
# create module instances
|
111 |
+
def create_modules(
|
112 |
+
prefix,
|
113 |
+
root_module: torch.nn.Module,
|
114 |
+
target_replace_modules,
|
115 |
+
target_replace_names = []
|
116 |
+
) -> List[network_module]:
|
117 |
print('Create LyCORIS Module')
|
118 |
loras = []
|
119 |
for name, module in root_module.named_modules():
|
|
|
143 |
else:
|
144 |
continue
|
145 |
loras.append(lora)
|
146 |
+
elif name in target_replace_names:
|
147 |
+
lora_name = prefix + '.' + name
|
148 |
+
lora_name = lora_name.replace('.', '_')
|
149 |
+
if module.__class__.__name__ == 'Linear' and lora_dim>0:
|
150 |
+
lora = network_module(
|
151 |
+
lora_name, module, self.multiplier,
|
152 |
+
self.lora_dim, self.alpha, self.dropout, use_cp
|
153 |
+
)
|
154 |
+
elif module.__class__.__name__ == 'Conv2d':
|
155 |
+
k_size, *_ = module.kernel_size
|
156 |
+
if k_size==1 and lora_dim>0:
|
157 |
+
lora = network_module(
|
158 |
+
lora_name, module, self.multiplier,
|
159 |
+
self.lora_dim, self.alpha, self.dropout, use_cp
|
160 |
+
)
|
161 |
+
elif conv_lora_dim>0:
|
162 |
+
lora = network_module(
|
163 |
+
lora_name, module, self.multiplier,
|
164 |
+
self.conv_lora_dim, self.conv_alpha, self.dropout, use_cp
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
continue
|
168 |
+
else:
|
169 |
+
continue
|
170 |
+
loras.append(lora)
|
171 |
return loras
|
172 |
|
173 |
self.text_encoder_loras = create_modules(
|
lycoris/locon.py
CHANGED
@@ -38,18 +38,11 @@ class LoConModule(nn.Module):
|
|
38 |
else:
|
39 |
self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False)
|
40 |
self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
|
41 |
-
self.op = F.conv2d
|
42 |
-
self.extra_args = {
|
43 |
-
'stride': stride,
|
44 |
-
'padding': padding
|
45 |
-
}
|
46 |
else:
|
47 |
in_dim = org_module.in_features
|
48 |
out_dim = org_module.out_features
|
49 |
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
|
50 |
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
|
51 |
-
self.op = F.linear
|
52 |
-
self.extra_args = {}
|
53 |
self.shape = org_module.weight.shape
|
54 |
|
55 |
if dropout:
|
@@ -66,6 +59,8 @@ class LoConModule(nn.Module):
|
|
66 |
# same as microsoft's
|
67 |
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
68 |
torch.nn.init.zeros_(self.lora_up.weight)
|
|
|
|
|
69 |
|
70 |
self.multiplier = multiplier
|
71 |
self.org_module = [org_module]
|
@@ -81,16 +76,10 @@ class LoConModule(nn.Module):
|
|
81 |
|
82 |
def forward(self, x):
|
83 |
if self.cp:
|
84 |
-
return self.dropout(
|
85 |
-
self.
|
86 |
-
|
87 |
-
) * self.multiplier * self.scale
|
88 |
else:
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
(self.org_module[0].weight.data
|
93 |
-
+ self.dropout(self.make_weight()) * self.multiplier * self.scale),
|
94 |
-
bias,
|
95 |
-
**self.extra_args,
|
96 |
-
)
|
|
|
38 |
else:
|
39 |
self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False)
|
40 |
self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
|
|
|
|
|
|
|
|
|
|
|
41 |
else:
|
42 |
in_dim = org_module.in_features
|
43 |
out_dim = org_module.out_features
|
44 |
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
|
45 |
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
|
|
|
|
|
46 |
self.shape = org_module.weight.shape
|
47 |
|
48 |
if dropout:
|
|
|
59 |
# same as microsoft's
|
60 |
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
61 |
torch.nn.init.zeros_(self.lora_up.weight)
|
62 |
+
if self.cp:
|
63 |
+
torch.nn.init.kaiming_uniform_(self.lora_mid.weight, a=math.sqrt(5))
|
64 |
|
65 |
self.multiplier = multiplier
|
66 |
self.org_module = [org_module]
|
|
|
76 |
|
77 |
def forward(self, x):
|
78 |
if self.cp:
|
79 |
+
return self.org_forward(x) + self.dropout(
|
80 |
+
self.lora_up(self.lora_mid(self.lora_down(x)))* self.multiplier * self.scale
|
81 |
+
)
|
|
|
82 |
else:
|
83 |
+
return self.org_forward(x) + self.dropout(
|
84 |
+
self.lora_up(self.lora_down(x))* self.multiplier * self.scale
|
85 |
+
)
|
|
|
|
|
|
|
|
|
|
lycoris/utils.py
CHANGED
@@ -164,6 +164,12 @@ def extract_diff(
|
|
164 |
"Downsample2D",
|
165 |
"Upsample2D"
|
166 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
168 |
LORA_PREFIX_UNET = 'lora_unet'
|
169 |
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
@@ -171,10 +177,12 @@ def extract_diff(
|
|
171 |
prefix,
|
172 |
root_module: torch.nn.Module,
|
173 |
target_module: torch.nn.Module,
|
174 |
-
target_replace_modules
|
|
|
175 |
):
|
176 |
loras = {}
|
177 |
temp = {}
|
|
|
178 |
|
179 |
for name, module in root_module.named_modules():
|
180 |
if module.__class__.__name__ in target_replace_modules:
|
@@ -183,6 +191,8 @@ def extract_diff(
|
|
183 |
if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
|
184 |
continue
|
185 |
temp[name][child_name] = child_module.weight
|
|
|
|
|
186 |
|
187 |
for name, module in tqdm(list(target_module.named_modules())):
|
188 |
if name in temp:
|
@@ -221,7 +231,7 @@ def extract_diff(
|
|
221 |
diff = child_module.weight - torch.einsum(
|
222 |
'i j k l, j r, p i -> p r k l',
|
223 |
extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
|
224 |
-
)
|
225 |
del extract_c
|
226 |
else:
|
227 |
continue
|
@@ -231,7 +241,7 @@ def extract_diff(
|
|
231 |
|
232 |
if use_bias:
|
233 |
diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
|
234 |
-
sparse_diff = make_sparse(diff, sparsity).to_sparse()
|
235 |
|
236 |
indices = sparse_diff.indices().to(torch.int16)
|
237 |
values = sparse_diff.values().half()
|
@@ -239,6 +249,63 @@ def extract_diff(
|
|
239 |
loras[f'{lora_name}.bias_values'] = values
|
240 |
loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
|
241 |
del extract_a, extract_b, diff
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
return loras
|
243 |
|
244 |
text_encoder_loras = make_state_dict(
|
@@ -250,7 +317,8 @@ def extract_diff(
|
|
250 |
unet_loras = make_state_dict(
|
251 |
LORA_PREFIX_UNET,
|
252 |
base_model[2], db_model[2],
|
253 |
-
UNET_TARGET_REPLACE_MODULE
|
|
|
254 |
)
|
255 |
print(len(text_encoder_loras), len(unet_loras))
|
256 |
return text_encoder_loras|unet_loras
|
|
|
164 |
"Downsample2D",
|
165 |
"Upsample2D"
|
166 |
]
|
167 |
+
UNET_TARGET_REPLACE_NAME = [
|
168 |
+
"conv_in",
|
169 |
+
"conv_out",
|
170 |
+
"time_embedding.linear_1",
|
171 |
+
"time_embedding.linear_2",
|
172 |
+
]
|
173 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
174 |
LORA_PREFIX_UNET = 'lora_unet'
|
175 |
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
|
|
177 |
prefix,
|
178 |
root_module: torch.nn.Module,
|
179 |
target_module: torch.nn.Module,
|
180 |
+
target_replace_modules,
|
181 |
+
target_replace_names = []
|
182 |
):
|
183 |
loras = {}
|
184 |
temp = {}
|
185 |
+
temp_name = {}
|
186 |
|
187 |
for name, module in root_module.named_modules():
|
188 |
if module.__class__.__name__ in target_replace_modules:
|
|
|
191 |
if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
|
192 |
continue
|
193 |
temp[name][child_name] = child_module.weight
|
194 |
+
elif name in target_replace_names:
|
195 |
+
temp_name[name] = module.weight
|
196 |
|
197 |
for name, module in tqdm(list(target_module.named_modules())):
|
198 |
if name in temp:
|
|
|
231 |
diff = child_module.weight - torch.einsum(
|
232 |
'i j k l, j r, p i -> p r k l',
|
233 |
extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
|
234 |
+
).detach().cpu().contiguous()
|
235 |
del extract_c
|
236 |
else:
|
237 |
continue
|
|
|
241 |
|
242 |
if use_bias:
|
243 |
diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
|
244 |
+
sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
|
245 |
|
246 |
indices = sparse_diff.indices().to(torch.int16)
|
247 |
values = sparse_diff.values().half()
|
|
|
249 |
loras[f'{lora_name}.bias_values'] = values
|
250 |
loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
|
251 |
del extract_a, extract_b, diff
|
252 |
+
elif name in temp_name:
|
253 |
+
weight = temp_name[name]
|
254 |
+
lora_name = prefix + '.' + name
|
255 |
+
lora_name = lora_name.replace('.', '_')
|
256 |
+
|
257 |
+
if weight.size(0)<32 or weight.size(1)<32:
|
258 |
+
loras[f'{lora_name}.diff'] = module.weight - weight
|
259 |
+
continue
|
260 |
+
|
261 |
+
layer = module.__class__.__name__
|
262 |
+
if layer == 'Linear':
|
263 |
+
extract_a, extract_b, diff = extract_linear(
|
264 |
+
(module.weight - weight),
|
265 |
+
mode,
|
266 |
+
linear_mode_param,
|
267 |
+
device = extract_device,
|
268 |
+
)
|
269 |
+
elif layer == 'Conv2d':
|
270 |
+
is_linear = (module.weight.shape[2] == 1
|
271 |
+
and module.weight.shape[3] == 1)
|
272 |
+
extract_a, extract_b, diff = extract_conv(
|
273 |
+
(module.weight - weight),
|
274 |
+
mode,
|
275 |
+
linear_mode_param if is_linear else conv_mode_param,
|
276 |
+
device = extract_device,
|
277 |
+
)
|
278 |
+
if small_conv and not is_linear:
|
279 |
+
dim = extract_a.size(0)
|
280 |
+
extract_c, extract_a, _ = extract_conv(
|
281 |
+
extract_a.transpose(0, 1),
|
282 |
+
'fixed', dim,
|
283 |
+
extract_device
|
284 |
+
)
|
285 |
+
extract_a = extract_a.transpose(0, 1)
|
286 |
+
extract_c = extract_c.transpose(0, 1)
|
287 |
+
loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half()
|
288 |
+
diff = module.weight - torch.einsum(
|
289 |
+
'i j k l, j r, p i -> p r k l',
|
290 |
+
extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
|
291 |
+
).detach().cpu().contiguous()
|
292 |
+
del extract_c
|
293 |
+
else:
|
294 |
+
continue
|
295 |
+
loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
|
296 |
+
loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
|
297 |
+
loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
|
298 |
+
|
299 |
+
if use_bias:
|
300 |
+
diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
|
301 |
+
sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
|
302 |
+
|
303 |
+
indices = sparse_diff.indices().to(torch.int16)
|
304 |
+
values = sparse_diff.values().half()
|
305 |
+
loras[f'{lora_name}.bias_indices'] = indices
|
306 |
+
loras[f'{lora_name}.bias_values'] = values
|
307 |
+
loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
|
308 |
+
del extract_a, extract_b, diff
|
309 |
return loras
|
310 |
|
311 |
text_encoder_loras = make_state_dict(
|
|
|
317 |
unet_loras = make_state_dict(
|
318 |
LORA_PREFIX_UNET,
|
319 |
base_model[2], db_model[2],
|
320 |
+
UNET_TARGET_REPLACE_MODULE,
|
321 |
+
UNET_TARGET_REPLACE_NAME
|
322 |
)
|
323 |
print(len(text_encoder_loras), len(unet_loras))
|
324 |
return text_encoder_loras|unet_loras
|