# -*- coding: utf-8 -*- import argparse import torch import torch.nn as nn import numpy as np from transformers import AutoTokenizer, AutoModelForTokenClassification import re import string from typing import List, Optional TOKEN_RE = re.compile(r'-?\d*\.\d+|[a-zа-яё]+|-?[\d\+\(\)\-]+|\S', re.I) """ Регулярка, для того чтобы выделять в отдельные токены знаки препинания, числа и слова. А именно: - Числа с плавающей точкой вида 123.23 выделяются в один токен. Десятичным разделителем рассматривается только точка - Число может быть отрицательным: иметь знак -123.4 - Целой части числа может вовсе не быть: последовательности -0.15 и −.15 означают одно и то же число. - При этом числа с нулевой дробной частью не допускаются: строка "12345." будет разделена на два токена "12345" и "." - Идущие подряд знаки препинания выделяются каждый в отдельный токен. - Телефонные номера выделяются в один токен +7(999)164-20-69 - Множество букв в словах ограничивается только кириллическим и англ алфавитом (33 буквы и 26 cоотв). """ # Прогнозируемые знаки препинания PUNK_MAPPING = {'.': 'PERIOD', ',': 'COMMA', '?': 'QUESTION'} # Прогнозируемый регистр LOWER - нижний регистр, UPPER - верхний регистр для первого символа, UPPER_TOTAL - верхний регистр для всех символов LABELS_CASE = ['LOWER', 'UPPER', 'UPPER_TOTAL'] # Добавим в пунктуацию метку O означающий отсутсвие пунктуации LABELS_PUNC = ['O'] + list(PUNK_MAPPING.values()) # Сформируем метки на основе комбинаций регистра и пунктуации LABELS_list = [] for case in LABELS_CASE: for punc in LABELS_PUNC: LABELS_list.append(f'{case}_{punc}') LABELS = {label: i+1 for i, label in enumerate(LABELS_list)} LABELS['O'] = -100 INVERSE_LABELS = {i: label for label, i in LABELS.items()} LABEL_TO_PUNC_LABEL = {label: label.split('_')[-1] for label in LABELS.keys() if label != 'O'} LABEL_TO_CASE_LABEL = {label: '_'.join(label.split('_')[:-1]) for label in LABELS.keys() if label != 'O'} def token_to_label(token, label): if type(label) == int: label = INVERSE_LABELS[label] if label == 'LOWER_O': return token if label == 'LOWER_PERIOD': return token + '.' if label == 'LOWER_COMMA': return token + ',' if label == 'LOWER_QUESTION': return token + '?' if label == 'UPPER_O': return token.capitalize() if label == 'UPPER_PERIOD': return token.capitalize() + '.' if label == 'UPPER_COMMA': return token.capitalize() + ',' if label == 'UPPER_QUESTION': return token.capitalize() + '?' if label == 'UPPER_TOTAL_O': return token.upper() if label == 'UPPER_TOTAL_PERIOD': return token.upper() + '.' if label == 'UPPER_TOTAL_COMMA': return token.upper() + ',' if label == 'UPPER_TOTAL_QUESTION': return token.upper() + '?' if label == 'O': return token def decode_label(label, classes='all'): if classes == 'punc': return LABEL_TO_PUNC_LABEL[INVERSE_LABELS[label]] if classes == 'case': return LABEL_TO_CASE_LABEL[INVERSE_LABELS[label]] else: return INVERSE_LABELS[label] def make_labeling(text: str): # Разобъем предложение на слова и знаки препинания tokens = TOKEN_RE.findall(text) # Предобработаем слова, удалим знаки препинания и зададим метки preprocessed_tokens = [] token_labels: List[List[str]] = [] # Убираем всю пунктуацию в начале предложения while tokens[0] in string.punctuation: tokens.pop(0) for token in tokens: if token in string.punctuation: # Если встретился знак препинания который мы прогнозируем изменим метку предыдущего слова, иначе проигнорируем его if token in PUNK_MAPPING: token_labels[-1][1] = PUNK_MAPPING[token] else: # Если встретилось слово, то укажем метку регистра и добавим в список предобработанных слов в нижнем регистре if sum(char.isupper() for char in token) > 1: token_labels.append(['UPPER_TOTAL', 'O']) elif token[0].isupper(): token_labels.append(['UPPER', 'O']) else: token_labels.append(['LOWER', 'O']) preprocessed_tokens.append(token.lower()) token_labels_merged = ['_'.join(label) for label in token_labels] token_labels_ids = [LABELS[label] for label in token_labels_merged] return dict(words=preprocessed_tokens, labels=token_labels_merged, label_ids=token_labels_ids) def align_labels(label_ids: list[int], word_ids: list[Optional[int]]): aligned_label_ids = [] previous_id = None for word_id in word_ids: if word_id is None or word_id == previous_id: aligned_label_ids.append(LABELS['O']) else: aligned_label_ids.append(label_ids.pop(0)) previous_id = word_id return aligned_label_ids MODEL_REPO = "kontur-ai/sbert-punc-case-ru" class SbertPuncCase(nn.Module): def __init__(self): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, revision="sbert", use_auth_token=True, strip_accents=False) self.model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO, revision="sbert", use_auth_token=True ) self.model.eval() def forward(self, input_ids, attention_mask): return self.model(input_ids=input_ids, attention_mask=attention_mask) def punctuate(self, text): text = text.strip().lower() # preprocess words_with_labels = make_labeling(text) words = words_with_labels['words'] label_ids = words_with_labels['label_ids'] tokenizer_output = self.tokenizer(words, is_split_into_words=True) aligned_label_ids = [align_labels(label_ids, tokenizer_output.word_ids())] result = dict(tokenizer_output) result.update({'labels': aligned_label_ids}) if len(result['input_ids']) > 512: return ' '.join([self.punctuate(' '.join(text_part)) for text_part in np.array_split(words, 2)]) predictions = self(torch.tensor([result['input_ids']], device=self.model.device), torch.tensor([result['attention_mask']], device=self.model.device)).logits.cpu().data.numpy() predictions = np.argmax(predictions, axis=2) # decode punctuation and casing splitted_text = [] word_ids = tokenizer_output.word_ids() for i, word in enumerate(words): label_pos = word_ids.index(i) label_id = predictions[0][label_pos] label = decode_label(label_id) splitted_text.append(token_to_label(word, label)) capitalized_text = ' '.join(splitted_text) return capitalized_text if __name__ == '__main__': parser = argparse.ArgumentParser("Punctuation and case restoration model sbert-punc-case-ru") parser.add_argument("-i", "--input", type=str, help="text to restore", default='SbertPuncCase расставляет точки запятые и знаки вопроса вам нравится') parser.add_argument("-d", "--device", type=str, help="run model on cpu or gpu", choices=['cpu', 'cuda'], default='cpu') args = parser.parse_args() print(f"Source text: {args.input}\n") sbertpunc = SbertPuncCase().to(args.device) punctuated_text = sbertpunc.punctuate(args.input) print(f"Restored text: {punctuated_text}")