TeraSpace commited on
Commit
4cd2bd2
1 Parent(s): ed20cbd

Update ruaccent/omograph_model.py

Browse files
Files changed (1) hide show
  1. ruaccent/omograph_model.py +2 -2
ruaccent/omograph_model.py CHANGED
@@ -2,8 +2,8 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer
2
  import torch
3
 
4
  class OmographModel:
5
- def __init__(self) -> None:
6
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
 
8
  def load(self, path):
9
  self.nli_model = AutoModelForSequenceClassification.from_pretrained(path, torch_dtype=torch.bfloat16).to(self.device)
 
2
  import torch
3
 
4
  class OmographModel:
5
+ def __init__(self, allow_cuda=True) -> None:
6
+ self.device = torch.device('cuda' if torch.cuda.is_available() and allow_cuda else 'cpu')
7
 
8
  def load(self, path):
9
  self.nli_model = AutoModelForSequenceClassification.from_pretrained(path, torch_dtype=torch.bfloat16).to(self.device)