bert_japanese_punctuation / insert_punctuation.py
bobfromjapan's picture
Upload 7 files
ae56469 verified
raw
history blame
2.83 kB
# %%
import torch
from transformers import (
BertTokenizer,
BertForMaskedLM,
AutoModelForMaskedLM,
AutoTokenizer,
BertModel,
)
import numpy as np
import random
from itertools import islice
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm
import os
model_name = "tohoku-nlp/bert-base-japanese-char-v3"
tokenizer = BertTokenizer.from_pretrained(model_name)
base_model = BertModel.from_pretrained(model_name)
class punctuation_predictor(torch.nn.Module):
def __init__(self, base_model):
super().__init__()
self.base_model = base_model
self.dropout = torch.nn.Dropout(0.2)
self.linear = torch.nn.Linear(768, 2)
def forward(self, input_ids, attention_mask):
last_hidden_state = self.base_model(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state
# get last hidden state token by token and apply linear layer
return self.linear(self.dropout(last_hidden_state))
model = punctuation_predictor(base_model)
model.load_state_dict(torch.load("punctuation_position_model.pth"))
model.eval()
def insert_punctuation(input, comma_pos, period_pos):
text = []
for i, (c, p) in enumerate(zip(comma_pos, period_pos)):
token_id = input[i].item()
if token_id > 5:
if i < len(input) - 1:
if p:
text.append(tokenizer.ids_to_tokens[input[i].item()] + "。")
elif c:
text.append(tokenizer.ids_to_tokens[input[i].item()] + "、")
else:
text.append(tokenizer.ids_to_tokens[input[i].item()])
else:
break
return "".join(text)
def process_long_text(text, max_length=256, comma_thresh=0.1, period_thresh=0.1):
text = text.replace("、", "").replace("。", "")
result = ""
for i in range(0, len(text), max_length):
no_punctuation_text = text[i : i + max_length]
inputs = tokenizer(
" ".join(list(no_punctuation_text)),
max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt",
)
output = model(inputs.input_ids, inputs.attention_mask)
output = torch.sigmoid(output)
comma_pos = output[0].detach().numpy().T[0] > comma_thresh
period_pos = output[0].detach().numpy().T[1] > period_thresh
result += insert_punctuation(inputs.input_ids[0], comma_pos, period_pos)
return result
# %%
if __name__ == "__main__":
print(
process_long_text(
"句読点ありバージョンを書きました句読点があることで僕は逆に読みづらく感じるので句読点無しで書きたいと思います",
)
)