ehottl commited on
Commit
732c49e
1 Parent(s): 3732e13

safetensor file

Browse files
Files changed (3) hide show
  1. cifar_net.pth +1 -1
  2. model.safetensors +3 -0
  3. pytorch_classifier_gen.py +12 -2
cifar_net.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bf87afeab91ee70d644d0f22733711e058dfa3d20edb740b01cb1abe0ddd2b4e
3
  size 251167
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c6818fe3fd14557084d0ddaf577fcdaf2ada6d32356000dae47022f9b0d2b01
3
  size 251167
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e248d017220b608973ba016c97fde2e0bd63f76e28bc852d588693c5afd651a
3
+ size 248760
pytorch_classifier_gen.py CHANGED
@@ -4,6 +4,8 @@ import torchvision.transforms as transforms
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  import torch.optim as optim
 
 
7
 
8
  # 데이터셋 불러오기
9
  transform = transforms.Compose(
@@ -77,6 +79,14 @@ for epoch in range(2): # 데이터셋을 수차례 반복합니다.
77
 
78
  print('Finished Training')
79
 
80
- # 모델 저장하기
81
  PATH = './cifar_net.pth'
82
- torch.save(net.state_dict(), PATH)
 
 
 
 
 
 
 
 
 
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  import torch.optim as optim
7
+ from safetensors import safe_open
8
+ from safetensors.torch import save_file
9
 
10
  # 데이터셋 불러오기
11
  transform = transforms.Compose(
 
79
 
80
  print('Finished Training')
81
 
82
+ # 모델 저장하기
83
  PATH = './cifar_net.pth'
84
+ torch.save(net.state_dict(), PATH) # Not safe way
85
+
86
+ save_file(net.state_dict(), "model.safetensors")
87
+
88
+ # 모델 불러오기
89
+ tensors = {}
90
+ with safe_open("model.safetensors", framework="pt", device="cpu") as f:
91
+ for key in f.keys():
92
+ tensors[key] = f.get_tensor(key)