Ruslan-DS commited on
Commit
ca126f2
1 Parent(s): 3c80cef

Update models/LogReg.py

Browse files
Files changed (1) hide show
  1. models/LogReg.py +23 -0
models/LogReg.py CHANGED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ import joblib
4
+ from transformers import BertModel, BertTokenizer
5
+
6
+ from models.preprocess_stage.bert_model import preprocess_bert
7
+ from models.preprocess_stage.bert_model import model
8
+
9
+ MAX_LEN = 100 # позже добавлю способ пользователю самому выбирать масимальную длину
10
+
11
+ logreg = joblib.load('models/weights/LogRegBestWeights.sav')
12
+
13
+ def predict_1(text):
14
+
15
+ preprocessed_text, attention_mask = preprocess_bert(text, MAX_LEN=MAX_LEN)
16
+ preprocessed_text, attention_mask = torch.tensor(preprocessed_text).unsqueeze(0), torch.tensor([attention_mask])
17
+
18
+ with torch.inference_mode():
19
+ vector = model(preprocessed_text, attention_mask=attention_mask)[0][:, 0, :]
20
+
21
+ predict = logreg.predict(vector)
22
+
23
+ return predict[-1]