# -*- coding: utf-8 -*- import argparse import torch import torch.nn as nn import numpy as np from transformers import AutoTokenizer, AutoModelForTokenClassification # Прогнозируемые знаки препинания PUNK_MAPPING = {".": "PERIOD", ",": "COMMA", "?": "QUESTION"} # Прогнозируемый регистр LOWER - нижний регистр, UPPER - верхний регистр для первого символа, # UPPER_TOTAL - верхний регистр для всех символов LABELS_CASE = ["LOWER", "UPPER", "UPPER_TOTAL"] # Добавим в пунктуацию метку O означающий отсутсвие пунктуации LABELS_PUNC = ["O"] + list(PUNK_MAPPING.values()) # Сформируем метки на основе комбинаций регистра и пунктуации LABELS_list = [] for case in LABELS_CASE: for punc in LABELS_PUNC: LABELS_list.append(f"{case}_{punc}") LABELS = {label: i + 1 for i, label in enumerate(LABELS_list)} LABELS["O"] = -100 INVERSE_LABELS = {i: label for label, i in LABELS.items()} LABEL_TO_PUNC_LABEL = { label: label.split("_")[-1] for label in LABELS.keys() if label != "O" } LABEL_TO_CASE_LABEL = { label: "_".join(label.split("_")[:-1]) for label in LABELS.keys() if label != "O" } def token_to_label(token, label): if type(label) == int: label = INVERSE_LABELS[label] if label == "LOWER_O": return token if label == "LOWER_PERIOD": return token + "." if label == "LOWER_COMMA": return token + "," if label == "LOWER_QUESTION": return token + "?" if label == "UPPER_O": return token.capitalize() if label == "UPPER_PERIOD": return token.capitalize() + "." if label == "UPPER_COMMA": return token.capitalize() + "," if label == "UPPER_QUESTION": return token.capitalize() + "?" if label == "UPPER_TOTAL_O": return token.upper() if label == "UPPER_TOTAL_PERIOD": return token.upper() + "." if label == "UPPER_TOTAL_COMMA": return token.upper() + "," if label == "UPPER_TOTAL_QUESTION": return token.upper() + "?" if label == "O": return token def decode_label(label, classes="all"): if classes == "punc": return LABEL_TO_PUNC_LABEL[INVERSE_LABELS[label]] if classes == "case": return LABEL_TO_CASE_LABEL[INVERSE_LABELS[label]] else: return INVERSE_LABELS[label] MODEL_REPO = "kontur-ai/sbert_punc_case_ru" class SbertPuncCase(nn.Module): def __init__(self): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, strip_accents=False) self.model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO) self.model.eval() def forward(self, input_ids, attention_mask): return self.model(input_ids=input_ids, attention_mask=attention_mask) def punctuate(self, text): text = text.strip().lower() # Разобъем предложение на слова words = text.split() tokenizer_output = self.tokenizer(words, is_split_into_words=True) if len(tokenizer_output.input_ids) > 512: return " ".join( [ self.punctuate(" ".join(text_part)) for text_part in np.array_split(words, 2) ] ) predictions = ( self( torch.tensor([tokenizer_output.input_ids], device=self.model.device), torch.tensor( [tokenizer_output.attention_mask], device=self.model.device ), ) .logits.cpu() .data.numpy() ) predictions = np.argmax(predictions, axis=2) # decode punctuation and casing splitted_text = [] word_ids = tokenizer_output.word_ids() for i, word in enumerate(words): label_pos = word_ids.index(i) label_id = predictions[0][label_pos] label = decode_label(label_id) splitted_text.append(token_to_label(word, label)) capitalized_text = " ".join(splitted_text) return capitalized_text if __name__ == "__main__": parser = argparse.ArgumentParser( "Punctuation and case restoration model sbert_punc_case_ru" ) parser.add_argument( "-i", "--input", type=str, help="text to restore", default="sbert punc case расставляет точки запятые и знаки вопроса вам нравится", ) parser.add_argument( "-d", "--device", type=str, help="run model on cpu or gpu", choices=["cpu", "cuda"], default="cpu", ) args = parser.parse_args() print(f"Source text: {args.input}\n") sbertpunc = SbertPuncCase().to(args.device) punctuated_text = sbertpunc.punctuate(args.input) print(f"Restored text: {punctuated_text}")