|
def convert_weight(old_state_dict, new_state_dict, model_size: int = 38): |
|
|
|
shift = 1 |
|
for idx in range(model_size): |
|
new_list, old_list = [], [] |
|
for weight_name, weight_value in new_state_dict.items(): |
|
if weight_name.split(".")[0] == str(idx): |
|
new_list.append((weight_name, None)) |
|
for weight_name, weight_value in old_state_dict.items(): |
|
if f"model.{idx+shift}." in weight_name: |
|
old_list.append((weight_name, weight_value)) |
|
if len(new_list) == len(old_list): |
|
for (weight_name, _), (_, weight_value) in zip(new_list, old_list): |
|
new_state_dict[weight_name] = weight_value |
|
else: |
|
for weight_name, weight_value in old_list: |
|
if "dfl" in weight_name: |
|
continue |
|
_, _, conv_name, conv_idx, *details = weight_name.split(".") |
|
if conv_name == "cv4" or conv_name == "cv5": |
|
layer_idx = 22 |
|
shift = 2 |
|
else: |
|
layer_idx = 37 |
|
|
|
if conv_name == "cv2" or conv_name == "cv4": |
|
conv_task = "anchor_conv" |
|
if conv_name == "cv3" or conv_name == "cv5": |
|
conv_task = "class_conv" |
|
|
|
weight_name = ".".join([str(layer_idx), "heads", conv_idx, conv_task, *details]) |
|
new_state_dict[weight_name] = weight_value |
|
return new_state_dict |
|
|