File size: 8,923 Bytes
946dc24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# -*- coding: utf-8 -*-

import argparse
import torch
import torch.nn as nn
import numpy as np

from transformers import AutoTokenizer, AutoModelForTokenClassification
import re
import string
from typing import List, Optional


TOKEN_RE = re.compile(r'-?\d*\.\d+|[a-zа-яё]+|-?[\d\+\(\)\-]+|\S', re.I)
"""
Регулярка, для того чтобы выделять в отдельные токены знаки препинания, числа и слова. А именно:
- Числа с плавающей точкой вида 123.23 выделяются в один токен. Десятичным разделителем рассматривается только точка
- Число может быть отрицательным: иметь знак -123.4
- Целой части числа может вовсе не быть: последовательности  -0.15 и −.15   означают одно и то же число.
- При этом числа с нулевой дробной частью не допускаются:  строка "12345." будет разделена на два токена "12345" и "."
- Идущие подряд знаки препинания выделяются каждый в отдельный токен.
- Телефонные номера выделяются в один токен +7(999)164-20-69
- Множество букв в словах ограничивается только кириллическим и англ алфавитом (33 буквы и 26 cоотв).
"""

# Прогнозируемые знаки препинания
PUNK_MAPPING = {'.': 'PERIOD', ',': 'COMMA', '?': 'QUESTION'}

# Прогнозируемый регистр LOWER - нижний регистр, UPPER - верхний регистр для первого символа, UPPER_TOTAL - верхний регистр для всех символов
LABELS_CASE = ['LOWER', 'UPPER', 'UPPER_TOTAL']
# Добавим в пунктуацию метку O означающий отсутсвие пунктуации
LABELS_PUNC = ['O'] + list(PUNK_MAPPING.values())

# Сформируем метки на основе комбинаций регистра и пунктуации
LABELS_list = []
for case in LABELS_CASE:
    for punc in LABELS_PUNC:
        LABELS_list.append(f'{case}_{punc}')
LABELS = {label: i+1 for i, label in enumerate(LABELS_list)}
LABELS['O'] = -100
INVERSE_LABELS = {i: label for label, i in LABELS.items()}

LABEL_TO_PUNC_LABEL = {label: label.split('_')[-1] for label in LABELS.keys() if label != 'O'}
LABEL_TO_CASE_LABEL = {label: '_'.join(label.split('_')[:-1]) for label in LABELS.keys() if label != 'O'}


def token_to_label(token, label):
    if type(label) == int:
        label = INVERSE_LABELS[label]
    if label == 'LOWER_O':
        return token
    if label == 'LOWER_PERIOD':
        return token + '.'
    if label == 'LOWER_COMMA':
        return token + ','
    if label == 'LOWER_QUESTION':
        return token + '?'
    if label == 'UPPER_O':
        return token.capitalize()
    if label == 'UPPER_PERIOD':
        return token.capitalize() + '.'
    if label == 'UPPER_COMMA':
        return token.capitalize() + ','
    if label == 'UPPER_QUESTION':
        return token.capitalize() + '?'
    if label == 'UPPER_TOTAL_O':
        return token.upper()
    if label == 'UPPER_TOTAL_PERIOD':
        return token.upper() + '.'
    if label == 'UPPER_TOTAL_COMMA':
        return token.upper() + ','
    if label == 'UPPER_TOTAL_QUESTION':
        return token.upper() + '?'
    if label == 'O':
        return token


def decode_label(label, classes='all'):
    if classes == 'punc':
        return LABEL_TO_PUNC_LABEL[INVERSE_LABELS[label]]
    if classes == 'case':
        return LABEL_TO_CASE_LABEL[INVERSE_LABELS[label]]
    else:
        return INVERSE_LABELS[label]


