# -*- coding: utf-8 -*- import argparse import torch import torch.nn as nn import numpy as np from transformers import AutoTokenizer, AutoModelForTokenClassification # Прогнозируемые знаки препинания PUNK_MAPPING = {'.': 'PERIOD', ',': 'COMMA', '?': 'QUESTION'} # Прогнозируемый регистр LOWER - нижний регистр, UPPER - верхний регистр для первого символа, # UPPER_TOTAL - верхний регистр для всех символов LABELS_CASE = ['LOWER', 'UPPER', 'UPPER_TOTAL'] # Добавим в пунктуацию метку O означающий отсутсвие пунктуации LABELS_PUNC = ['O'] + list(PUNK_MAPPING.values()) # Сформируем метки на основе комбинаций регистра и пунктуации LABELS_list = [] for case in LABELS_CASE: for punc in LABELS_PUNC: LABELS_list.append(f'{case}_{punc}') LABELS = {label: i+1 for i, label in enumerate(LABELS_list)} LABELS['O'] = -100 INVERSE_LABELS = {i: label for label, i in LABELS.items()} LABEL_TO_PUNC_LABEL = {label: label.split('_')[-1] for label in LABELS.keys() if label != 'O'} LABEL_TO_CASE_LABEL = {label: '_'.join(label.split('_')[:-1]) for label in LABELS.keys() if label != 'O'} def token_to_label(token, label): if type(label) == int: label = INVERSE_LABELS[label] if label == 'LOWER_O': return token if label == 'LOWER_PERIOD': return token + '.' if label == 'LOWER_COMMA': return token + ',' if label == 'LOWER_QUESTION': return token + '?' if label == 'UPPER_O': return token.capitalize() if label == 'UPPER_PERIOD': return token.capitalize() + '.' if label == 'UPPER_COMMA': return token.capitalize() + ',' if label == 'UPPER_QUESTION': return token.capitalize() + '?' if label == 'UPPER_TOTAL_O': return token.upper() if label == 'UPPER_TOTAL_PERIOD': return token.upper() + '.' if label == 'UPPER_TOTAL_COMMA': return token.upper() + ',' if label == 'UPPER_TOTAL_QUESTION': return token.upper() + '?' if label == 'O': return token def decode_label(label, classes='all'): if classes == 'punc': return LABEL_TO_PUNC_LABEL[INVERSE_LABELS[label]] if classes == 'case': return LABEL_TO_CASE_LABEL[INVERSE_LABELS[label]] else: return INVERSE_LABELS[label] MODEL_REPO = "kontur-ai/sbert-punc-case-ru" class SbertPuncCase(nn.Module): def __init__(self): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, revision="sbert", use_auth_token=True, strip_accents=False) self.model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO, revision="sbert", use_auth_token=True ) self.model.eval() def forward(self, input_ids, attention_mask): return self.model(input_ids=input_ids, attention_mask=attention_mask) def punctuate(self, text): text = text.strip().lower() # Разобъем предложение на слова words = text.split() tokenizer_output = self.tokenizer(words, is_split_into_words=True) if len(tokenizer_output.input_ids) > 512: return ' '.join([self.punctuate(' '.join(text_part)) for text_part in np.array_split(words, 2)]) predictions = self(torch.tensor([tokenizer_output.input_ids], device=self.model.device), torch.tensor([tokenizer_output.attention_mask], device=self.model.device)).logits.cpu().data.numpy() predictions = np.argmax(predictions, axis=2) # decode punctuation and casing splitted_text = [] word_ids = tokenizer_output.word_ids() for i, word in enumerate(words): label_pos = word_ids.index(i) label_id = predictions[0][label_pos] label = decode_label(label_id) splitted_text.append(token_to_label(word, label)) capitalized_text = ' '.join(splitted_text) return capitalized_text if __name__ == '__main__': parser = argparse.ArgumentParser("Punctuation and case restoration model sbert-punc-case-ru") parser.add_argument("-i", "--input", type=str, help="text to restore", default='sbert punc case расставляет точки запятые и знаки вопроса вам нравится') parser.add_argument("-d", "--device", type=str, help="run model on cpu or gpu", choices=['cpu', 'cuda'], default='cpu') args = parser.parse_args() print(f"Source text: {args.input}\n") sbertpunc = SbertPuncCase().to(args.device) punctuated_text = sbertpunc.punctuate(args.input) print(f"Restored text: {punctuated_text}")