|
|
|
|
|
import argparse |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
|
|
|
|
PUNK_MAPPING = {".": "PERIOD", ",": "COMMA", "?": "QUESTION"} |
|
|
|
|
|
|
|
LABELS_CASE = ["LOWER", "UPPER", "UPPER_TOTAL"] |
|
|
|
LABELS_PUNC = ["O"] + list(PUNK_MAPPING.values()) |
|
|
|
|
|
LABELS_list = [] |
|
for case in LABELS_CASE: |
|
for punc in LABELS_PUNC: |
|
LABELS_list.append(f"{case}_{punc}") |
|
LABELS = {label: i + 1 for i, label in enumerate(LABELS_list)} |
|
LABELS["O"] = -100 |
|
INVERSE_LABELS = {i: label for label, i in LABELS.items()} |
|
|
|
LABEL_TO_PUNC_LABEL = { |
|
label: label.split("_")[-1] for label in LABELS.keys() if label != "O" |
|
} |
|
LABEL_TO_CASE_LABEL = { |
|
label: "_".join(label.split("_")[:-1]) for label in LABELS.keys() if label != "O" |
|
} |
|
|
|
|
|
def token_to_label(token, label): |
|
if type(label) == int: |
|
label = INVERSE_LABELS[label] |
|
if label == "LOWER_O": |
|
return token |
|
if label == "LOWER_PERIOD": |
|
return token + "." |
|
if label == "LOWER_COMMA": |
|
return token + "," |
|
if label == "LOWER_QUESTION": |
|
return token + "?" |
|
if label == "UPPER_O": |
|
return token.capitalize() |
|
if label == "UPPER_PERIOD": |
|
return token.capitalize() + "." |
|
if label == "UPPER_COMMA": |
|
return token.capitalize() + "," |
|
if label == "UPPER_QUESTION": |
|
return token.capitalize() + "?" |
|
if label == "UPPER_TOTAL_O": |
|
return token.upper() |
|
if label == "UPPER_TOTAL_PERIOD": |
|
return token.upper() + "." |
|
if label == "UPPER_TOTAL_COMMA": |
|
return token.upper() + "," |
|
if label == "UPPER_TOTAL_QUESTION": |
|
return token.upper() + "?" |
|
if label == "O": |
|
return token |
|
|
|
|
|
def decode_label(label, classes="all"): |
|
if classes == "punc": |
|
return LABEL_TO_PUNC_LABEL[INVERSE_LABELS[label]] |
|
if classes == "case": |
|
return LABEL_TO_CASE_LABEL[INVERSE_LABELS[label]] |
|
else: |
|
return INVERSE_LABELS[label] |
|
|
|
|
|
MODEL_REPO = "kontur-ai/sbert_punc_case_ru" |
|
|
|
|
|
class SbertPuncCase(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, strip_accents=False) |
|
self.model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO) |
|
self.model.eval() |
|
|
|
def forward(self, input_ids, attention_mask): |
|
return self.model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
|
def punctuate(self, text): |
|
text = text.strip().lower() |
|
|
|
|
|
words = text.split() |
|
|
|
tokenizer_output = self.tokenizer(words, is_split_into_words=True) |
|
|
|
if len(tokenizer_output.input_ids) > 512: |
|
return " ".join( |
|
[ |
|
self.punctuate(" ".join(text_part)) |
|
for text_part in np.array_split(words, 2) |
|
] |
|
) |
|
|
|
predictions = ( |
|
self( |
|
torch.tensor([tokenizer_output.input_ids], device=self.model.device), |
|
torch.tensor( |
|
[tokenizer_output.attention_mask], device=self.model.device |
|
), |
|
) |
|
.logits.cpu() |
|
.data.numpy() |
|
) |
|
predictions = np.argmax(predictions, axis=2) |
|
|
|
|
|
splitted_text = [] |
|
word_ids = tokenizer_output.word_ids() |
|
for i, word in enumerate(words): |
|
label_pos = word_ids.index(i) |
|
label_id = predictions[0][label_pos] |
|
label = decode_label(label_id) |
|
splitted_text.append(token_to_label(word, label)) |
|
capitalized_text = " ".join(splitted_text) |
|
return capitalized_text |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
"Punctuation and case restoration model sbert_punc_case_ru" |
|
) |
|
parser.add_argument( |
|
"-i", |
|
"--input", |
|
type=str, |
|
help="text to restore", |
|
default="sbert punc case расставляет точки запятые и знаки вопроса вам нравится", |
|
) |
|
parser.add_argument( |
|
"-d", |
|
"--device", |
|
type=str, |
|
help="run model on cpu or gpu", |
|
choices=["cpu", "cuda"], |
|
default="cpu", |
|
) |
|
args = parser.parse_args() |
|
print(f"Source text: {args.input}\n") |
|
sbertpunc = SbertPuncCase().to(args.device) |
|
punctuated_text = sbertpunc.punctuate(args.input) |
|
print(f"Restored text: {punctuated_text}") |
|
|