Spaces:
Sleeping
Sleeping
🔨 [Add] wegith transform for v9seg model
Browse files
yolo/tools/format_converters.py
CHANGED
|
@@ -83,3 +83,55 @@ def convert_weight_v7(old_state_dict, new_state_dict):
|
|
| 83 |
assert new_shape == old_shape, "Weight Shape Mismatch!! {old_key_name}"
|
| 84 |
new_state_dict[new_key_name] = old_state_dict[old_key_name]
|
| 85 |
return new_state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
assert new_shape == old_shape, "Weight Shape Mismatch!! {old_key_name}"
|
| 84 |
new_state_dict[new_key_name] = old_state_dict[old_key_name]
|
| 85 |
return new_state_dict
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
replace_dict = {"cv": "conv", ".m.": ".bottleneck."}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def convert_weight_seg(old_state_dict, new_state_dict):
|
| 92 |
+
diff = -1
|
| 93 |
+
for old_weight_name in old_state_dict.keys():
|
| 94 |
+
old_idx = int(old_weight_name.split(".")[1])
|
| 95 |
+
if old_idx == 23:
|
| 96 |
+
diff = 3
|
| 97 |
+
elif old_idx == 41:
|
| 98 |
+
diff = -19
|
| 99 |
+
new_idx = old_idx + diff
|
| 100 |
+
new_weight_name = old_weight_name.replace(f".{old_idx}.", f".{new_idx}.")
|
| 101 |
+
for key, val in replace_dict.items():
|
| 102 |
+
new_weight_name = new_weight_name.replace(key, val)
|
| 103 |
+
|
| 104 |
+
if new_weight_name not in new_state_dict.keys():
|
| 105 |
+
heads = "heads"
|
| 106 |
+
_, _, conv_name, conv_idx, *details = old_weight_name.split(".")
|
| 107 |
+
if "proto" in conv_name:
|
| 108 |
+
conv_idx = "3"
|
| 109 |
+
new_weight_name = ".".join(["model", str(layer_idx), heads, conv_task, *details])
|
| 110 |
+
continue
|
| 111 |
+
if "dfl" in old_weight_name:
|
| 112 |
+
continue
|
| 113 |
+
if conv_name == "cv2" or conv_name == "cv3" or conv_name == "cv6":
|
| 114 |
+
layer_idx = 44
|
| 115 |
+
heads = "detect.heads"
|
| 116 |
+
if conv_name == "cv4" or conv_name == "cv5" or conv_name == "cv7":
|
| 117 |
+
layer_idx = 25
|
| 118 |
+
heads = "detect.heads"
|
| 119 |
+
|
| 120 |
+
if conv_name == "cv2" or conv_name == "cv4":
|
| 121 |
+
conv_task = "anchor_conv"
|
| 122 |
+
if conv_name == "cv3" or conv_name == "cv5":
|
| 123 |
+
conv_task = "class_conv"
|
| 124 |
+
if conv_name == "cv6" or conv_name == "cv7":
|
| 125 |
+
conv_task = "mask_conv"
|
| 126 |
+
heads = "heads"
|
| 127 |
+
|
| 128 |
+
new_weight_name = ".".join(["model", str(layer_idx), heads, conv_idx, conv_task, *details])
|
| 129 |
+
|
| 130 |
+
if (
|
| 131 |
+
new_weight_name not in new_state_dict.keys()
|
| 132 |
+
or new_state_dict[new_weight_name].shape != old_state_dict[old_weight_name].shape
|
| 133 |
+
):
|
| 134 |
+
print(f"new: {new_weight_name}, old: {old_weight_name}")
|
| 135 |
+
print(f"{new_state_dict[new_weight_name].shape} {old_state_dict[old_weight_name].shape}")
|
| 136 |
+
new_state_dict[new_weight_name] = old_state_dict[old_weight_name]
|
| 137 |
+
return new_state_dict
|