kaczmarj commited on
Commit
f2937b7
1 Parent(s): a2a789e

Delete convert_pt.py

Browse files
Files changed (1) hide show
  1. convert_pt.py +0 -59
convert_pt.py DELETED
@@ -1,59 +0,0 @@
1
- from safetensors.torch import save_file
2
- import torch
3
- from torchvision.models import resnet34
4
-
5
- model_path = "RESNET_34_cancer_350px_lr_1e-2_decay_5_jitter_val6slides_harder_tcga_none_0403_0204_0.9826153355179645_16.t7"
6
-
7
- orig_model = torch.load(model_path, map_location="cpu")
8
- state_dict = orig_model["model"].module.state_dict()
9
- keys_missing = [
10
- "bn1.num_batches_tracked",
11
- "layer1.0.bn1.num_batches_tracked",
12
- "layer1.0.bn2.num_batches_tracked",
13
- "layer1.1.bn1.num_batches_tracked",
14
- "layer1.1.bn2.num_batches_tracked",
15
- "layer1.2.bn1.num_batches_tracked",
16
- "layer1.2.bn2.num_batches_tracked",
17
- "layer2.0.bn1.num_batches_tracked",
18
- "layer2.0.bn2.num_batches_tracked",
19
- "layer2.0.downsample.1.num_batches_tracked",
20
- "layer2.1.bn1.num_batches_tracked",
21
- "layer2.1.bn2.num_batches_tracked",
22
- "layer2.2.bn1.num_batches_tracked",
23
- "layer2.2.bn2.num_batches_tracked",
24
- "layer2.3.bn1.num_batches_tracked",
25
- "layer2.3.bn2.num_batches_tracked",
26
- "layer3.0.bn1.num_batches_tracked",
27
- "layer3.0.bn2.num_batches_tracked",
28
- "layer3.0.downsample.1.num_batches_tracked",
29
- "layer3.1.bn1.num_batches_tracked",
30
- "layer3.1.bn2.num_batches_tracked",
31
- "layer3.2.bn1.num_batches_tracked",
32
- "layer3.2.bn2.num_batches_tracked",
33
- "layer3.3.bn1.num_batches_tracked",
34
- "layer3.3.bn2.num_batches_tracked",
35
- "layer3.4.bn1.num_batches_tracked",
36
- "layer3.4.bn2.num_batches_tracked",
37
- "layer3.5.bn1.num_batches_tracked",
38
- "layer3.5.bn2.num_batches_tracked",
39
- "layer4.0.bn1.num_batches_tracked",
40
- "layer4.0.bn2.num_batches_tracked",
41
- "layer4.0.downsample.1.num_batches_tracked",
42
- "layer4.1.bn1.num_batches_tracked",
43
- "layer4.1.bn2.num_batches_tracked",
44
- "layer4.2.bn1.num_batches_tracked",
45
- "layer4.2.bn2.num_batches_tracked",
46
- ]
47
- assert not any(
48
- key in state_dict.keys() for key in keys_missing
49
- ), "key present that should be missing"
50
- for key in keys_missing:
51
- state_dict[key] = torch.as_tensor(0)
52
- torch.save(state_dict, "pytorch_model.pt")
53
- save_file(state_dict, "model.safetensors")
54
-
55
- model = resnet34(weights=None)
56
- model.fc = torch.nn.Linear(model.fc.in_features, out_features=5, bias=True)
57
- model.load_state_dict(state_dict)
58
- model_jit = torch.jit.script(model, example_inputs=[(torch.ones(1, 3, 224, 224),)])
59
- torch.jit.save(model_jit, "torchscript_model.bin")