Partial-Arabic-Diacritization / model_partial.py
bkhmsi's picture
support for TD2
d7c4b94
from typing import NamedTuple
import yaml
from tqdm import tqdm
import numpy as np
import torch as T
from torch import nn
from torch.nn import functional as F
from diac_utils import flat_2_3head
from model_dd import DiacritizerD2
from model_plm import Diacritizer
class Readout(nn.Module):
def __init__(
self,
in_size: int,
out_size: int,
):
super().__init__()
self.W1 = nn.Linear(in_size, in_size)
self.W2 = nn.Linear(in_size, out_size)
def forward(self, x: T.Tensor):
z = self.W1(x)
z = T.tanh(z)
z = self.W2(x)
return z
class WordDD_LSTM(nn.Module):
def __init__(
self,
feature_size: int,
num_classes: int = 13,
return_logits: bool = True,
):
super().__init__()
self.feature_size = feature_size
self.num_classes = num_classes
self.return_logits = return_logits
self.cell = nn.LSTM(feature_size)
self.head = Readout(feature_size, num_classes)
def forward(self, x: T.Tensor):
#^ x: [b tc dc]
z = self.cell(x)
#^ z: [b tc @dc]
y = self.head(z)
#^ y: [b tc Classes]
yhat = y
if not self.return_logits:
yhat = F.softmax(yhat, dim=1)
#^ yhat: [b tc @Classes]
return yhat
class PartialDiacOutput(NamedTuple):
preds_hard: T.Tensor
preds_ctxt_logit: T.Tensor
preds_base_logit: T.Tensor
class PartialDD(nn.Module):
def __init__(
self,
config: dict,
**kwargs
):
super().__init__()
self._built = False
self.no_diac_id = 0
self._dummy = nn.Parameter(T.ones(1, 1))
# with open('./configs/dd/config_d2.yaml', 'r', encoding='utf-8') as fin:
# self.config_d2 = yaml.safe_load(fin)
# self.device = T.device('cuda' if T.cuda.is_available() else 'cpu')
self.config = config
self._use_d2 = config["model-name"] == "D2"
if self._use_d2:
self.sentence_diac = DiacritizerD2(self.config)
else:
self.sentence_diac = Diacritizer(self.config, load_pretrained=False)
# self.sentence_diac.to(self.device)
# self.build()
# self.word_diac = WordDD_LSTM(feature_size, num_classes=13, return_logits=False)
self.eval()
@property
def device(self):
return self._dummy.device
@property
def tokenizer(self):
return self.sentence_diac.tokenizer
def load_state_dict(
self,
state_dict: dict,
strict: bool = True,
):
self.sentence_diac.load_state_dict(state_dict, strict=strict)
def _slim_batch(
self,
toke_ids: T.Tensor,
char_ids: T.Tensor,
diac_ids: T.Tensor,
subword_lengths: T.Tensor,
):
#^ toke_ids: [b tt]
#^ char_ids: [b tw tc]
#^ diac_ids: [b tw tc "13"]
#^ subword_lengths: [b tw]
token_nonpad_mask = toke_ids.ne(self.tokenizer.pad_token_id)
Ttoken = token_nonpad_mask.sum(1).max()
toke_ids = toke_ids[:, :Ttoken]
char_nonpad_mask = char_ids.ne(0)
Tword = char_nonpad_mask.any(2).sum(1).max()
Tchar = char_nonpad_mask.sum(2).max()
char_ids = char_ids[:, :Tword, :Tchar]
diac_ids = diac_ids[:, :Tword, :Tchar]
subword_lengths = subword_lengths[:, :Tword]
return toke_ids, char_ids, diac_ids, subword_lengths
T.jit.export
def word_diac(
self,
toke_ids: T.Tensor,
char_ids: T.Tensor,
diac_ids: T.Tensor,
subword_lengths: T.Tensor,
*,
shape: tuple = None,
):
if shape is None:
toke_ids, char_ids, diac_ids, subword_lengths = self._slim_batch(
toke_ids, char_ids, diac_ids, subword_lengths
)
else:
Nb, Tw, Tc = shape
toke_ids = toke_ids[:, :]
char_ids = char_ids[:, :Tw, :Tc]
diac_ids = diac_ids[:, :Tw, :Tc, :]
subword_lengths = subword_lengths[:, :Tw]
Nb, Tw, Tc = char_ids.shape
# Tw = min(Tw, word_ids.shape[1])
#^ word_ids: [b tt]
#^ char_ids: [b tw tc]
# wids_flat = word_ids[:, Tw].reshape(Nb * Tw, 1)
# cids_flat = char_ids[:, Tw].reshape(Nb * Tw, 1, Tc)
# z = self.sentence_diac(wids_flat, cids_flat)
sent_word_strides = subword_lengths.cumsum(1)
assert tuple(subword_lengths.shape) == (Nb, Tw), f"{subword_lengths.shape} != {(Nb, Tw)=}"
max_tokens_per_word: int = subword_lengths.max().int().item()
word_x = T.zeros(Nb, Tw, max_tokens_per_word).to(toke_ids)
for i_b in range(toke_ids.shape[0]):
sent_i = toke_ids[i_b]
start_iw = 0
for i_word, end_iw in enumerate(sent_word_strides[i_b]):
if end_iw == start_iw: break
word = sent_i[start_iw:end_iw]
word_x[i_b, i_word, 0 : end_iw - start_iw] = word
start_iw = end_iw
#^ word_x: [b tw tt]
word_x = word_x.reshape(Nb * Tw, max_tokens_per_word)
cids_flat = char_ids.reshape(Nb * Tw, 1, Tc)
word_lengths = subword_lengths.reshape(Nb * Tw, 1)
z = self.sentence_diac(
word_x,
cids_flat,
diac_ids.reshape(Nb*Tw, Tc, -1),
subword_lengths=word_lengths,
)
# Nc = z.shape[-1]
#^ z: [b*tw, 1, tc, "13"]
z = z.reshape(Nb, Tw, Tc, -1)
return z
T.jit.ignore
def forward(
self,
word_ids: T.Tensor,
char_ids: T.Tensor,
_labels: T.Tensor,
# ground_truth: T.Tensor,
# padding_mask: T.BoolTensor,
*,
eval_only: str = None,
subword_lengths: T.Tensor,
return_extra: bool = False,
do_partial: bool = False,
):
# assert self._built and not self.training
assert not self.training
#^ word_ids: [b tw]
#^ char_ids: [b tw tc]
#^ ground_truth: [b tw tc]
padding_mask = char_ids.eq(0)
#^ padding_mask: [b tw tc]
if True or eval_only != 'base':
y_ctxt = self.sentence_diac(
word_ids,
char_ids,
_labels,
subword_lengths=subword_lengths,
)
out_shape = y_ctxt.shape[:-1]
else:
out_shape = self.sentence_diac._slim_batch_size(
word_ids,
char_ids,
_labels,
subword_lengths,
)[1].shape
#^ y_ctxt: [b tw tc "13"]
if eval_only == 'ctxt':
return y_ctxt.argmax(-1)
y_base = self.word_diac(
word_ids,
char_ids,
_labels,
subword_lengths,
shape=out_shape
)
#^ y_base: [b tw tc "13"]
if eval_only == 'base':
return y_base.argmax(-1)
#! TODO: Return the logits.
ypred_ctxt = y_ctxt.argmax(-1)
ypred_base = y_base.argmax(-1)
#^ ypred: [b tw tc _]
# Maybe for eval
# ypred_ctxt[~((ypred_base == ground_truth) & (~padding_mask))] = self.no_diac_id
# return ypred_ctxt
if do_partial:
ypred_ctxt[(padding_mask) | (ypred_base == ypred_ctxt)] = self.no_diac_id
if not return_extra:
return ypred_ctxt
else:
return PartialDiacOutput(ypred_ctxt, y_ctxt, y_base)
def step(self, xt, yt, mask=None):
raise NotImplementedError
xt[1] = xt[1].to(self.device)
xt[2] = xt[2].to(self.device)
yt = yt.to(self.device)
#^ yt: [b ts tw]
diac, _ = self(*xt) # xt: (word_ids, char_ids, _labels)
loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1))
return loss
def predict_partial(
self,
dataloader,
return_extra=False,
eval_only: str = None,
do_partial=True,
):
training = self.training
self.eval()
preds = {
'haraka': [],
'shadda': [],
'tanween': [],
'diacs': [],
'y_ctxt': [],
'y_base': [],
'subword_lengths': [],
}
print("> Predicting...")
# breakpoint()
for i_batch, (inputs, _, subword_lengths) in enumerate(tqdm(dataloader)):
# if i_batch > 10:
# break
#^ inputs: [toke_ids, char_ids, diac_ids]
inputs[0] = inputs[0].to(self.device) #< toke_ids
inputs[1] = inputs[1].to(self.device) #< char_ids
# inputs[2] = inputs[2].to(self.device) #< diac_ids
if self._use_d2:
subword_lengths = T.ones_like(inputs[0])
subword_lengths[inputs[0] == 0] = 0
with T.no_grad():
output = self(
*inputs,
subword_lengths=subword_lengths,
return_extra=return_extra,
eval_only=eval_only,
do_partial=do_partial,
)
# output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1)
if return_extra:
assert isinstance(output, PartialDiacOutput)
marks = output.preds_hard
if eval_only == 'recalibrated':
marks = (output.preds_ctxt_logit + output.preds_base_logit).argmax(-1)
preds['diacs'].extend(list(marks.detach().cpu().numpy()))
preds['y_ctxt'].extend(list(output.preds_ctxt_logit.detach().cpu().numpy()))
preds['y_base'].extend(list(output.preds_base_logit.detach().cpu().numpy()))
preds['subword_lengths'].extend(list(subword_lengths.detach().cpu().numpy()))
else:
assert isinstance(output, T.Tensor)
marks = output
preds['diacs'].extend(list(marks.detach().cpu().numpy()))
#^ [b ts tw]
haraka, tanween, shadda = flat_2_3head(marks)
preds['haraka'].extend(haraka)
preds['tanween'].extend(tanween)
preds['shadda'].extend(shadda)
self.train(training)
return {
'diacritics': (
#! FIXME! Due to batch slimming, output diacritics may need padding.
np.array(preds['haraka']),
np.array(preds["tanween"]),
np.array(preds["shadda"]),
),
'other': ( # Would be empty when !return_extra
np.array(preds['y_ctxt']),
np.array(preds['y_base']),
np.array(preds['diacs']),
np.array(preds['subword_lengths']),
)
}
def predict(self, dataloader):
training = self.training
self.eval()
preds = {'haraka': [], 'shadda': [], 'tanween': []}
print("> Predicting...")
for inputs, _ in tqdm(dataloader, total=len(dataloader)):
inputs[0] = inputs[0].to(self.device)
inputs[1] = inputs[1].to(self.device)
output = self(*inputs)
# output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1)
marks = output
#^ [b ts tw]
haraka, tanween, shadda = flat_2_3head(marks)
preds['haraka'].extend(haraka)
preds['tanween'].extend(tanween)
preds['shadda'].extend(shadda)
self.train(training)
return (
np.array(preds['haraka']),
np.array(preds["tanween"]),
np.array(preds["shadda"]),
)