Transformers
PyTorch
Russian
bert
Inference Endpoints
Edit model card

RuTextSegModel

Model for Russian text segmentation, trained on wiki and news corpora

Model description

This model is a top-level part of HierBERT model and solves the problem of text segmentation as a token classification at the sentence level. The ai-forever/sbert_large_nlu_ru with max pooling is used as a low-level model (sentence embedding generator). It's recommended to use this model only with specified low-level model with defined pooling for embeddings.

Intended uses & limitations

How to use

Here is how to use this model in PyTorch:

import torch
import torch.nn as nn
from transformers import BertForTokenClassification, AutoModel, AutoTokenizer
from razdel import sentenize

class BertForTextSegmentationEmbeddings(nn.Module):
    def __init__(self, config, embeddings_dim=768):
        super(BertForTextSegmentationEmbeddings, self).__init__()

        self.config = config
        self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size)

        self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)

    def forward(self, inputs_embeds, position_ids=None, input_ids=None, token_type_ids=None, past_key_values_length=None):
        input_shape = inputs_embeds.size()[:-1]
        seq_length = input_shape[1]
        device = inputs_embeds.device

        assert seq_length <= self.config.max_position_embeddings, \
            f"Too long sequence is passed {seq_length}. Maximum allowed sequence length is {self.config.max_position_embeddings}"

        if position_ids is None:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).expand(input_shape)

        position_embeddings = self.position_embeddings(position_ids)

        embeddings = inputs_embeds + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

class BertForTextSegmentation(BertForTokenClassification):
    def __init__(self, config):
        super(BertForTextSegmentation, self).__init__(config)

        self.bert.base_model.embeddings = BertForTextSegmentationEmbeddings(config)

        self.init_weights()
        
def max_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    token_embeddings[input_mask_expanded == 0] = -1e9  # Set padding tokens to large negative value
    return torch.max(token_embeddings, 1)[0]

def create_embeddings(sentences, tokenizer, model):
    # Tokenize sentences
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input.to(device))
    # Perform pooling. In this case, max pooling.
    sentence_embeddings = max_pooling(model_output, encoded_input['attention_mask'])
    
    return sentence_embeddings

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

emb_tokenizer = AutoTokenizer.from_pretrained("ai-forever/sbert_large_nlu_ru")
emb_model = AutoModel.from_pretrained("ai-forever/sbert_large_nlu_ru")
model = BertForTextSegmentation.from_pretrained("mlenjoyneer/RuTextSegModel")

emb_model.to(device)
model.to(device)

text = """В Норильске за годы работы телефона доверия консультанты приняли в общей сложности порядка 75 тысяч обращений, сообщает «Заполярная Правда». Служба психологической помощи появилась в 2000 году. Руководитель службы профилактики наркомании Елена Слатвицкая рассказала журналистам, что в Заполярье настал период, когда ухудшается психо– эмоциональное состояние населения. Это происходит на входе в полярную ночь и на выходе из нее. Осень является кризисным моментом. Сейчас на телефоне доверия работают 15 специалистов. Каждый — под своим псевдонимом. Тему беседы определяет звонящий. Это могут быть наркомания и алкоголизм, ВИЧ–инфекция и прочие заболевания и зависимости, кризисы семейных отношений и многое другое. Сотрудники службы отметили, что больше стало звонков по поводу суицидальных намерений. Наибольшее количество обращений по суицидам пришлось на октябрь — ноябрь. Много звонков как от мужчин, так и от женщин с вопросами об одиночестве. Лидерами по количеству обращений пока остаются женщины. В сентябре в Норильске обнаружили тело девятиклассницы. По версии следствия, девочка сбросилась с крыши. В январе подросток нанес себе порезы стеклом от разбитой бутылки, пытаясь покончить с собой. Мальчик поссорился с матерью и в ходе ссоры нанес себе несколько порезов. Проводится расследование."""

input_embeds = create_embeddings([s.text for s in sentenize(text)], emb_tokenizer, emb_model).unsqueeze(0)
outputs = model(inputs_embeds=input_embeds)

logits = outputs.logits.cpu()
preds = logits.argmax(axis=2).tolist()[0]  # [0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]

# true_labels = [0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]

Training data

Model trained on mlenjoyneer/RuTextSegNews and mlenjoyneer/RuTextSegWiki datasets.

Evaluation results

Train Dataset Test Dataset F1_total F1_1 Pk Pk_5 WinDiff WinDiff_5
News+Wiki News 0.88 0.80 0.16 0.11 0.20 0.35
News+Wiki Wiki 0.89 0.80 0.18 0.16 0.09 0.19

Citation info

In progress
Downloads last month
6
Inference API
Unable to determine this model’s pipeline type. Check the docs .

Datasets used to train mlenjoyneer/RuTextSegModel