Simma7 commited on
Commit
0a6c6ed
·
verified ·
1 Parent(s): 7555365

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +20 -0
model.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import timm
3
+ from safetensors.torch import load_file
4
+
5
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
6
+
7
+ def load_model(path, arch="vit_base_patch16_224"):
8
+ state_dict = load_file(path)
9
+
10
+ # auto detect output classes
11
+ last_key = list(state_dict.keys())[-1]
12
+ out_features = state_dict[last_key].shape[0]
13
+
14
+ model = timm.create_model(arch, pretrained=False, num_classes=out_features)
15
+ model.load_state_dict(state_dict, strict=False)
16
+
17
+ model.to(DEVICE)
18
+ model.eval()
19
+
20
+ return model