def make_labeling(text: str):
    # Разобъем предложение на слова и знаки препинания
    tokens = TOKEN_RE.findall(text)
    # Предобработаем слова, удалим знаки препинания и зададим метки

    preprocessed_tokens = []
    token_labels: List[List[str]] = []

    # Убираем всю пунктуацию в начале предложения
    while tokens[0] in string.punctuation:
        tokens.pop(0)

    for token in tokens:
        if token in string.punctuation:
            # Если встретился знак препинания который мы прогнозируем изменим метку предыдущего слова, иначе проигнорируем его
            if token in PUNK_MAPPING:
                token_labels[-1][1] = PUNK_MAPPING[token]
        else:
            # Если встретилось слово, то укажем метку регистра и добавим в список предобработанных слов в нижнем регистре
            if sum(char.isupper() for char in token) > 1:
                token_labels.append(['UPPER_TOTAL', 'O'])
            elif token[0].isupper():
                token_labels.append(['UPPER', 'O'])
            else:
                token_labels.append(['LOWER', 'O'])
            preprocessed_tokens.append(token.lower())
    token_labels_merged = ['_'.join(label) for label in token_labels]
    token_labels_ids = [LABELS[label] for label in token_labels_merged]
    return dict(words=preprocessed_tokens, labels=token_labels_merged, label_ids=token_labels_ids)


def align_labels(label_ids: list[int], word_ids: list[Optional[int]]):
    aligned_label_ids = []
    previous_id = None
    for word_id in word_ids:
        if word_id is None or word_id == previous_id:
            aligned_label_ids.append(LABELS['O'])
        else:
            aligned_label_ids.append(label_ids.pop(0))
        previous_id = word_id
    return aligned_label_ids


MODEL_REPO = "kontur-ai/sbert-punc-case-ru"


class SbertPuncCase(nn.Module):
    def __init__(self):
        super().__init__()

        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO,
                                                       revision="sbert",
                                                       use_auth_token=True,
                                                       strip_accents=False)
        self.model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO,
                                                                     revision="sbert",
                                                                     use_auth_token=True
                                                                     )
        self.model.eval()

    def forward(self, input_ids, attention_mask):
        return self.model(input_ids=input_ids,
                          attention_mask=attention_mask)

    def punctuate(self, text):
        text = text.strip().lower()

        # preprocess
        words_with_labels = make_labeling(text)
        words = words_with_labels['words']
        label_ids = words_with_labels['label_ids']

        tokenizer_output = self.tokenizer(words, is_split_into_words=True)
        aligned_label_ids = [align_labels(label_ids, tokenizer_output.word_ids())]

        result = dict(tokenizer_output)
        result.update({'labels': aligned_label_ids})

        if len(result['input_ids']) > 512:
            return ' '.join([self.punctuate(' '.join(text_part)) for text_part in np.array_split(words, 2)])

        predictions = self(torch.tensor([result['input_ids']], device=self.model.device),
                           torch.tensor([result['attention_mask']], device=self.model.device)).logits.cpu().data.numpy()
        predictions = np.argmax(predictions, axis=2)

        # decode punctuation and casing
        splitted_text = []
        word_ids = tokenizer_output.word_ids()
        for i, word in enumerate(words):
            label_pos = word_ids.index(i)
            label_id = predictions[0][label_pos]
            label = decode_label(label_id)
            splitted_text.append(token_to_label(word, label))
        capitalized_text = ' '.join(splitted_text)
        return capitalized_text


if __name__ == '__main__':
    parser = argparse.ArgumentParser("Punctuation and case restoration model sbert-punc-case-ru")
    parser.add_argument("-i", "--input", type=str, help="text to restore", default='SbertPuncCase расставляет точки запятые и знаки вопроса вам нравится')
    parser.add_argument("-d", "--device", type=str, help="run model on cpu or gpu", choices=['cpu', 'cuda'], default='cpu')
    args = parser.parse_args()
    print(f"Source text:   {args.input}\n")
    sbertpunc = SbertPuncCase().to(args.device)
    punctuated_text = sbertpunc.punctuate(args.input)
    print(f"Restored text: {punctuated_text}")