File size: 592 Bytes
60a2954
 
 
 
4718e6a
 
 
60a2954
 
 
 
4718e6a
60a2954
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import numpy as np
import torch
from transformers import BertModel, BertTokenizer

CHECKPOINT = 'DeepPavlov/rubert-base-cased'
tokenizer = BertTokenizer.from_pretrained(CHECKPOINT)
model = BertModel.from_pretrained(CHECKPOINT)


def preprocess_bert(text, MAX_LEN):

    tokenized_text = tokenizer.encode(
        text=text,
        add_special_tokens=True,
        truncation=True,
        max_length=MAX_LEN
    )

    padded_text = np.array(tokenized_text + [0] * (MAX_LEN - len(tokenized_text)))
    attention_mask = np.where(padded_text != 0, 1, 0)

    return padded_text, attention_mask