farrosalferro24 commited on
Commit
4f1298a
1 Parent(s): 979164d

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +3 -3
script.py CHANGED
@@ -28,9 +28,9 @@ class PytorchWorker:
28
  "cuda:0" if torch.cuda.is_available() else "cpu")
29
  print(f"Using devide: {self.device}")
30
 
31
- from torchvision.models import vit_b_16, ViT_B_16_Weights
32
 
33
- model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1)
34
 
35
  model.heads.head = torch.nn.Linear(model.heads.head.in_features,
36
  number_of_categories)
@@ -52,7 +52,7 @@ class PytorchWorker:
52
  self.model = _load_model(model_name, model_path)
53
 
54
  self.transforms = T.Compose([
55
- T.Resize((384, 384)),
56
  T.ToTensor(),
57
  T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
58
  ])
 
28
  "cuda:0" if torch.cuda.is_available() else "cpu")
29
  print(f"Using devide: {self.device}")
30
 
31
+ from torchvision.models import vit_b_16
32
 
33
+ model = vit_b_16(weights=None)
34
 
35
  model.heads.head = torch.nn.Linear(model.heads.head.in_features,
36
  number_of_categories)
 
52
  self.model = _load_model(model_name, model_path)
53
 
54
  self.transforms = T.Compose([
55
+ T.Resize((224, 224)),
56
  T.ToTensor(),
57
  T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
58
  ])