|
|
|
import torch |
|
from transformers import ( |
|
BertTokenizer, |
|
BertForMaskedLM, |
|
AutoModelForMaskedLM, |
|
AutoTokenizer, |
|
BertModel, |
|
) |
|
import numpy as np |
|
import random |
|
from itertools import islice |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.optim import AdamW |
|
from tqdm.auto import tqdm |
|
import os |
|
|
|
model_name = "tohoku-nlp/bert-base-japanese-char-v3" |
|
tokenizer = BertTokenizer.from_pretrained(model_name) |
|
base_model = BertModel.from_pretrained(model_name) |
|
|
|
|
|
class punctuation_predictor(torch.nn.Module): |
|
def __init__(self, base_model): |
|
super().__init__() |
|
self.base_model = base_model |
|
self.dropout = torch.nn.Dropout(0.2) |
|
self.linear = torch.nn.Linear(768, 2) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
last_hidden_state = self.base_model( |
|
input_ids=input_ids, attention_mask=attention_mask |
|
).last_hidden_state |
|
|
|
return self.linear(self.dropout(last_hidden_state)) |
|
|
|
|
|
model = punctuation_predictor(base_model) |
|
model.load_state_dict(torch.load("punctuation_position_model.pth")) |
|
model.eval() |
|
|
|
|
|
def insert_punctuation(input, comma_pos, period_pos): |
|
text = [] |
|
for i, (c, p) in enumerate(zip(comma_pos, period_pos)): |
|
token_id = input[i].item() |
|
if token_id > 5: |
|
if i < len(input) - 1: |
|
if p: |
|
text.append(tokenizer.ids_to_tokens[input[i].item()] + "。") |
|
elif c: |
|
text.append(tokenizer.ids_to_tokens[input[i].item()] + "、") |
|
else: |
|
text.append(tokenizer.ids_to_tokens[input[i].item()]) |
|
else: |
|
break |
|
return "".join(text) |
|
|
|
|
|
def process_long_text(text, max_length=256, comma_thresh=0.1, period_thresh=0.1): |
|
text = text.replace("、", "").replace("。", "") |
|
result = "" |
|
for i in range(0, len(text), max_length): |
|
no_punctuation_text = text[i : i + max_length] |
|
inputs = tokenizer( |
|
" ".join(list(no_punctuation_text)), |
|
max_length=512, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
output = model(inputs.input_ids, inputs.attention_mask) |
|
output = torch.sigmoid(output) |
|
comma_pos = output[0].detach().numpy().T[0] > comma_thresh |
|
period_pos = output[0].detach().numpy().T[1] > period_thresh |
|
result += insert_punctuation(inputs.input_ids[0], comma_pos, period_pos) |
|
return result |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
print( |
|
process_long_text( |
|
"句読点ありバージョンを書きました句読点があることで僕は逆に読みづらく感じるので句読点無しで書きたいと思います", |
|
) |
|
) |
|
|