|
|
|
import torch |
|
from transformers import ( |
|
BertTokenizer, |
|
BertForMaskedLM, |
|
AutoModelForMaskedLM, |
|
AutoTokenizer, |
|
BertModel, |
|
) |
|
import numpy as np |
|
import random |
|
from itertools import islice |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.optim import AdamW, SGD |
|
from tqdm import tqdm |
|
import os |
|
|
|
|
|
def index_to_onehot(l, length): |
|
|
|
return [1 if i in l else 0 for i in range(length)] |
|
|
|
|
|
def get_punctuation_position(tokenized_text, tokenizer): |
|
|
|
count = 0 |
|
comma_pos = [] |
|
period_pos = [] |
|
punctuation_removed_text = [] |
|
comma_id = tokenizer.convert_tokens_to_ids("、") |
|
period_id = tokenizer.convert_tokens_to_ids("。") |
|
|
|
for i, c in enumerate(tokenized_text): |
|
if c == comma_id: |
|
comma_pos.append(i - count - 1) |
|
count += 1 |
|
elif c == period_id: |
|
period_pos.append(i - count - 1) |
|
count += 1 |
|
else: |
|
punctuation_removed_text.append(c) |
|
|
|
if len(punctuation_removed_text) < 512: |
|
punctuation_removed_text += [tokenizer.pad_token_id] * ( |
|
512 - len(punctuation_removed_text) |
|
) |
|
|
|
return ( |
|
torch.tensor(punctuation_removed_text), |
|
[ |
|
index_to_onehot(comma_pos, 512), |
|
index_to_onehot(period_pos, 512), |
|
], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PunctuationPositionDataset(torch.utils.data.Dataset): |
|
def __init__(self, data, tokenizer): |
|
self.data = data |
|
self.tokenizer = tokenizer |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
text = self.data[idx] |
|
text = " ".join(list(text)) |
|
inputs = self.tokenizer( |
|
text, |
|
max_length=512, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
|
|
input_ids, label = get_punctuation_position( |
|
inputs["input_ids"][0], self.tokenizer |
|
) |
|
|
|
label = torch.tensor(label, dtype=torch.float32).transpose(0, 1) |
|
|
|
return (input_ids, inputs.attention_mask.squeeze(), label.squeeze(), text) |
|
|
|
|
|
|
|
model_name = "tohoku-nlp/bert-base-japanese-char-v3" |
|
tokenizer = BertTokenizer.from_pretrained(model_name) |
|
base_model = BertModel.from_pretrained(model_name) |
|
|
|
|
|
|
|
class punctuation_predictor(torch.nn.Module): |
|
def __init__(self, base_model): |
|
super().__init__() |
|
self.base_model = base_model |
|
self.dropout = torch.nn.Dropout(0.2) |
|
self.linear = torch.nn.Linear(768, 2) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
last_hidden_state = self.base_model( |
|
input_ids=input_ids, attention_mask=attention_mask |
|
).last_hidden_state |
|
|
|
return self.linear(self.dropout(last_hidden_state)) |
|
|
|
|
|
model = punctuation_predictor(base_model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
with open("data/train.txt", "r") as f: |
|
texts = f.readlines() |
|
|
|
dataset = PunctuationPositionDataset(texts, tokenizer) |
|
|
|
data_loader = DataLoader( |
|
dataset, |
|
batch_size=16, |
|
shuffle=True, |
|
num_workers=8, |
|
) |
|
|
|
|
|
|
|
optimizer = AdamW( |
|
[ |
|
{"params": model.base_model.parameters(), "lr": 5e-5}, |
|
{"params": model.linear.parameters(), "lr": 1e-3}, |
|
], |
|
) |
|
|
|
criteria = torch.nn.BCEWithLogitsLoss() |
|
|
|
model.train() |
|
model.to("cuda") |
|
for epoch in range(10): |
|
epoch_loss = 0.0 |
|
progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}") |
|
for batch in progress_bar: |
|
input_ids, attention_masks, labels, text = batch |
|
input_ids = input_ids.to("cuda") |
|
attention_masks = attention_masks.to("cuda") |
|
labels = labels.to("cuda") |
|
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_masks) |
|
loss = criteria(outputs, labels) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
epoch_loss += loss.item() |
|
progress_bar.set_postfix({"loss": epoch_loss / len(data_loader)}) |
|
|
|
torch.save(model.state_dict(), "punctuation_position_model.pth") |
|
|