Almira commited on
Commit
946dc24
1 Parent(s): 9f57b59

Add wrapper for punctuation

Browse files
sbert-punc-case-ru/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sbertpunccase import SbertPuncCase
sbert-punc-case-ru/sbertpunccase.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import argparse
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+
8
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
9
+ import re
10
+ import string
11
+ from typing import List, Optional
12
+
13
+
14
+ TOKEN_RE = re.compile(r'-?\d*\.\d+|[a-zа-яё]+|-?[\d\+\(\)\-]+|\S', re.I)
15
+ """
16
+ Регулярка, для того чтобы выделять в отдельные токены знаки препинания, числа и слова. А именно:
17
+ - Числа с плавающей точкой вида 123.23 выделяются в один токен. Десятичным разделителем рассматривается только точка
18
+ - Число может быть отрицательным: иметь знак -123.4
19
+ - Целой части числа может вовсе не быть: последовательности -0.15 и −.15 означают одно и то же число.
20
+ - При этом числа с нулевой дробной частью не допускаются: строка "12345." будет разделена на два токена "12345" и "."
21
+ - Идущие подряд знаки препинания выделяются каждый в отдельный токен.
22
+ - Телефонные номера выделяются в один токен +7(999)164-20-69
23
+ - Множество букв в словах ограничивается только кириллическим и англ алфавитом (33 буквы и 26 cоотв).
24
+ """
25
+
26
+ # Прогнозируемые знаки препинания
27
+ PUNK_MAPPING = {'.': 'PERIOD', ',': 'COMMA', '?': 'QUESTION'}
28
+
29
+ # Прогнозируемый регистр LOWER - нижний регистр, UPPER - верхний регистр для первого символа, UPPER_TOTAL - верхний регистр для всех символов
30
+ LABELS_CASE = ['LOWER', 'UPPER', 'UPPER_TOTAL']
31
+ # Добавим в пунктуацию метку O означающий отсутсвие пунктуации
32
+ LABELS_PUNC = ['O'] + list(PUNK_MAPPING.values())
33
+
34
+ # Сформируем метки на основе комбинаций регистра и пунктуации
35
+ LABELS_list = []
36
+ for case in LABELS_CASE:
37
+ for punc in LABELS_PUNC:
38
+ LABELS_list.append(f'{case}_{punc}')
39
+ LABELS = {label: i+1 for i, label in enumerate(LABELS_list)}
40
+ LABELS['O'] = -100
41
+ INVERSE_LABELS = {i: label for label, i in LABELS.items()}
42
+
43
+ LABEL_TO_PUNC_LABEL = {label: label.split('_')[-1] for label in LABELS.keys() if label != 'O'}
44
+ LABEL_TO_CASE_LABEL = {label: '_'.join(label.split('_')[:-1]) for label in LABELS.keys() if label != 'O'}
45
+
46
+
47
+ def token_to_label(token, label):
48
+ if type(label) == int:
49
+ label = INVERSE_LABELS[label]
50
+ if label == 'LOWER_O':
51
+ return token
52
+ if label == 'LOWER_PERIOD':
53
+ return token + '.'
54
+ if label == 'LOWER_COMMA':
55
+ return token + ','
56
+ if label == 'LOWER_QUESTION':
57
+ return token + '?'
58
+ if label == 'UPPER_O':
59
+ return token.capitalize()
60
+ if label == 'UPPER_PERIOD':
61
+ return token.capitalize() + '.'
62
+ if label == 'UPPER_COMMA':
63
+ return token.capitalize() + ','
64
+ if label == 'UPPER_QUESTION':
65
+ return token.capitalize() + '?'
66
+ if label == 'UPPER_TOTAL_O':
67
+ return token.upper()
68
+ if label == 'UPPER_TOTAL_PERIOD':
69
+ return token.upper() + '.'
70
+ if label == 'UPPER_TOTAL_COMMA':
71
+ return token.upper() + ','
72
+ if label == 'UPPER_TOTAL_QUESTION':
73
+ return token.upper() + '?'
74
+ if label == 'O':
75
+ return token
76
+
77
+
78
+ def decode_label(label, classes='all'):
79
+ if classes == 'punc':
80
+ return LABEL_TO_PUNC_LABEL[INVERSE_LABELS[label]]
81
+ if classes == 'case':
82
+ return LABEL_TO_CASE_LABEL[INVERSE_LABELS[label]]
83
+ else:
84
+ return INVERSE_LABELS[label]
85
+
86
+
87
+ def make_labeling(text: str):
88
+ # Разобъем предложение на слова и знаки препинания
89
+ tokens = TOKEN_RE.findall(text)
90
+ # Предобработаем слова, удалим знаки препинания и зададим метки
91
+
92
+ preprocessed_tokens = []
93
+ token_labels: List[List[str]] = []
94
+
95
+ # Убираем всю пунктуацию в начале предложения
96
+ while tokens[0] in string.punctuation:
97
+ tokens.pop(0)
98
+
99
+ for token in tokens:
100
+ if token in string.punctuation:
101
+ # Если встретился знак препинания который мы прогнозируем изменим метку предыдущего слова, иначе проигнорируем его
102
+ if token in PUNK_MAPPING:
103
+ token_labels[-1][1] = PUNK_MAPPING[token]
104
+ else:
105
+ # Если встретилось слово, то укажем метку регистра и добавим в список предобработанных слов в ни��нем регистре
106
+ if sum(char.isupper() for char in token) > 1:
107
+ token_labels.append(['UPPER_TOTAL', 'O'])
108
+ elif token[0].isupper():
109
+ token_labels.append(['UPPER', 'O'])
110
+ else:
111
+ token_labels.append(['LOWER', 'O'])
112
+ preprocessed_tokens.append(token.lower())
113
+ token_labels_merged = ['_'.join(label) for label in token_labels]
114
+ token_labels_ids = [LABELS[label] for label in token_labels_merged]
115
+ return dict(words=preprocessed_tokens, labels=token_labels_merged, label_ids=token_labels_ids)
116
+
117
+
118
+ def align_labels(label_ids: list[int], word_ids: list[Optional[int]]):
119
+ aligned_label_ids = []
120
+ previous_id = None
121
+ for word_id in word_ids:
122
+ if word_id is None or word_id == previous_id:
123
+ aligned_label_ids.append(LABELS['O'])
124
+ else:
125
+ aligned_label_ids.append(label_ids.pop(0))
126
+ previous_id = word_id
127
+ return aligned_label_ids
128
+
129
+
130
+ MODEL_REPO = "kontur-ai/sbert-punc-case-ru"
131
+
132
+
133
+ class SbertPuncCase(nn.Module):
134
+ def __init__(self):
135
+ super().__init__()
136
+
137
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO,
138
+ revision="sbert",
139
+ use_auth_token=True,
140
+ strip_accents=False)
141
+ self.model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO,
142
+ revision="sbert",
143
+ use_auth_token=True
144
+ )
145
+ self.model.eval()
146
+
147
+ def forward(self, input_ids, attention_mask):
148
+ return self.model(input_ids=input_ids,
149
+ attention_mask=attention_mask)
150
+
151
+ def punctuate(self, text):
152
+ text = text.strip().lower()
153
+
154
+ # preprocess
155
+ words_with_labels = make_labeling(text)
156
+ words = words_with_labels['words']
157
+ label_ids = words_with_labels['label_ids']
158
+
159
+ tokenizer_output = self.tokenizer(words, is_split_into_words=True)
160
+ aligned_label_ids = [align_labels(label_ids, tokenizer_output.word_ids())]
161
+
162
+ result = dict(tokenizer_output)
163
+ result.update({'labels': aligned_label_ids})
164
+
165
+ if len(result['input_ids']) > 512:
166
+ return ' '.join([self.punctuate(' '.join(text_part)) for text_part in np.array_split(words, 2)])
167
+
168
+ predictions = self(torch.tensor([result['input_ids']], device=self.model.device),
169
+ torch.tensor([result['attention_mask']], device=self.model.device)).logits.cpu().data.numpy()
170
+ predictions = np.argmax(predictions, axis=2)
171
+
172
+ # decode punctuation and casing
173
+ splitted_text = []
174
+ word_ids = tokenizer_output.word_ids()
175
+ for i, word in enumerate(words):
176
+ label_pos = word_ids.index(i)
177
+ label_id = predictions[0][label_pos]
178
+ label = decode_label(label_id)
179
+ splitted_text.append(token_to_label(word, label))
180
+ capitalized_text = ' '.join(splitted_text)
181
+ return capitalized_text
182
+
183
+
184
+ if __name__ == '__main__':
185
+ parser = argparse.ArgumentParser("Punctuation and case restoration model sbert-punc-case-ru")
186
+ parser.add_argument("-i", "--input", type=str, help="text to restore", default='SbertPuncCase расставляет точки запятые и знаки вопроса вам нравится')
187
+ parser.add_argument("-d", "--device", type=str, help="run model on cpu or gpu", choices=['cpu', 'cuda'], default='cpu')
188
+ args = parser.parse_args()
189
+ print(f"Source text: {args.input}\n")
190
+ sbertpunc = SbertPuncCase().to(args.device)
191
+ punctuated_text = sbertpunc.punctuate(args.input)
192
+ print(f"Restored text: {punctuated_text}")
setup.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from distutils.core import setup
2
+
3
+ setup(name='sbert-punc-case-ru',
4
+ version='0.1',
5
+ description='Punctuation and Case Restoration model based on https://huggingface.co/sberbank-ai/sbert_large_nlu_ru',
6
+ author='Almira Murtazina',
7
+ author_email='ar.murtazina@skbkontur.ru',
8
+ packages=['sbert-punc-case-ru'],
9
+ install_requires=['transformers>=4.18.3'],
10
+ classifiers=[
11
+ "Operating System :: OS Independent",
12
+ "Programming Language :: Python :: 3",
13
+ "Programming Language :: Python :: 3.6",
14
+ "Programming Language :: Python :: 3.7",
15
+ "Programming Language :: Python :: 3.8",
16
+ "Programming Language :: Python :: 3.9",
17
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
18
+ ]
19
+ )