chychiu commited on
Commit
2e47c02
1 Parent(s): 0fc6977

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +3 -3
script.py CHANGED
@@ -161,7 +161,7 @@ def generate_embeddings(metadata_file_path, root_dir):
161
 
162
  loader = DataLoader(test_dataset, batch_size=3, shuffle=False)
163
 
164
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
165
  model = timm.create_model(
166
  "timm/vit_large_patch14_reg4_dinov2.lvd142m", pretrained=False
167
  )
@@ -225,7 +225,7 @@ class FungiMEEModel(nn.Module):
225
  super().__init__()
226
 
227
  print("Setting up Pytorch Model")
228
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
229
  print(f"Using devide: {self.device}")
230
 
231
  self.date_embedding = MlpHead(
@@ -279,7 +279,7 @@ class FungiEnsembleModel(nn.Module):
279
  super().__init__()
280
 
281
  self.models = nn.ModuleList()
282
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
283
 
284
  for model in models:
285
  model = model.to(self.device)
 
161
 
162
  loader = DataLoader(test_dataset, batch_size=3, shuffle=False)
163
 
164
+ device = torch.device('cpu')
165
  model = timm.create_model(
166
  "timm/vit_large_patch14_reg4_dinov2.lvd142m", pretrained=False
167
  )
 
225
  super().__init__()
226
 
227
  print("Setting up Pytorch Model")
228
+ self.device = torch.device('cpu')
229
  print(f"Using devide: {self.device}")
230
 
231
  self.date_embedding = MlpHead(
 
279
  super().__init__()
280
 
281
  self.models = nn.ModuleList()
282
+ self.device = torch.device('cpu')
283
 
284
  for model in models:
285
  model = model.to(self.device)