Almira's picture
Add wrapper for punctuation
946dc24
raw
history blame
8.92 kB
# -*- coding: utf-8 -*-
import argparse
import torch
import torch.nn as nn
import numpy as np
from transformers import AutoTokenizer, AutoModelForTokenClassification
import re
import string
from typing import List, Optional
TOKEN_RE = re.compile(r'-?\d*\.\d+|[a-zа-яё]+|-?[\d\+\(\)\-]+|\S', re.I)
"""
Регулярка, для того чтобы выделять в отдельные токены знаки препинания, числа и слова. А именно:
- Числа с плавающей точкой вида 123.23 выделяются в один токен. Десятичным разделителем рассматривается только точка
- Число может быть отрицательным: иметь знак -123.4
- Целой части числа может вовсе не быть: последовательности -0.15 и −.15 означают одно и то же число.
- При этом числа с нулевой дробной частью не допускаются: строка "12345." будет разделена на два токена "12345" и "."
- Идущие подряд знаки препинания выделяются каждый в отдельный токен.
- Телефонные номера выделяются в один токен +7(999)164-20-69
- Множество букв в словах ограничивается только кириллическим и англ алфавитом (33 буквы и 26 cоотв).
"""
# Прогнозируемые знаки препинания
PUNK_MAPPING = {'.': 'PERIOD', ',': 'COMMA', '?': 'QUESTION'}
# Прогнозируемый регистр LOWER - нижний регистр, UPPER - верхний регистр для первого символа, UPPER_TOTAL - верхний регистр для всех символов
LABELS_CASE = ['LOWER', 'UPPER', 'UPPER_TOTAL']
# Добавим в пунктуацию метку O означающий отсутсвие пунктуации
LABELS_PUNC = ['O'] + list(PUNK_MAPPING.values())
# Сформируем метки на основе комбинаций регистра и пунктуации
LABELS_list = []
for case in LABELS_CASE:
for punc in LABELS_PUNC:
LABELS_list.append(f'{case}_{punc}')
LABELS = {label: i+1 for i, label in enumerate(LABELS_list)}
LABELS['O'] = -100
INVERSE_LABELS = {i: label for label, i in LABELS.items()}
LABEL_TO_PUNC_LABEL = {label: label.split('_')[-1] for label in LABELS.keys() if label != 'O'}
LABEL_TO_CASE_LABEL = {label: '_'.join(label.split('_')[:-1]) for label in LABELS.keys() if label != 'O'}
def token_to_label(token, label):
if type(label) == int:
label = INVERSE_LABELS[label]
if label == 'LOWER_O':
return token
if label == 'LOWER_PERIOD':
return token + '.'
if label == 'LOWER_COMMA':
return token + ','
if label == 'LOWER_QUESTION':
return token + '?'
if label == 'UPPER_O':
return token.capitalize()
if label == 'UPPER_PERIOD':
return token.capitalize() + '.'
if label == 'UPPER_COMMA':
return token.capitalize() + ','
if label == 'UPPER_QUESTION':
return token.capitalize() + '?'
if label == 'UPPER_TOTAL_O':
return token.upper()
if label == 'UPPER_TOTAL_PERIOD':
return token.upper() + '.'
if label == 'UPPER_TOTAL_COMMA':
return token.upper() + ','
if label == 'UPPER_TOTAL_QUESTION':
return token.upper() + '?'
if label == 'O':
return token
def decode_label(label, classes='all'):
if classes == 'punc':
return LABEL_TO_PUNC_LABEL[INVERSE_LABELS[label]]
if classes == 'case':
return LABEL_TO_CASE_LABEL[INVERSE_LABELS[label]]
else:
return INVERSE_LABELS[label]
def make_labeling(text: str):
# Разобъем предложение на слова и знаки препинания
tokens = TOKEN_RE.findall(text)
# Предобработаем слова, удалим знаки препинания и зададим метки
preprocessed_tokens = []
token_labels: List[List[str]] = []
# Убираем всю пунктуацию в начале предложения
while tokens[0] in string.punctuation:
tokens.pop(0)
for token in tokens:
if token in string.punctuation:
# Если встретился знак препинания который мы прогнозируем изменим метку предыдущего слова, иначе проигнорируем его
if token in PUNK_MAPPING:
token_labels[-1][1] = PUNK_MAPPING[token]
else:
# Если встретилось слово, то укажем метку регистра и добавим в список предобработанных слов в нижнем регистре
if sum(char.isupper() for char in token) > 1:
token_labels.append(['UPPER_TOTAL', 'O'])
elif token[0].isupper():
token_labels.append(['UPPER', 'O'])
else:
token_labels.append(['LOWER', 'O'])
preprocessed_tokens.append(token.lower())
token_labels_merged = ['_'.join(label) for label in token_labels]
token_labels_ids = [LABELS[label] for label in token_labels_merged]
return dict(words=preprocessed_tokens, labels=token_labels_merged, label_ids=token_labels_ids)
def align_labels(label_ids: list[int], word_ids: list[Optional[int]]):
aligned_label_ids = []
previous_id = None
for word_id in word_ids:
if word_id is None or word_id == previous_id:
aligned_label_ids.append(LABELS['O'])
else:
aligned_label_ids.append(label_ids.pop(0))
previous_id = word_id
return aligned_label_ids
MODEL_REPO = "kontur-ai/sbert-punc-case-ru"
class SbertPuncCase(nn.Module):
def __init__(self):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO,
revision="sbert",
use_auth_token=True,
strip_accents=False)
self.model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO,
revision="sbert",
use_auth_token=True
)
self.model.eval()
def forward(self, input_ids, attention_mask):
return self.model(input_ids=input_ids,
attention_mask=attention_mask)
def punctuate(self, text):
text = text.strip().lower()
# preprocess
words_with_labels = make_labeling(text)
words = words_with_labels['words']
label_ids = words_with_labels['label_ids']
tokenizer_output = self.tokenizer(words, is_split_into_words=True)
aligned_label_ids = [align_labels(label_ids, tokenizer_output.word_ids())]
result = dict(tokenizer_output)
result.update({'labels': aligned_label_ids})
if len(result['input_ids']) > 512:
return ' '.join([self.punctuate(' '.join(text_part)) for text_part in np.array_split(words, 2)])
predictions = self(torch.tensor([result['input_ids']], device=self.model.device),
torch.tensor([result['attention_mask']], device=self.model.device)).logits.cpu().data.numpy()
predictions = np.argmax(predictions, axis=2)
# decode punctuation and casing
splitted_text = []
word_ids = tokenizer_output.word_ids()
for i, word in enumerate(words):
label_pos = word_ids.index(i)
label_id = predictions[0][label_pos]
label = decode_label(label_id)
splitted_text.append(token_to_label(word, label))
capitalized_text = ' '.join(splitted_text)
return capitalized_text
if __name__ == '__main__':
parser = argparse.ArgumentParser("Punctuation and case restoration model sbert-punc-case-ru")
parser.add_argument("-i", "--input", type=str, help="text to restore", default='SbertPuncCase расставляет точки запятые и знаки вопроса вам нравится')
parser.add_argument("-d", "--device", type=str, help="run model on cpu or gpu", choices=['cpu', 'cuda'], default='cpu')
args = parser.parse_args()
print(f"Source text: {args.input}\n")
sbertpunc = SbertPuncCase().to(args.device)
punctuated_text = sbertpunc.punctuate(args.input)
print(f"Restored text: {punctuated_text}")