Ruslan-DS's picture
Update models/preprocess_stage/bert_model.py
4718e6a
import numpy as np
import torch
from transformers import BertModel, BertTokenizer
CHECKPOINT = 'DeepPavlov/rubert-base-cased'
tokenizer = BertTokenizer.from_pretrained(CHECKPOINT)
model = BertModel.from_pretrained(CHECKPOINT)
def preprocess_bert(text, MAX_LEN):
tokenized_text = tokenizer.encode(
text=text,
add_special_tokens=True,
truncation=True,
max_length=MAX_LEN
)
padded_text = np.array(tokenized_text + [0] * (MAX_LEN - len(tokenized_text)))
attention_mask = np.where(padded_text != 0, 1, 0)
return padded_text, attention_mask