Ruslan-DS commited on
Commit
b40a0e4
1 Parent(s): a70ae06

Update models/BertTunning.py

Browse files
Files changed (1) hide show
  1. models/BertTunning.py +2 -0
models/BertTunning.py CHANGED
@@ -4,6 +4,7 @@ from models.preprocess_stage.bert_model import model
4
  from models.preprocess_stage.bert_model import preprocess_bert
5
 
6
  MAX_LEN = 100
 
7
 
8
  class BertTunnig(nn.Module):
9
  def __init__(self, bert_model):
@@ -40,6 +41,7 @@ def predict_2(text):
40
  preprocessed_text, attention_mask = preprocess_bert(text, MAX_LEN=MAX_LEN)
41
  preprocessed_text, attention_mask = torch.tensor(preprocessed_text).unsqueeze(0), torch.tensor([attention_mask])
42
 
 
43
  with torch.inference_mode():
44
 
45
  predict = round(model_tunning(preprocessed_text, attention_mask=attention_mask).item())
 
4
  from models.preprocess_stage.bert_model import preprocess_bert
5
 
6
  MAX_LEN = 100
7
+ DEVICE='cpu'
8
 
9
  class BertTunnig(nn.Module):
10
  def __init__(self, bert_model):
 
41
  preprocessed_text, attention_mask = preprocess_bert(text, MAX_LEN=MAX_LEN)
42
  preprocessed_text, attention_mask = torch.tensor(preprocessed_text).unsqueeze(0), torch.tensor([attention_mask])
43
 
44
+ model_tunning.to(DEVICE)
45
  with torch.inference_mode():
46
 
47
  predict = round(model_tunning(preprocessed_text, attention_mask=attention_mask).item())