Spaces:
Runtime error
Runtime error
from typing import NamedTuple | |
import yaml | |
from tqdm import tqdm | |
import numpy as np | |
import torch as T | |
from torch import nn | |
from torch.nn import functional as F | |
from diac_utils import flat_2_3head | |
from model_dd import DiacritizerD2 | |
from model_plm import Diacritizer | |
class Readout(nn.Module): | |
def __init__( | |
self, | |
in_size: int, | |
out_size: int, | |
): | |
super().__init__() | |
self.W1 = nn.Linear(in_size, in_size) | |
self.W2 = nn.Linear(in_size, out_size) | |
def forward(self, x: T.Tensor): | |
z = self.W1(x) | |
z = T.tanh(z) | |
z = self.W2(x) | |
return z | |
class WordDD_LSTM(nn.Module): | |
def __init__( | |
self, | |
feature_size: int, | |
num_classes: int = 13, | |
return_logits: bool = True, | |
): | |
super().__init__() | |
self.feature_size = feature_size | |
self.num_classes = num_classes | |
self.return_logits = return_logits | |
self.cell = nn.LSTM(feature_size) | |
self.head = Readout(feature_size, num_classes) | |
def forward(self, x: T.Tensor): | |
#^ x: [b tc dc] | |
z = self.cell(x) | |
#^ z: [b tc @dc] | |
y = self.head(z) | |
#^ y: [b tc Classes] | |
yhat = y | |
if not self.return_logits: | |
yhat = F.softmax(yhat, dim=1) | |
#^ yhat: [b tc @Classes] | |
return yhat | |
class PartialDiacOutput(NamedTuple): | |
preds_hard: T.Tensor | |
preds_ctxt_logit: T.Tensor | |
preds_base_logit: T.Tensor | |
class PartialDD(nn.Module): | |
def __init__( | |
self, | |
config: dict, | |
**kwargs | |
): | |
super().__init__() | |
self._built = False | |
self.no_diac_id = 0 | |
self._dummy = nn.Parameter(T.ones(1, 1)) | |
# with open('./configs/dd/config_d2.yaml', 'r', encoding='utf-8') as fin: | |
# self.config_d2 = yaml.safe_load(fin) | |
# self.device = T.device('cuda' if T.cuda.is_available() else 'cpu') | |
self.config = config | |
self._use_d2 = config["model-name"] == "D2" | |
if self._use_d2: | |
self.sentence_diac = DiacritizerD2(self.config) | |
else: | |
self.sentence_diac = Diacritizer(self.config, load_pretrained=False) | |
# self.sentence_diac.to(self.device) | |
# self.build() | |
# self.word_diac = WordDD_LSTM(feature_size, num_classes=13, return_logits=False) | |
self.eval() | |
def device(self): | |
return self._dummy.device | |
def tokenizer(self): | |
return self.sentence_diac.tokenizer | |
def load_state_dict( | |
self, | |
state_dict: dict, | |
strict: bool = True, | |
): | |
self.sentence_diac.load_state_dict(state_dict, strict=strict) | |
def _slim_batch( | |
self, | |
toke_ids: T.Tensor, | |
char_ids: T.Tensor, | |
diac_ids: T.Tensor, | |
subword_lengths: T.Tensor, | |
): | |
#^ toke_ids: [b tt] | |
#^ char_ids: [b tw tc] | |
#^ diac_ids: [b tw tc "13"] | |
#^ subword_lengths: [b tw] | |
token_nonpad_mask = toke_ids.ne(self.tokenizer.pad_token_id) | |
Ttoken = token_nonpad_mask.sum(1).max() | |
toke_ids = toke_ids[:, :Ttoken] | |
char_nonpad_mask = char_ids.ne(0) | |
Tword = char_nonpad_mask.any(2).sum(1).max() | |
Tchar = char_nonpad_mask.sum(2).max() | |
char_ids = char_ids[:, :Tword, :Tchar] | |
diac_ids = diac_ids[:, :Tword, :Tchar] | |
subword_lengths = subword_lengths[:, :Tword] | |
return toke_ids, char_ids, diac_ids, subword_lengths | |
T.jit.export | |
def word_diac( | |
self, | |
toke_ids: T.Tensor, | |
char_ids: T.Tensor, | |
diac_ids: T.Tensor, | |
subword_lengths: T.Tensor, | |
*, | |
shape: tuple = None, | |
): | |
if shape is None: | |
toke_ids, char_ids, diac_ids, subword_lengths = self._slim_batch( | |
toke_ids, char_ids, diac_ids, subword_lengths | |
) | |
else: | |
Nb, Tw, Tc = shape | |
toke_ids = toke_ids[:, :] | |
char_ids = char_ids[:, :Tw, :Tc] | |
diac_ids = diac_ids[:, :Tw, :Tc, :] | |
subword_lengths = subword_lengths[:, :Tw] | |
Nb, Tw, Tc = char_ids.shape | |
# Tw = min(Tw, word_ids.shape[1]) | |
#^ word_ids: [b tt] | |
#^ char_ids: [b tw tc] | |
# wids_flat = word_ids[:, Tw].reshape(Nb * Tw, 1) | |
# cids_flat = char_ids[:, Tw].reshape(Nb * Tw, 1, Tc) | |
# z = self.sentence_diac(wids_flat, cids_flat) | |
sent_word_strides = subword_lengths.cumsum(1) | |
assert tuple(subword_lengths.shape) == (Nb, Tw), f"{subword_lengths.shape} != {(Nb, Tw)=}" | |
max_tokens_per_word: int = subword_lengths.max().int().item() | |
word_x = T.zeros(Nb, Tw, max_tokens_per_word).to(toke_ids) | |
for i_b in range(toke_ids.shape[0]): | |
sent_i = toke_ids[i_b] | |
start_iw = 0 | |
for i_word, end_iw in enumerate(sent_word_strides[i_b]): | |
if end_iw == start_iw: break | |
word = sent_i[start_iw:end_iw] | |
word_x[i_b, i_word, 0 : end_iw - start_iw] = word | |
start_iw = end_iw | |
#^ word_x: [b tw tt] | |
word_x = word_x.reshape(Nb * Tw, max_tokens_per_word) | |
cids_flat = char_ids.reshape(Nb * Tw, 1, Tc) | |
word_lengths = subword_lengths.reshape(Nb * Tw, 1) | |
z = self.sentence_diac( | |
word_x, | |
cids_flat, | |
diac_ids.reshape(Nb*Tw, Tc, -1), | |
subword_lengths=word_lengths, | |
) | |
# Nc = z.shape[-1] | |
#^ z: [b*tw, 1, tc, "13"] | |
z = z.reshape(Nb, Tw, Tc, -1) | |
return z | |
T.jit.ignore | |
def forward( | |
self, | |
word_ids: T.Tensor, | |
char_ids: T.Tensor, | |
_labels: T.Tensor, | |
# ground_truth: T.Tensor, | |
# padding_mask: T.BoolTensor, | |
*, | |
eval_only: str = None, | |
subword_lengths: T.Tensor, | |
return_extra: bool = False, | |
do_partial: bool = False, | |
): | |
# assert self._built and not self.training | |
assert not self.training | |
#^ word_ids: [b tw] | |
#^ char_ids: [b tw tc] | |
#^ ground_truth: [b tw tc] | |
padding_mask = char_ids.eq(0) | |
#^ padding_mask: [b tw tc] | |
if True or eval_only != 'base': | |
y_ctxt = self.sentence_diac( | |
word_ids, | |
char_ids, | |
_labels, | |
subword_lengths=subword_lengths, | |
) | |
out_shape = y_ctxt.shape[:-1] | |
else: | |
out_shape = self.sentence_diac._slim_batch_size( | |
word_ids, | |
char_ids, | |
_labels, | |
subword_lengths, | |
)[1].shape | |
#^ y_ctxt: [b tw tc "13"] | |
if eval_only == 'ctxt': | |
return y_ctxt.argmax(-1) | |
y_base = self.word_diac( | |
word_ids, | |
char_ids, | |
_labels, | |
subword_lengths, | |
shape=out_shape | |
) | |
#^ y_base: [b tw tc "13"] | |
if eval_only == 'base': | |
return y_base.argmax(-1) | |
#! TODO: Return the logits. | |
ypred_ctxt = y_ctxt.argmax(-1) | |
ypred_base = y_base.argmax(-1) | |
#^ ypred: [b tw tc _] | |
# Maybe for eval | |
# ypred_ctxt[~((ypred_base == ground_truth) & (~padding_mask))] = self.no_diac_id | |
# return ypred_ctxt | |
if do_partial: | |
ypred_ctxt[(padding_mask) | (ypred_base == ypred_ctxt)] = self.no_diac_id | |
if not return_extra: | |
return ypred_ctxt | |
else: | |
return PartialDiacOutput(ypred_ctxt, y_ctxt, y_base) | |
def step(self, xt, yt, mask=None): | |
raise NotImplementedError | |
xt[1] = xt[1].to(self.device) | |
xt[2] = xt[2].to(self.device) | |
yt = yt.to(self.device) | |
#^ yt: [b ts tw] | |
diac, _ = self(*xt) # xt: (word_ids, char_ids, _labels) | |
loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1)) | |
return loss | |
def predict_partial( | |
self, | |
dataloader, | |
return_extra=False, | |
eval_only: str = None, | |
do_partial=True, | |
): | |
training = self.training | |
self.eval() | |
preds = { | |
'haraka': [], | |
'shadda': [], | |
'tanween': [], | |
'diacs': [], | |
'y_ctxt': [], | |
'y_base': [], | |
'subword_lengths': [], | |
} | |
print("> Predicting...") | |
# breakpoint() | |
for i_batch, (inputs, _, subword_lengths) in enumerate(tqdm(dataloader)): | |
# if i_batch > 10: | |
# break | |
#^ inputs: [toke_ids, char_ids, diac_ids] | |
inputs[0] = inputs[0].to(self.device) #< toke_ids | |
inputs[1] = inputs[1].to(self.device) #< char_ids | |
# inputs[2] = inputs[2].to(self.device) #< diac_ids | |
if self._use_d2: | |
subword_lengths = T.ones_like(inputs[0]) | |
subword_lengths[inputs[0] == 0] = 0 | |
with T.no_grad(): | |
output = self( | |
*inputs, | |
subword_lengths=subword_lengths, | |
return_extra=return_extra, | |
eval_only=eval_only, | |
do_partial=do_partial, | |
) | |
# output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1) | |
if return_extra: | |
assert isinstance(output, PartialDiacOutput) | |
marks = output.preds_hard | |
if eval_only == 'recalibrated': | |
marks = (output.preds_ctxt_logit + output.preds_base_logit).argmax(-1) | |
preds['diacs'].extend(list(marks.detach().cpu().numpy())) | |
preds['y_ctxt'].extend(list(output.preds_ctxt_logit.detach().cpu().numpy())) | |
preds['y_base'].extend(list(output.preds_base_logit.detach().cpu().numpy())) | |
preds['subword_lengths'].extend(list(subword_lengths.detach().cpu().numpy())) | |
else: | |
assert isinstance(output, T.Tensor) | |
marks = output | |
preds['diacs'].extend(list(marks.detach().cpu().numpy())) | |
#^ [b ts tw] | |
haraka, tanween, shadda = flat_2_3head(marks) | |
preds['haraka'].extend(haraka) | |
preds['tanween'].extend(tanween) | |
preds['shadda'].extend(shadda) | |
self.train(training) | |
return { | |
'diacritics': ( | |
#! FIXME! Due to batch slimming, output diacritics may need padding. | |
np.array(preds['haraka']), | |
np.array(preds["tanween"]), | |
np.array(preds["shadda"]), | |
), | |
'other': ( # Would be empty when !return_extra | |
np.array(preds['y_ctxt']), | |
np.array(preds['y_base']), | |
np.array(preds['diacs']), | |
np.array(preds['subword_lengths']), | |
) | |
} | |
def predict(self, dataloader): | |
training = self.training | |
self.eval() | |
preds = {'haraka': [], 'shadda': [], 'tanween': []} | |
print("> Predicting...") | |
for inputs, _ in tqdm(dataloader, total=len(dataloader)): | |
inputs[0] = inputs[0].to(self.device) | |
inputs[1] = inputs[1].to(self.device) | |
output = self(*inputs) | |
# output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1) | |
marks = output | |
#^ [b ts tw] | |
haraka, tanween, shadda = flat_2_3head(marks) | |
preds['haraka'].extend(haraka) | |
preds['tanween'].extend(tanween) | |
preds['shadda'].extend(shadda) | |
self.train(training) | |
return ( | |
np.array(preds['haraka']), | |
np.array(preds["tanween"]), | |
np.array(preds["shadda"]), | |
) |