kaczmarj commited on
Commit
1a4c404
1 Parent(s): b93463e

add torchscript and safetensors versions of retccl model

Browse files

This commit adds TorchScript and Safetensors versions of the RetCCL model. Torchscript allows the model to be used without code defining its implementation and without a Python runtime. This can help others incorporate RetCCL into other applications. In fact I was planning to upload RetCCL to HuggingFace but I found this repository first. I am hoping to use the Torchscript version of this model in several applications and to pull it from this repository.

Safetensors is a file format developed by HuggingFace to deal with some drawbacks of the PyTorch pickle-based format. I have uploaded a safetensors version here.

Here is the code I used to create the two files here. First I cloned the RetCCL GitHub repo, and then made minor changes to ResNet.py to satisfy TorchScript requirements. Namely, I set `self.instDis = nn.Identity()` and `self.groupDis = nn.Identity()` if those attributes were `None`.

```python
import numpy as np
from safetensors.torch import save_file
import torch
from torch import nn
import ResNet

model = ResNet.resnet50(num_classes=128,mlp=False, two_branch=False, normlinear=True)
pretext_model = torch.load("/home/jakub/Downloads/retccl.pth", map_location="cpu")
model.fc = nn.Identity()
model.load_state_dict(pretext_model, strict=True)
model.eval()

# Save torchscript model.
model_jit = torch.jit.script(model, example_inputs=[(torch.ones(1, 3, 224, 224),)])
torch.jit.save(model_jit, "retccl_torchscript.pth")

# Save safetensors weights
save_file(pretext_model, "retccl.safetensors")

# Ensure model outputs are the same in JIT and original model.
x = torch.ones(1, 3, 224, 224)
with torch.no_grad():
orig = model(x)
new = model_jit(x)

assert np.array_equal(orig, new)
```

Files changed (2) hide show
  1. retccl.safetensors +3 -0
  2. retccl_torchscript.pth +3 -0
retccl.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e253cc0ef27f6e0bf0fa8481f4268c7005406591e4ba236c455af8ee8e76e96c
3
+ size 94273408
retccl_torchscript.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:867ee6d5b2d2b4702cb35f6f5e55a80f02f6bfbb9b4e6e25b1dd87f8608176ba
3
+ size 94418045