File size: 5,278 Bytes
f61e618
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import torch
from safetensors.torch import load_file, save_file
from collections import defaultdict

def convert_comfy_to_wan_lora_final_fp16(lora_path, output_path):
    """
    Converts a ComfyUI-style LoRA to the format expected by 'wan.modules.model'.
    - Keeps 'diffusion_model.' prefix.
    - Converts 'lora_A' to 'lora_down', 'lora_B' to 'lora_up'.
    - Skips per-layer '.alpha' keys.
    - Skips keys related to 'img_emb.' that are under the 'diffusion_model.' prefix.
    - Converts all LoRA weight tensors to float16.

    Args:
        lora_path (str): Path to the input ComfyUI LoRA .safetensors file.
        output_path (str): Path to save the converted LoRA .safetensors file.
    """
    try:
        source_state_dict = load_file(lora_path)
    except Exception as e:
        print(f"Error loading LoRA file '{lora_path}': {e}")
        return

    diffusers_state_dict = {}
    print(f"Loaded {len(source_state_dict)} tensors from {lora_path}")

    source_comfy_prefix = "diffusion_model."
    target_wan_prefix = "diffusion_model." 

    converted_count = 0
    skipped_alpha_keys_count = 0
    skipped_img_emb_keys_count = 0
    problematic_keys = [] 

    for key, tensor in source_state_dict.items():
        original_key = key
        
        if not key.startswith(source_comfy_prefix):
            problematic_keys.append(f"{original_key} (Key does not start with expected prefix '{source_comfy_prefix}')")
            continue

        module_and_lora_part = key[len(source_comfy_prefix):]
        
        if module_and_lora_part.startswith("img_emb."):
            skipped_img_emb_keys_count += 1
            continue 

        new_key_module_base = ""
        new_lora_suffix = ""
        is_weight_tensor = False # Flag to identify tensors that need dtype conversion

        if module_and_lora_part.endswith(".lora_A.weight"):
            new_key_module_base = module_and_lora_part[:-len(".lora_A.weight")]
            new_lora_suffix = ".lora_down.weight"
            is_weight_tensor = True
        elif module_and_lora_part.endswith(".lora_B.weight"):
            new_key_module_base = module_and_lora_part[:-len(".lora_B.weight")]
            new_lora_suffix = ".lora_up.weight"
            is_weight_tensor = True
        elif module_and_lora_part.endswith(".alpha"):
            skipped_alpha_keys_count += 1
            continue # Alpha keys are skipped and don't need dtype conversion if they were kept
        else:
            problematic_keys.append(f"{original_key} (Unknown LoRA suffix or non-LoRA key within '{source_comfy_prefix}' structure: '...{module_and_lora_part[-25:]}')")
            continue
            
        new_key = target_wan_prefix + new_key_module_base + new_lora_suffix
        
        # Convert to float16 if it's a weight tensor
        if is_weight_tensor:
            if tensor.is_floating_point(): # Only convert floating point types
                diffusers_state_dict[new_key] = tensor.to(torch.float16)
            else: # Should not happen for LoRA weights, but as a safeguard
                diffusers_state_dict[new_key] = tensor
                print(f"Warning: Tensor {original_key} was not floating point, dtype not changed.")

        else: # Should not be reached if only lora_A/B weights are processed
             diffusers_state_dict[new_key] = tensor


        converted_count += 1

    print(f"\nKey conversion finished.")
    print(f"Successfully processed and converted {converted_count} LoRA weight keys (to float16).")
    if skipped_alpha_keys_count > 0:
        print(f"Skipped {skipped_alpha_keys_count} '.alpha' keys.")
    if skipped_img_emb_keys_count > 0:
        print(f"Skipped {skipped_img_emb_keys_count} 'diffusion_model.img_emb.' related keys.")
    if problematic_keys:
        print(f"Found {len(problematic_keys)} other keys that were also skipped (see details below):")
        for pkey in problematic_keys:
            print(f"  - {pkey}")
    
    if diffusers_state_dict:
        print(f"Output dictionary has {len(diffusers_state_dict)} keys.")
        print(f"Now attempting to save the file to: {output_path} (This might take a while for large files)...")
        try:
            save_file(diffusers_state_dict, output_path)
            print(f"\nSuccessfully saved converted LoRA to: {output_path}")
        except Exception as e:
            print(f"Error saving converted LoRA file '{output_path}': {e}")
    elif converted_count == 0 and source_state_dict:
         print("\nNo keys were converted. Check input LoRA format and skipped key counts.")
    elif not source_state_dict:
        print("\nInput LoRA file seems empty or could not be loaded. No conversion performed.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Convert ComfyUI-style LoRA to 'wan.modules.model' format, converting weights to float16.",
        formatter_class=argparse.RawTextHelpFormatter
    )
    parser.add_argument("lora_path", type=str, help="Path to the input ComfyUI LoRA (.safetensors) file.")
    parser.add_argument("output_path", type=str, help="Path to save the converted LoRA (.safetensors) file.")
    args = parser.parse_args()

    convert_comfy_to_wan_lora_final_fp16(args.lora_path, args.output_path)