Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch as T | |
from tqdm import tqdm | |
from torch import nn | |
from torch.nn import functional as F | |
from components.k_lstm import K_LSTM | |
from components.attention import Attention | |
from data_utils import DatasetUtils | |
from diac_utils import flat2_3head, flat_2_3head | |
class DiacritizerD2(nn.Module): | |
def __init__(self, config): | |
super(DiacritizerD2, self).__init__() | |
self.max_word_len = config["train"]["max-word-len"] | |
self.max_sent_len = config["train"]["max-sent-len"] | |
self.char_embed_dim = config["train"]["char-embed-dim"] | |
self.final_dropout_p = config["train"]["final-dropout"] | |
self.sent_dropout_p = config["train"]["sent-dropout"] | |
self.diac_dropout_p = config["train"]["diac-dropout"] | |
self.vertical_dropout = config['train']['vertical-dropout'] | |
self.recurrent_dropout = config['train']['recurrent-dropout'] | |
self.recurrent_dropout_mode = config['train'].get('recurrent-dropout-mode', 'gal_tied') | |
self.recurrent_activation = config['train'].get('recurrent-activation', 'sigmoid') | |
self.sent_lstm_units = config["train"]["sent-lstm-units"] | |
self.word_lstm_units = config["train"]["word-lstm-units"] | |
self.decoder_units = config["train"]["decoder-units"] | |
self.sent_lstm_layers = config["train"]["sent-lstm-layers"] | |
self.word_lstm_layers = config["train"]["word-lstm-layers"] | |
self.cell = config['train'].get('rnn-cell', 'lstm') | |
self.num_layers = config["train"].get("num-layers", 2) | |
self.RNN_Layer = K_LSTM | |
self.batch_first = config['train'].get('batch-first', True) | |
self.device = 'cuda' if T.cuda.is_available() else 'cpu' | |
self.num_classes = 15 | |
def build(self, wembs: T.Tensor, abjad_size: int): | |
self.closs = F.cross_entropy | |
self.bloss = F.binary_cross_entropy_with_logits | |
rnn_kargs = dict( | |
recurrent_dropout_mode=self.recurrent_dropout_mode, | |
recurrent_activation=self.recurrent_activation, | |
) | |
self.sent_lstm = self.RNN_Layer( | |
input_size=300, | |
hidden_size=self.sent_lstm_units, | |
num_layers=self.sent_lstm_layers, | |
bidirectional=True, | |
vertical_dropout=self.vertical_dropout, | |
recurrent_dropout=self.recurrent_dropout, | |
batch_first=self.batch_first, | |
**rnn_kargs, | |
) | |
self.word_lstm = self.RNN_Layer( | |
input_size=self.sent_lstm_units * 2 + self.char_embed_dim, | |
hidden_size=self.word_lstm_units, | |
num_layers=self.word_lstm_layers, | |
bidirectional=True, | |
vertical_dropout=self.vertical_dropout, | |
recurrent_dropout=self.recurrent_dropout, | |
batch_first=self.batch_first, | |
return_states=True, | |
**rnn_kargs, | |
) | |
self.char_embs = nn.Embedding( | |
abjad_size, | |
self.char_embed_dim, | |
padding_idx=0, | |
) | |
self.attention = Attention( | |
kind="dot", | |
query_dim=self.word_lstm_units * 2, | |
input_dim=self.sent_lstm_units * 2, | |
) | |
self.word_embs = T.tensor(wembs).clone().to(dtype=T.float32) | |
self.word_embs = self.word_embs.to(self.device) | |
self.classifier = nn.Linear(self.attention.Dout + self.word_lstm_units * 2, self.num_classes) | |
self.dropout = nn.Dropout(self.final_dropout_p) | |
def forward(self, sents, words, labels=None, subword_lengths=None): | |
#^ sents : [b ts] | |
#^ words : [b ts tw] | |
#^ labels: [b ts tw] | |
max_words = min(self.max_sent_len, sents.shape[1]) | |
word_mask = words.ne(0.).float() | |
#^ word_mask: [b ts tw] | |
if self.training: | |
q = 1.0 - self.sent_dropout_p | |
sdo = T.bernoulli(T.full(sents.shape, q)) | |
sents_do = sents * sdo.long() | |
#^ sents_do : [b ts] ; DO(ts) | |
wembs = self.word_embs[sents_do] | |
#^ wembs : [b ts dw] ; DO(ts) | |
else: | |
wembs = self.word_embs[sents] | |
#^ wembs : [b ts dw] | |
sent_enc = self.sent_lstm(wembs.to(self.device)) | |
#^ sent_enc : [b ts dwe] | |
sentword_do = sent_enc.unsqueeze(2) | |
#^ sentword_do : [b ts _ dwe] | |
sentword_do = self.dropout(sentword_do * word_mask.unsqueeze(-1)) | |
#^ sentword_do : [b ts tw dwe] | |
word_index = words.view(-1, self.max_word_len) | |
#^ word_index: [b*ts tw]? | |
cembs = self.char_embs(word_index) | |
#^ cembs : [b*ts tw dc] | |
sentword_do = sentword_do.view(-1, self.max_word_len, self.sent_lstm_units * 2) | |
#^ sentword_do : [b*ts tw dwe] | |
char_embs = T.cat([cembs, sentword_do], dim=-1) | |
#^ char_embs : [b*ts tw dcw] ; dcw = dc + dwe | |
char_enc, _ = self.word_lstm(char_embs) | |
#^ char_enc: [b*ts tw dce] | |
char_enc_reshaped = char_enc.view(-1, max_words, self.max_word_len, self.word_lstm_units * 2) | |
# #^ char_enc: [b ts tw dce] | |
omit_self_mask = (1.0 - T.eye(max_words)).unsqueeze(0).to(self.device) | |
attn_enc, attn_map = self.attention(char_enc_reshaped, sent_enc, word_mask.bool(), prejudice_mask=omit_self_mask) | |
# # #^ attn_enc: [b ts tw dae] | |
attn_enc = attn_enc.reshape(-1, self.max_word_len, self.attention.Dout) | |
# #^ attn_enc: [b*ts tw dae] | |
final_vec = T.cat([attn_enc, char_enc], dim=-1) | |
diac_out = self.classifier(self.dropout(final_vec)) | |
#^ diac_out: [b*ts tw 7] | |
diac_out = diac_out.view(-1, max_words, self.max_word_len, self.num_classes) | |
#^ diac_out: [b ts tw 7] | |
if not self.batch_first: | |
diac_out = diac_out.swapaxes(1, 0) | |
return diac_out | |
def step(self, xt, yt, mask=None): | |
xt[1] = xt[1].to(self.device) | |
xt[2] = xt[2].to(self.device) | |
yt = yt.to(self.device) | |
#^ yt: [b ts tw] | |
diac, _ = self(*xt) | |
loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1)) | |
return loss | |
def predict(self, dataloader): | |
training = self.training | |
self.eval() | |
preds = {'haraka': [], 'shadda': [], 'tanween': []} | |
print("> Predicting...") | |
for inputs, _ in tqdm(dataloader, total=len(dataloader)): | |
inputs[0] = inputs[0].to(self.device) | |
inputs[1] = inputs[1].to(self.device) | |
diac = self(*inputs) | |
output = np.argmax(T.softmax(diac.detach(), dim=-1).cpu().numpy(), axis=-1) | |
#^ [b ts tw] | |
haraka, tanween, shadda = flat_2_3head(output) | |
preds['haraka'].extend(haraka) | |
preds['tanween'].extend(tanween) | |
preds['shadda'].extend(shadda) | |
self.train(training) | |
return ( | |
np.array(preds['haraka']), | |
np.array(preds["tanween"]), | |
np.array(preds["shadda"]), | |
) | |
class DiacritizerD3(nn.Module): | |
def __init__(self, config, device='cuda'): | |
super(DiacritizerD3, self).__init__() | |
self.max_word_len = config["train"]["max-word-len"] | |
self.max_sent_len = config["train"]["max-sent-len"] | |
self.char_embed_dim = config["train"]["char-embed-dim"] | |
self.sent_dropout_p = config["train"]["sent-dropout"] | |
self.diac_dropout_p = config["train"]["diac-dropout"] | |
self.vertical_dropout = config['train']['vertical-dropout'] | |
self.recurrent_dropout = config['train']['recurrent-dropout'] | |
self.recurrent_dropout_mode = config['train'].get('recurrent-dropout-mode', 'gal_tied') | |
self.recurrent_activation = config['train'].get('recurrent-activation', 'sigmoid') | |
self.sent_lstm_units = config["train"]["sent-lstm-units"] | |
self.word_lstm_units = config["train"]["word-lstm-units"] | |
self.decoder_units = config["train"]["decoder-units"] | |
self.sent_lstm_layers = config["train"]["sent-lstm-layers"] | |
self.word_lstm_layers = config["train"]["word-lstm-layers"] | |
self.cell = config['train'].get('rnn-cell', 'lstm') | |
self.num_layers = config["train"].get("num-layers", 2) | |
self.RNN_Layer = K_LSTM | |
self.batch_first = config['train'].get('batch-first', True) | |
self.baseline = config["train"].get("baseline", False) | |
self.device = device | |
def build(self, wembs: T.Tensor, abjad_size: int): | |
self.closs = F.cross_entropy | |
self.bloss = F.binary_cross_entropy_with_logits | |
rnn_kargs = dict( | |
recurrent_dropout_mode=self.recurrent_dropout_mode, | |
recurrent_activation=self.recurrent_activation, | |
) | |
self.sent_lstm = self.RNN_Layer( | |
input_size=300, | |
hidden_size=self.sent_lstm_units, | |
num_layers=self.sent_lstm_layers, | |
bidirectional=True, | |
vertical_dropout=self.vertical_dropout, | |
recurrent_dropout=self.recurrent_dropout, | |
batch_first=self.batch_first, | |
**rnn_kargs, | |
) | |
self.word_lstm = self.RNN_Layer( | |
input_size=self.sent_lstm_units * 2 + self.char_embed_dim, | |
hidden_size=self.word_lstm_units, | |
num_layers=self.word_lstm_layers, | |
bidirectional=True, | |
vertical_dropout=self.vertical_dropout, | |
recurrent_dropout=self.recurrent_dropout, | |
batch_first=self.batch_first, | |
return_states=True, | |
**rnn_kargs, | |
) | |
self.char_embs = nn.Embedding( | |
abjad_size, | |
self.char_embed_dim, | |
padding_idx=0, | |
) | |
self.attention = Attention( | |
kind="dot", | |
query_dim=self.word_lstm_units * 2, | |
input_dim=self.sent_lstm_units * 2, | |
) | |
self.lstm_decoder = self.RNN_Layer( | |
input_size=self.word_lstm_units * 2 + self.attention.Dout + 8, | |
hidden_size=self.word_lstm_units * 2, | |
num_layers=1, | |
bidirectional=False, | |
vertical_dropout=self.vertical_dropout, | |
recurrent_dropout=self.recurrent_dropout, | |
batch_first=self.batch_first, | |
return_states=True, | |
**rnn_kargs, | |
) | |
self.word_embs = T.tensor(wembs, dtype=T.float32) | |
self.classifier = nn.Linear(self.lstm_decoder.hidden_size, 15) | |
self.dropout = nn.Dropout(0.2) | |
def forward(self, sents, words, labels): | |
#^ sents : [b ts] | |
#^ words : [b ts tw] | |
#^ labels: [b ts tw] | |
word_mask = words.ne(0.).float() | |
#^ word_mask: [b ts tw] | |
if self.training: | |
q = 1.0 - self.sent_dropout_p | |
sdo = T.bernoulli(T.full(sents.shape, q)) | |
sents_do = sents * sdo.long() | |
#^ sents_do : [b ts] ; DO(ts) | |
wembs = self.word_embs[sents_do] | |
#^ wembs : [b ts dw] ; DO(ts) | |
else: | |
wembs = self.word_embs[sents] | |
#^ wembs : [b ts dw] | |
sent_enc = self.sent_lstm(wembs.to(self.device)) | |
#^ sent_enc : [b ts dwe] | |
sentword_do = sent_enc.unsqueeze(2) | |
#^ sentword_do : [b ts _ dwe] | |
sentword_do = self.dropout(sentword_do * word_mask.unsqueeze(-1)) | |
#^ sentword_do : [b ts tw dwe] | |
word_index = words.view(-1, self.max_word_len) | |
#^ word_index: [b*ts tw]? | |
cembs = self.char_embs(word_index) | |
#^ cembs : [b*ts tw dc] | |
sentword_do = sentword_do.view(-1, self.max_word_len, self.sent_lstm_units * 2) | |
#^ sentword_do : [b*ts tw dwe] | |
char_embs = T.cat([cembs, sentword_do], dim=-1) | |
#^ char_embs : [b*ts tw dcw] ; dcw = dc + dwe | |
char_enc, _ = self.word_lstm(char_embs) | |
#^ char_enc: [b*ts tw dce] | |
char_enc_reshaped = char_enc.view(-1, self.max_sent_len, self.max_word_len, self.word_lstm_units * 2) | |
#^ char_enc: [b ts tw dce] | |
omit_self_mask = (1.0 - T.eye(self.max_sent_len)).unsqueeze(0).to(self.device) | |
attn_enc, attn_map = self.attention(char_enc_reshaped, sent_enc, word_mask.bool(), prejudice_mask=omit_self_mask) | |
#^ attn_enc: [b ts tw dae] | |
attn_enc = attn_enc.view(-1, self.max_sent_len*self.max_word_len, self.attention.Dout) | |
#^ attn_enc: [b*ts tw dae] | |
if self.training and self.diac_dropout_p > 0: | |
q = 1.0 - self.diac_dropout_p | |
ddo = T.bernoulli(T.full(labels.shape[:-1], q)) | |
labels = labels * ddo.unsqueeze(-1).long().to(self.device) | |
#^ labels : [b ts tw] ; DO(ts) | |
labels = labels.view(-1, self.max_sent_len*self.max_word_len, 8).float() | |
#^ labels: [b*ts tw 8] | |
char_enc = char_enc.view(-1, self.max_sent_len*self.max_word_len, self.word_lstm_units * 2) | |
final_vec = T.cat([attn_enc, char_enc, labels], dim=-1) | |
#^ final_vec: [b ts*tw dae+8] | |
dec_out, _ = self.lstm_decoder(final_vec) | |
#^ dec_out: [b*ts tw du] | |
dec_out = dec_out.reshape(-1, self.max_word_len, self.lstm_decoder.hidden_size) | |
diac_out = self.classifier(self.dropout(dec_out)) | |
#^ diac_out: [b*ts tw 7] | |
diac_out = diac_out.view(-1, self.max_sent_len, self.max_word_len, 15) | |
#^ diac_out: [b ts tw 7] | |
if not self.batch_first: | |
diac_out = diac_out.swapaxes(1, 0) | |
return diac_out, attn_map | |
def predict_sample(self, sents, words, labels): | |
word_mask = words.ne(0.).float() | |
#^ mask: [b ts tw 1] | |
if self.training: | |
q = 1.0 - self.sent_dropout_p | |
sdo = T.bernoulli(T.full(sents.shape, q)) | |
sents_do = sents * sdo.long() | |
#^ sents_do : [b ts] ; DO(ts) | |
wembs = self.word_embs[sents_do] | |
#^ wembs : [b ts dw] ; DO(ts) | |
else: | |
wembs = self.word_embs[sents] | |
#^ wembs : [b ts dw] | |
sent_enc = self.sent_lstm(wembs.to(self.device)) | |
#^ sent_enc : [b ts dwe] | |
sentword_do = sent_enc.unsqueeze(2) | |
#^ sentword_do : [b ts _ dwe] | |
sentword_do = self.dropout(sentword_do * word_mask.unsqueeze(-1)) | |
#^ sentword_do : [b ts tw dwe] | |
word_index = words.view(-1, self.max_word_len) | |
#^ word_index: [b*ts tw]? | |
cembs = self.char_embs(word_index) | |
#^ cembs : [b*ts tw dc] | |
sentword_do = sentword_do.view(-1, self.max_word_len, self.sent_lstm_units * 2) | |
#^ sentword_do : [b*ts tw dwe] | |
char_embs = T.cat([cembs, sentword_do], dim=-1) | |
#^ char_embs : [b*ts tw dcw] ; dcw = dc + dwe | |
char_enc, _ = self.word_lstm(char_embs) | |
#^ char_enc: [b*ts tw dce] | |
#^ word_states: ([b*ts dce], [b*ts dce]) | |
char_enc = char_enc.view(-1, self.max_sent_len, self.max_word_len, self.word_lstm_units*2) | |
#^ char_enc: [b ts tw dce] | |
omit_self_mask = (1.0 - T.eye(self.max_sent_len)).unsqueeze(0).to(self.device) | |
attn_enc, _ = self.attention(char_enc, sent_enc, word_mask.bool(), prejudice_mask=omit_self_mask) | |
#^ attn_enc: [b ts tw dae] | |
all_out = T.zeros(*char_enc.size()[:-1], 15).to(self.device) | |
#^ all_out: [b ts tw 7] | |
batch_sz = char_enc.size()[0] | |
#^ batch_sz: b | |
zeros = T.zeros(1, batch_sz, self.lstm_decoder.hidden_size).to(self.device) | |
#^ zeros: [1 b du] | |
bos_tag = T.tensor([0,0,0,0,0,1,0,0]).unsqueeze(0) | |
#^ bos_tag: [1 8] | |
prev_label = T.cat([bos_tag]*batch_sz).to(self.device).float() | |
# bos_vec = T.cat([bos_tag]*batch_sz).to(self.device).float() | |
#^ prev_label: [b 8] | |
for ts in range(self.max_sent_len): | |
dec_hx = (zeros, zeros) | |
#^ dec_hx: [1 b du] | |
for tw in range(self.max_word_len): | |
final_vec = T.cat([attn_enc[:,ts,tw,:], char_enc[:,ts,tw,:], prev_label], dim=-1).unsqueeze(1) | |
#^ final_vec: [b 1 dce+8] | |
dec_out, dec_hx = self.lstm_decoder(final_vec, dec_hx) | |
#^ dec_out: [b 1 du] | |
dec_out = dec_out.squeeze(0) | |
dec_out = dec_out.transpose(0,1) | |
logits_raw = self.classifier(self.dropout(dec_out)) | |
#^ logits_raw: [b 1 15] | |
out_idx = T.max(T.softmax(logits_raw.squeeze(), dim=-1), dim=-1)[1] | |
haraka, tanween, shadda = flat2_3head(out_idx.detach().cpu().numpy()) | |
haraka_onehot = T.eye(6)[haraka].float().to(self.device) | |
#^ haraka_onehot+bos_tag: [b 6] | |
tanween = T.tensor(tanween).float().unsqueeze(-1).to(self.device) | |
shadda = T.tensor(shadda).float().unsqueeze(-1).to(self.device) | |
prev_label = T.cat([haraka_onehot, tanween, shadda], dim=-1) | |
all_out[:,ts,tw,:] = logits_raw.squeeze() | |
if not self.batch_first: | |
all_out = all_out.swapaxes(1, 0) | |
return all_out | |
def step(self, xt, yt, mask=None): | |
xt[1] = xt[1].to(self.device) | |
xt[2] = xt[2].to(self.device) | |
#^ yt: [b ts tw] | |
yt = yt.to(self.device) | |
if self.training: | |
diac, _ = self(*xt) | |
else: | |
diac = self.predict_sample(*xt) | |
#^ diac[0] : [b ts tw 5] | |
loss = self.closs(diac.view(-1,15), yt.view(-1)) | |
return loss | |
def predict(self, dataloader): | |
training = self.training | |
self.eval() | |
preds = {'haraka': [], 'shadda': [], 'tanween': []} | |
print("> Predicting...") | |
for inputs, _ in tqdm(dataloader, total=len(dataloader)): | |
inputs[1] = inputs[1].to(self.device) | |
inputs[2] = inputs[2].to(self.device) | |
diac = self.predict_sample(*inputs) | |
output = np.argmax(T.softmax(diac.detach(), dim=-1).cpu().numpy(), axis=-1) | |
#^ [b ts tw] | |
haraka, tanween, shadda = flat_2_3head(output) | |
preds['haraka'].extend(haraka) | |
preds['tanween'].extend(tanween) | |
preds['shadda'].extend(shadda) | |
self.train(training) | |
return ( | |
np.array(preds['haraka']), | |
np.array(preds["tanween"]), | |
np.array(preds["shadda"]), | |
) | |
if __name__ == "__main__": | |
import yaml | |
config_path = "configs/dd/config_d2.yaml" | |
model_path = "models/tashkeela-d2.pt" | |
with open(config_path, 'r', encoding="utf-8") as file: | |
config = yaml.load(file, Loader=yaml.FullLoader) | |
data_utils = DatasetUtils(config) | |
vocab_size = len(data_utils.letter_list) | |
word_embeddings = data_utils.embeddings | |
model = DiacritizerD2(config, device='cpu') | |
model.build(word_embeddings, vocab_size) | |
model.load_state_dict(T.load(model_path, map_location=T.device('cpu'))["state_dict"]) |