add torchscript and safetensors versions of retccl model

#1
by kaczmarj - opened

This commit adds TorchScript and Safetensors versions of the RetCCL model. Torchscript allows the model to be used without code defining its implementation and without a Python runtime. This can help others incorporate RetCCL into other applications. In fact I was planning to upload RetCCL to HuggingFace but I found this repository first. I am hoping to use the Torchscript version of this model in several applications and to pull it from this repository.

Safetensors is a file format developed by HuggingFace to deal with some drawbacks of the PyTorch pickle-based format. I have uploaded a safetensors version here.

Here is the code I used to create the two files here. First I cloned the RetCCL GitHub repo, and then made minor changes to ResNet.py to satisfy TorchScript requirements. Namely, I set self.instDis = nn.Identity() and self.groupDis = nn.Identity() if those attributes were not set.

import numpy as np
from safetensors.torch import save_file
import torch
from torch import nn
import ResNet

model = ResNet.resnet50(num_classes=128,mlp=False, two_branch=False, normlinear=True)
pretext_model = torch.load("/home/jakub/Downloads/retccl.pth", map_location="cpu")
model.fc = nn.Identity()
model.load_state_dict(pretext_model, strict=True)
model.eval()

# Save torchscript model.
model_jit = torch.jit.script(model, example_inputs=[(torch.ones(1, 3, 224, 224),)])
torch.jit.save(model_jit, "retccl_torchscript.pth")

# Save safetensors weights
save_file(pretext_model, "retccl.safetensors")

# Ensure model outputs are the same in JIT and original model.
x = torch.ones(1, 3, 224, 224)
with torch.no_grad():
    orig = model(x)
    new = model_jit(x)

assert np.array_equal(orig, new)

Thanks for the contribution! I will review and merge later today or tomorrow.

jamesdolezal changed pull request status to merged

Sign up or log in to comment