|
def match_lora(lora, to_load): |
|
patch_dict = {} |
|
loaded_keys = set() |
|
for x in to_load: |
|
real_load_key = to_load[x] |
|
if real_load_key in lora: |
|
patch_dict[real_load_key] = ('fooocus', lora[real_load_key]) |
|
loaded_keys.add(real_load_key) |
|
continue |
|
|
|
alpha_name = "{}.alpha".format(x) |
|
alpha = None |
|
if alpha_name in lora.keys(): |
|
alpha = lora[alpha_name].item() |
|
loaded_keys.add(alpha_name) |
|
|
|
regular_lora = "{}.lora_up.weight".format(x) |
|
diffusers_lora = "{}_lora.up.weight".format(x) |
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x) |
|
A_name = None |
|
|
|
if regular_lora in lora.keys(): |
|
A_name = regular_lora |
|
B_name = "{}.lora_down.weight".format(x) |
|
mid_name = "{}.lora_mid.weight".format(x) |
|
elif diffusers_lora in lora.keys(): |
|
A_name = diffusers_lora |
|
B_name = "{}_lora.down.weight".format(x) |
|
mid_name = None |
|
elif transformers_lora in lora.keys(): |
|
A_name = transformers_lora |
|
B_name ="{}.lora_linear_layer.down.weight".format(x) |
|
mid_name = None |
|
|
|
if A_name is not None: |
|
mid = None |
|
if mid_name is not None and mid_name in lora.keys(): |
|
mid = lora[mid_name] |
|
loaded_keys.add(mid_name) |
|
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid)) |
|
loaded_keys.add(A_name) |
|
loaded_keys.add(B_name) |
|
|
|
|
|
|
|
hada_w1_a_name = "{}.hada_w1_a".format(x) |
|
hada_w1_b_name = "{}.hada_w1_b".format(x) |
|
hada_w2_a_name = "{}.hada_w2_a".format(x) |
|
hada_w2_b_name = "{}.hada_w2_b".format(x) |
|
hada_t1_name = "{}.hada_t1".format(x) |
|
hada_t2_name = "{}.hada_t2".format(x) |
|
if hada_w1_a_name in lora.keys(): |
|
hada_t1 = None |
|
hada_t2 = None |
|
if hada_t1_name in lora.keys(): |
|
hada_t1 = lora[hada_t1_name] |
|
hada_t2 = lora[hada_t2_name] |
|
loaded_keys.add(hada_t1_name) |
|
loaded_keys.add(hada_t2_name) |
|
|
|
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)) |
|
loaded_keys.add(hada_w1_a_name) |
|
loaded_keys.add(hada_w1_b_name) |
|
loaded_keys.add(hada_w2_a_name) |
|
loaded_keys.add(hada_w2_b_name) |
|
|
|
|
|
|
|
lokr_w1_name = "{}.lokr_w1".format(x) |
|
lokr_w2_name = "{}.lokr_w2".format(x) |
|
lokr_w1_a_name = "{}.lokr_w1_a".format(x) |
|
lokr_w1_b_name = "{}.lokr_w1_b".format(x) |
|
lokr_t2_name = "{}.lokr_t2".format(x) |
|
lokr_w2_a_name = "{}.lokr_w2_a".format(x) |
|
lokr_w2_b_name = "{}.lokr_w2_b".format(x) |
|
|
|
lokr_w1 = None |
|
if lokr_w1_name in lora.keys(): |
|
lokr_w1 = lora[lokr_w1_name] |
|
loaded_keys.add(lokr_w1_name) |
|
|
|
lokr_w2 = None |
|
if lokr_w2_name in lora.keys(): |
|
lokr_w2 = lora[lokr_w2_name] |
|
loaded_keys.add(lokr_w2_name) |
|
|
|
lokr_w1_a = None |
|
if lokr_w1_a_name in lora.keys(): |
|
lokr_w1_a = lora[lokr_w1_a_name] |
|
loaded_keys.add(lokr_w1_a_name) |
|
|
|
lokr_w1_b = None |
|
if lokr_w1_b_name in lora.keys(): |
|
lokr_w1_b = lora[lokr_w1_b_name] |
|
loaded_keys.add(lokr_w1_b_name) |
|
|
|
lokr_w2_a = None |
|
if lokr_w2_a_name in lora.keys(): |
|
lokr_w2_a = lora[lokr_w2_a_name] |
|
loaded_keys.add(lokr_w2_a_name) |
|
|
|
lokr_w2_b = None |
|
if lokr_w2_b_name in lora.keys(): |
|
lokr_w2_b = lora[lokr_w2_b_name] |
|
loaded_keys.add(lokr_w2_b_name) |
|
|
|
lokr_t2 = None |
|
if lokr_t2_name in lora.keys(): |
|
lokr_t2 = lora[lokr_t2_name] |
|
loaded_keys.add(lokr_t2_name) |
|
|
|
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): |
|
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)) |
|
|
|
|
|
a1_name = "{}.a1.weight".format(x) |
|
a2_name = "{}.a2.weight".format(x) |
|
b1_name = "{}.b1.weight".format(x) |
|
b2_name = "{}.b2.weight".format(x) |
|
if a1_name in lora: |
|
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha)) |
|
loaded_keys.add(a1_name) |
|
loaded_keys.add(a2_name) |
|
loaded_keys.add(b1_name) |
|
loaded_keys.add(b2_name) |
|
|
|
w_norm_name = "{}.w_norm".format(x) |
|
b_norm_name = "{}.b_norm".format(x) |
|
w_norm = lora.get(w_norm_name, None) |
|
b_norm = lora.get(b_norm_name, None) |
|
|
|
if w_norm is not None: |
|
loaded_keys.add(w_norm_name) |
|
patch_dict[to_load[x]] = ("diff", (w_norm,)) |
|
if b_norm is not None: |
|
loaded_keys.add(b_norm_name) |
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,)) |
|
|
|
diff_name = "{}.diff".format(x) |
|
diff_weight = lora.get(diff_name, None) |
|
if diff_weight is not None: |
|
patch_dict[to_load[x]] = ("diff", (diff_weight,)) |
|
loaded_keys.add(diff_name) |
|
|
|
diff_bias_name = "{}.diff_b".format(x) |
|
diff_bias = lora.get(diff_bias_name, None) |
|
if diff_bias is not None: |
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,)) |
|
loaded_keys.add(diff_bias_name) |
|
|
|
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys} |
|
return patch_dict, remaining_dict |
|
|