Can you provide scripts that convert to safetensors format?

#1
by rkingzhong - opened

Can you provide scripts that convert to safetensors format?
I found that the results were different from yours, and the model was converted through the safetensors tool.

Here is the very hacky script that I used to create the weight files.

import torch
from safetensors.torch import load_file
from safetensors.torch import save_file

def rename(name: str):
    name = name.replace("model.0.", "net.b1.0.")
    name = name.replace("model.1.", "net.b1.1.")
    name = name.replace("model.2.m.", "net.b2.0.bottleneck.")
    name = name.replace("model.2.", "net.b2.0.")
    name = name.replace("model.3.", "net.b2.1.")
    name = name.replace("model.3.", "net.b2.1.")
    name = name.replace("model.4.m.", "net.b2.2.bottleneck.")
    name = name.replace("model.4.", "net.b2.2.")
    name = name.replace("model.5.", "net.b3.0.")
    name = name.replace("model.6.m.", "net.b3.1.bottleneck.")
    name = name.replace("model.6.", "net.b3.1.")
    name = name.replace("model.7.", "net.b4.0.")
    name = name.replace("model.8.m.", "net.b4.1.bottleneck.")
    name = name.replace("model.8.", "net.b4.1.")
    name = name.replace("model.9.", "net.b5.0.")
    name = name.replace("model.12.m.", "fpn.n1.bottleneck.")
    name = name.replace("model.12.", "fpn.n1.")
    name = name.replace("model.15.m.", "fpn.n2.bottleneck.")
    name = name.replace("model.15.", "fpn.n2.")
    name = name.replace("model.16.", "fpn.n3.")
    name = name.replace("model.18.m.", "fpn.n4.bottleneck.")
    name = name.replace("model.18.", "fpn.n4.")
    name = name.replace("model.19.", "fpn.n5.")
    name = name.replace("model.21.m.", "fpn.n6.bottleneck.")
    name = name.replace("model.21.", "fpn.n6.")
    name = name.replace("model.22.", "head.")
    return name

data = torch.load("/tmp/yolov8s-pose.pt")
#data = torch.load("/tmp/yolov8s.pt")
print(data)
tensors = data['model'].state_dict().items()
tensors = dict(tensors)
tensors = {rename(k): t for k, t in tensors.items()}
print(data["model"])
save_file(tensors, "/tmp/model.safetensors")
for k, v in tensors.items():
    print(str(k), v.shape)

thks your script.

Thank you for your script, it works very well. I am looking forward to your model supporting classification tasks.

Sign up or log in to comment