Almira's picture
Split words by spaces rather than regexp
a45aa62
raw
history blame
No virus
5.23 kB
# -*- coding: utf-8 -*-
import argparse
import torch
import torch.nn as nn
import numpy as np
from transformers import AutoTokenizer, AutoModelForTokenClassification
# Прогнозируемые знаки препинания
PUNK_MAPPING = {'.': 'PERIOD', ',': 'COMMA', '?': 'QUESTION'}
# Прогнозируемый регистр LOWER - нижний регистр, UPPER - верхний регистр для первого символа,
# UPPER_TOTAL - верхний регистр для всех символов
LABELS_CASE = ['LOWER', 'UPPER', 'UPPER_TOTAL']
# Добавим в пунктуацию метку O означающий отсутсвие пунктуации
LABELS_PUNC = ['O'] + list(PUNK_MAPPING.values())
# Сформируем метки на основе комбинаций регистра и пунктуации
LABELS_list = []
for case in LABELS_CASE:
for punc in LABELS_PUNC:
LABELS_list.append(f'{case}_{punc}')
LABELS = {label: i+1 for i, label in enumerate(LABELS_list)}
LABELS['O'] = -100
INVERSE_LABELS = {i: label for label, i in LABELS.items()}
LABEL_TO_PUNC_LABEL = {label: label.split('_')[-1] for label in LABELS.keys() if label != 'O'}
LABEL_TO_CASE_LABEL = {label: '_'.join(label.split('_')[:-1]) for label in LABELS.keys() if label != 'O'}
def token_to_label(token, label):
if type(label) == int:
label = INVERSE_LABELS[label]
if label == 'LOWER_O':
return token
if label == 'LOWER_PERIOD':
return token + '.'
if label == 'LOWER_COMMA':
return token + ','
if label == 'LOWER_QUESTION':
return token + '?'
if label == 'UPPER_O':
return token.capitalize()
if label == 'UPPER_PERIOD':
return token.capitalize() + '.'
if label == 'UPPER_COMMA':
return token.capitalize() + ','
if label == 'UPPER_QUESTION':
return token.capitalize() + '?'
if label == 'UPPER_TOTAL_O':
return token.upper()
if label == 'UPPER_TOTAL_PERIOD':
return token.upper() + '.'
if label == 'UPPER_TOTAL_COMMA':
return token.upper() + ','
if label == 'UPPER_TOTAL_QUESTION':
return token.upper() + '?'
if label == 'O':
return token
def decode_label(label, classes='all'):
if classes == 'punc':
return LABEL_TO_PUNC_LABEL[INVERSE_LABELS[label]]
if classes == 'case':
return LABEL_TO_CASE_LABEL[INVERSE_LABELS[label]]
else:
return INVERSE_LABELS[label]
MODEL_REPO = "kontur-ai/sbert-punc-case-ru"
class SbertPuncCase(nn.Module):
def __init__(self):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO,
revision="sbert",
use_auth_token=True,
strip_accents=False)
self.model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO,
revision="sbert",
use_auth_token=True
)
self.model.eval()
def forward(self, input_ids, attention_mask):
return self.model(input_ids=input_ids,
attention_mask=attention_mask)
def punctuate(self, text):
text = text.strip().lower()
# Разобъем предложение на слова
words = text.split()
tokenizer_output = self.tokenizer(words, is_split_into_words=True)
if len(tokenizer_output.input_ids) > 512:
return ' '.join([self.punctuate(' '.join(text_part)) for text_part in np.array_split(words, 2)])
predictions = self(torch.tensor([tokenizer_output.input_ids], device=self.model.device),
torch.tensor([tokenizer_output.attention_mask], device=self.model.device)).logits.cpu().data.numpy()
predictions = np.argmax(predictions, axis=2)
# decode punctuation and casing
splitted_text = []
word_ids = tokenizer_output.word_ids()
for i, word in enumerate(words):
label_pos = word_ids.index(i)
label_id = predictions[0][label_pos]
label = decode_label(label_id)
splitted_text.append(token_to_label(word, label))
capitalized_text = ' '.join(splitted_text)
return capitalized_text
if __name__ == '__main__':
parser = argparse.ArgumentParser("Punctuation and case restoration model sbert-punc-case-ru")
parser.add_argument("-i", "--input", type=str, help="text to restore", default='sbert punc case расставляет точки запятые и знаки вопроса вам нравится')
parser.add_argument("-d", "--device", type=str, help="run model on cpu or gpu", choices=['cpu', 'cuda'], default='cpu')
args = parser.parse_args()
print(f"Source text: {args.input}\n")
sbertpunc = SbertPuncCase().to(args.device)
punctuated_text = sbertpunc.punctuate(args.input)
print(f"Restored text: {punctuated_text}")