Ruslan-DS commited on
Commit
af4c8a8
1 Parent(s): a0bdf1c

Update models/LogReg.py

Browse files
Files changed (1) hide show
  1. models/LogReg.py +2 -0
models/LogReg.py CHANGED
@@ -6,6 +6,7 @@ from models.preprocess_stage.bert_model import preprocess_bert
6
  from models.preprocess_stage.bert_model import model
7
 
8
  MAX_LEN = 100 # позже добавлю способ пользователю самому выбирать масимальную длину
 
9
 
10
  logreg = joblib.load('models/weights/LogRegBestWeights.sav')
11
 
@@ -14,6 +15,7 @@ def predict_1(text):
14
  preprocessed_text, attention_mask = preprocess_bert(text, MAX_LEN=MAX_LEN)
15
  preprocessed_text, attention_mask = torch.tensor(preprocessed_text).unsqueeze(0), torch.tensor([attention_mask])
16
 
 
17
  with torch.inference_mode():
18
  vector = model(preprocessed_text, attention_mask=attention_mask)[0][:, 0, :]
19
 
 
6
  from models.preprocess_stage.bert_model import model
7
 
8
  MAX_LEN = 100 # позже добавлю способ пользователю самому выбирать масимальную длину
9
+ DEVICE='cpu'
10
 
11
  logreg = joblib.load('models/weights/LogRegBestWeights.sav')
12
 
 
15
  preprocessed_text, attention_mask = preprocess_bert(text, MAX_LEN=MAX_LEN)
16
  preprocessed_text, attention_mask = torch.tensor(preprocessed_text).unsqueeze(0), torch.tensor([attention_mask])
17
 
18
+ model.to(DEVICE)
19
  with torch.inference_mode():
20
  vector = model(preprocessed_text, attention_mask=attention_mask)[0][:, 0, :]
21