beweinreich commited on
Commit
b617b65
1 Parent(s): 99061c5

use gpu if available

Browse files
Files changed (1) hide show
  1. similarity_fast.py +2 -1
similarity_fast.py CHANGED
@@ -19,7 +19,8 @@ category_pickle_file_path = f'./embeddings/fast/{filename}-categories.pkl'
19
  class SimilarityFast:
20
  def __init__(self, db_cursor):
21
  self.db_cursor = db_cursor
22
- self.model = SentenceTransformer(model_name)
 
23
 
24
  self.db_cursor.execute("SELECT description FROM dictionary")
25
  dictionary = self.db_cursor.fetchall()
 
19
  class SimilarityFast:
20
  def __init__(self, db_cursor):
21
  self.db_cursor = db_cursor
22
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ self.model = SentenceTransformer(model_name).to(self.device)
24
 
25
  self.db_cursor.execute("SELECT description FROM dictionary")
26
  dictionary = self.db_cursor.fetchall()