bert_japanese_punctuation / insert_punctuation.py
bobfromjapan's picture
Upload 2 files
3084243 verified
# %%
import torch
from transformers import (
BertTokenizer,
BertForMaskedLM,
AutoModelForMaskedLM,
AutoTokenizer,
BertModel,
)
import numpy as np
import random
from itertools import islice
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm
import os
model_name = "tohoku-nlp/bert-base-japanese-char-v3"
tokenizer = BertTokenizer.from_pretrained(model_name)
base_model = BertModel.from_pretrained(model_name)
class punctuation_predictor(torch.nn.Module):
def __init__(self, base_model):
super().__init__()
self.base_model = base_model
self.dropout = torch.nn.Dropout(0.2)
self.linear = torch.nn.Linear(768, 2)
def forward(self, input_ids, attention_mask):
last_hidden_state = self.base_model(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state
# get last hidden state token by token and apply linear layer
return self.linear(self.dropout(last_hidden_state))
model = punctuation_predictor(base_model)
model.load_state_dict(torch.load("weight/punctuation_position_model.pth"))
model.eval()
def insert_punctuation(input, comma_pos, period_pos):
text = []
for i, (c, p) in enumerate(zip(comma_pos, period_pos)):
token_id = input[i].item()
if token_id > 5:
if i < len(input) - 1:
if p:
text.append(tokenizer.ids_to_tokens[input[i].item()] + "。")
elif c:
text.append(tokenizer.ids_to_tokens[input[i].item()] + "、")
else:
text.append(tokenizer.ids_to_tokens[input[i].item()])
else:
break
return "".join(text)
def process_long_text(text, max_length=256, comma_thresh=0.1, period_thresh=0.1):
text = text.replace("、", "").replace("。", "")
result = ""
for i in range(0, len(text), max_length):
no_punctuation_text = text[i : i + max_length]
inputs = tokenizer(
" ".join(list(no_punctuation_text)),
max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt",
)
output = model(inputs.input_ids, inputs.attention_mask)
output = torch.sigmoid(output)
comma_pos = output[0].detach().numpy().T[0] > comma_thresh
period_pos = output[0].detach().numpy().T[1] > period_thresh
result += insert_punctuation(inputs.input_ids[0], comma_pos, period_pos)
return result
# %%
if __name__ == "__main__":
print(
process_long_text(
"女は昨夕艶めかしい姿をして彼の浴室の戸を開けた人に違なかった風呂場で彼を驚ろかした大きな髷をいつの間にか崩して尋常の束髪に結い更えたので彼はつい同じ人と気がつかずにいた彼はさらに声を聴いただけで顔を知らなかった伴の男の方をよそながらの初対面といった風に女と眺め比べた",
)
)