Sifal commited on
Commit
e5528e5
1 Parent(s): 5737312

handle cpu

Browse files
Files changed (1) hide show
  1. README.md +8 -2
README.md CHANGED
@@ -30,12 +30,18 @@ from model import BertClassifier
30
  from transformers import PreTrainedTokenizerFast
31
  import torch
32
 
33
- MODEL_PATH = "./pytorch_model.bin"
 
 
 
 
34
  TOKENIZER_PATH = "./tokenizer.json"
35
 
 
36
  dzarashield = BertClassifier()
37
- dzarashield.load_state_dict(torch.load(MODEL_PATH))
38
 
 
39
  tokenizer = PreTrainedTokenizerFast(tokenizer_file=TOKENIZER_PATH)
40
 
41
  ```
 
30
  from transformers import PreTrainedTokenizerFast
31
  import torch
32
 
33
+ # Check if a GPU is available
34
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35
+
36
+ # Specify paths
37
+ MODEL_PATH = "./model.pth"
38
  TOKENIZER_PATH = "./tokenizer.json"
39
 
40
+ # Load the model with the appropriate map_location
41
  dzarashield = BertClassifier()
42
+ dzarashield.load_state_dict(torch.load(MODEL_PATH, map_location=device))
43
 
44
+ # Load the tokenizer
45
  tokenizer = PreTrainedTokenizerFast(tokenizer_file=TOKENIZER_PATH)
46
 
47
  ```