Ruslan-DS commited on
Commit
16aa87c
1 Parent(s): e20906a

Update models/LogReg.py

Browse files
Files changed (1) hide show
  1. models/LogReg.py +2 -2
models/LogReg.py CHANGED
@@ -6,7 +6,7 @@ from models.preprocess_stage.bert_model import preprocess_bert
6
  from models.preprocess_stage.bert_model import model
7
 
8
  MAX_LEN = 100 # позже добавлю способ пользователю самому выбирать масимальную длину
9
- DEVICE='cpu'
10
 
11
  logreg = joblib.load('models/weights/LogRegBestWeights.sav')
12
 
@@ -15,7 +15,7 @@ def predict_1(text):
15
  preprocessed_text, attention_mask = preprocess_bert(text, MAX_LEN=MAX_LEN)
16
  preprocessed_text, attention_mask = torch.tensor(preprocessed_text).unsqueeze(0), torch.tensor([attention_mask])
17
 
18
- model.to(DEVICE)
19
  with torch.inference_mode():
20
  vector = model(preprocessed_text, attention_mask=attention_mask)[0][:, 0, :]
21
 
 
6
  from models.preprocess_stage.bert_model import model
7
 
8
  MAX_LEN = 100 # позже добавлю способ пользователю самому выбирать масимальную длину
9
+ # DEVICE='cpu'
10
 
11
  logreg = joblib.load('models/weights/LogRegBestWeights.sav')
12
 
 
15
  preprocessed_text, attention_mask = preprocess_bert(text, MAX_LEN=MAX_LEN)
16
  preprocessed_text, attention_mask = torch.tensor(preprocessed_text).unsqueeze(0), torch.tensor([attention_mask])
17
 
18
+ # model.to(DEVICE)
19
  with torch.inference_mode():
20
  vector = model(preprocessed_text, attention_mask=attention_mask)[0][:, 0, :]
21