Spaces:
Runtime error
Runtime error
from typing import List, Iterator, cast | |
import copy | |
import numpy as np | |
import torch as T | |
from torch import nn | |
from torch.nn import functional as F | |
from transformers import BertConfig, BertModel | |
from transformers import AutoTokenizer, AutoModel, AutoConfig | |
from transformers import PreTrainedModel | |
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions | |
class Diacritizer(nn.Module): | |
def __init__( | |
self, | |
config, | |
device=None, | |
load_pretrained=True | |
) -> None: | |
super().__init__() | |
self._dummy = nn.Parameter(T.ones(1)) | |
if 'modeling' in config: | |
config = config['modeling'] | |
self.config = config | |
model_name = config.get('base_model', "CAMeL-Lab/bert-base-arabic-camelbert-mix-ner") | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
if load_pretrained: | |
self.token_model: BertModel = AutoModel.from_pretrained(model_name) | |
else: | |
marbert_config = AutoConfig.from_pretrained(model_name) | |
self.token_model = AutoModel.from_config(marbert_config) | |
self.num_classes = 15 | |
self.diac_model_config = BertConfig(**config['diac_model_config']) | |
self.token_model_config: BertConfig = self.token_model.config | |
self.char_embs = nn.Embedding(config["num-chars"], embedding_dim=config["char-embed-dim"]) | |
self.diac_emb_model = self.build_diac_model(self.token_model) | |
self.down_project_token_embeds_deep = None | |
self.down_project_token_embeds = None | |
if 'token_hidden_size' in config: | |
if config['token_hidden_size'] == 'auto': | |
down_proj_size = self.diac_emb_model.config.hidden_size | |
else: | |
down_proj_size = config['token_hidden_size'] | |
if config.get('deep-down-proj', False): | |
self.down_project_token_embeds_deep = nn.Sequential( | |
nn.Linear( | |
self.token_model_config.hidden_size + config["char-embed-dim"], | |
down_proj_size * 4, | |
bias=False, | |
), | |
nn.Tanh(), | |
nn.Linear( | |
down_proj_size * 4, | |
down_proj_size, | |
bias=False, | |
) | |
) | |
# else: | |
self.down_project_token_embeds = nn.Linear( | |
self.token_model_config.hidden_size + config["char-embed-dim"], | |
down_proj_size, | |
bias=False, | |
) | |
# assert self.down_project_token_embeds_deep is None or self.down_project_token_embeds is None | |
classifier_feature_size = self.diac_model_config.hidden_size | |
if config.get('deep-cls', False): | |
# classifier_feature_size = 512 | |
self.final_feature_transform = nn.Linear( | |
self.diac_model_config.hidden_size | |
+ self.token_model_config.hidden_size, | |
#^ diac_features + [residual from token_model] | |
out_features=classifier_feature_size, | |
bias=False | |
) | |
else: | |
self.final_feature_transform = None | |
self.feature_layer_norm = nn.LayerNorm(classifier_feature_size) | |
self.classifier = nn.Linear(classifier_feature_size, self.num_classes, bias=True) | |
self.trim_model_(config) | |
self.dropout = nn.Dropout(config['dropout']) | |
self.sent_dropout_p = config['sentence_dropout'] | |
self.closs = F.cross_entropy | |
def build_diac_model(self, token_model=None): | |
if self.config.get('pre-init-diac-model', False): | |
model = copy.deepcopy(self.token_model) | |
model.pooler = None | |
model.embeddings.word_embeddings = None | |
num_layers = self.config.get('keep-token-model-layers', None) | |
model.encoder.layer = nn.ModuleList( | |
list(model.encoder.layer[num_layers:num_layers*2]) | |
) | |
model.encoder.config.num_hidden_layers = num_layers | |
else: | |
model = BertModel(self.diac_model_config) | |
return model | |
def trim_model_(self, config): | |
self.token_model.pooler = None | |
self.diac_emb_model.pooler = None | |
# self.diac_emb_model.embeddings = None | |
self.diac_emb_model.embeddings.word_embeddings = None | |
num_token_model_kept_layers = config.get('keep-token-model-layers', None) | |
if num_token_model_kept_layers is not None: | |
self.token_model.encoder.layer = nn.ModuleList( | |
list(self.token_model.encoder.layer[:num_token_model_kept_layers]) | |
) | |
self.token_model.encoder.config.num_hidden_layers = num_token_model_kept_layers | |
if not config.get('full-finetune', False): | |
for param in self.token_model.parameters(): | |
param.requires_grad = False | |
finetune_last_layers = config.get('num-finetune-last-layers', 4) | |
if finetune_last_layers > 0: | |
unfrozen_layers = self.token_model.encoder.layer[-finetune_last_layers:] | |
for layer in unfrozen_layers: | |
for param in layer.parameters(): | |
param.requires_grad = True | |
def get_grouped_params(self): | |
downstream_params: Iterator[nn.Parameter] = cast( | |
Iterator, | |
(param | |
for module in (self.diac_emb_model, self.classifier, self.char_embs) | |
for param in module.parameters()) | |
) | |
pg = { | |
'pretrained': self.token_model.parameters(), | |
'downstream': downstream_params, | |
} | |
return pg | |
def device(self): | |
return self._dummy.device | |
def step(self, xt, yt, mask=None, subword_lengths: T.Tensor=None): | |
# ^ word_x, char_x, diac_x are Indices | |
# ^ xt : self.preprocess((word_x, char_x, diac_x)), | |
# ^ yt : T.tensor(diac_y, dtype=T.long), | |
# ^ subword_lengths: T.tensor(subword_lengths, dtype=T.long) | |
#< Move char_x, diac_x to device because they're small and trainable | |
xt[0], xt[1], yt, subword_lengths = self._slim_batch_size(xt[0], xt[1], yt, subword_lengths) | |
xt[0] = xt[0].to(self.device) | |
xt[1] = xt[1].to(self.device) | |
# xt[2] = xt[2].to(self.device) | |
yt = yt.to(self.device) | |
#^ yt: [b tw tc] | |
Nb, Tword, Tchar = xt[1].shape | |
if Tword * Tchar < 500: | |
diac = self(*xt, subword_lengths) | |
loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1), reduction='sum') | |
else: | |
num_chunks = Tword * Tchar / 300 | |
loss = 0 | |
for i in range(round(num_chunks+0.5)): | |
_slice = slice(i*300, (i+1)*300) | |
chunk = self._slice_batch(xt, _slice) | |
diac = self(*chunk, subword_lengths[_slice]) | |
chunk_loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1), reduction='sum') | |
loss = loss + chunk_loss | |
return loss | |
def _slice_batch(self, xt: List[T.Tensor], _slice): | |
return [xt[0][_slice], xt[1][_slice], xt[2][_slice]] | |
def _slim_batch_size( | |
self, | |
tx: T.Tensor, | |
cx: T.Tensor, | |
yt: T.Tensor, | |
subword_lengths: T.Tensor | |
): | |
#^ tx : [b tt] | |
#^ cx : [b tw tc] | |
#^ yt : [b tw tc] | |
token_nonpad_mask = tx.ne(self.tokenizer.pad_token_id) | |
Ttoken = token_nonpad_mask.sum(1).max() | |
tx = tx[:, :Ttoken] | |
char_nonpad_mask = cx.ne(0) | |
Tword = char_nonpad_mask.any(2).sum(1).max() | |
Tchar = char_nonpad_mask.sum(2).max() | |
cx = cx[:, :Tword, :Tchar] | |
yt = yt[:, :Tword, :Tchar] | |
subword_lengths = subword_lengths[:, :Tword] | |
return tx, cx, yt, subword_lengths | |
def token_dropout(self, toke_x): | |
#^ toke_x : [b tw] | |
if self.training: | |
q = 1.0 - self.sent_dropout_p | |
sdo = T.bernoulli(T.full(toke_x.shape, q)) | |
toke_x[sdo == 0] = self.tokenizer.pad_token_id | |
return toke_x | |
def sentence_dropout(self, word_embs: T.Tensor): | |
#^ word_embs : [b tw dwe] | |
if self.training: | |
q = 1.0 - self.sent_dropout_p | |
sdo = T.bernoulli(T.full(word_embs.shape[:2], q)) | |
sdo = sdo.detach().unsqueeze(-1).to(word_embs) | |
word_embs = word_embs * sdo | |
# toke_x[sdo == 0] = self.tokenizer.pad_token_id | |
return word_embs | |
def embed_tokens(self, input_ids: T.Tensor, attention_mask: T.Tensor): | |
y: BaseModelOutputWithPoolingAndCrossAttentions | |
y = self.token_model(input_ids, attention_mask=attention_mask) | |
z = y.last_hidden_state | |
return z | |
def forward( | |
self, | |
toke_x : T.Tensor, | |
char_x : T.Tensor, | |
diac_x : T.Tensor, | |
subword_lengths : T.Tensor, | |
): | |
#^ toke_x : [b tt] | |
#^ char_x : [b tw tc] | |
#^ diac_x/labels : [b tw tc] | |
#^ subword_lengths : [b, tw] | |
# !TODO Use `subword_lengths` to aggregate subword embeddings first before ... | |
# ... passing concatenated contextual embedding to chars in diac_model | |
token_nonpad_mask = toke_x.ne(self.tokenizer.pad_token_id) | |
char_nonpad_mask = char_x.ne(0) | |
Nb, Tw, Tc = char_x.shape | |
# assert Tw == Tw_0 and Tc == Tc_0, f"{Tw=} {Tw_0=}, {Tc=} {Tc_0=}" | |
# toke_x = self.token_dropout(toke_x) | |
token_embs = self.embed_tokens(toke_x, attention_mask=token_nonpad_mask) | |
# token_embs = self.sentence_dropout(token_embs) | |
#? Strip BOS,EOS | |
token_embs = token_embs[:, 1:-1, ...] | |
sent_word_strides = subword_lengths.cumsum(1) | |
sent_enc: T.Tensor = T.zeros(Nb, Tw, token_embs.shape[-1]).to(token_embs) | |
for i_b in range(Nb): | |
token_embs_ib = token_embs[i_b] | |
start_iw = 0 | |
for i_word, end_iw in enumerate(sent_word_strides[i_b]): | |
if end_iw == start_iw: break | |
word_emb = token_embs_ib[start_iw : end_iw].sum(0) / (end_iw - start_iw) | |
sent_enc[i_b, i_word] = word_emb | |
start_iw = end_iw | |
#^ sent_enc: [b tw dwe] | |
char_x_flat = char_x.reshape(Nb*Tw, Tc) | |
char_nonpad_mask = char_x_flat.gt(0) | |
# ^ char_nonpad_mask [b*tw tc] | |
char_x_flat = char_x_flat * char_nonpad_mask | |
cembs = self.char_embs(char_x_flat) | |
#^ cembs: [b*tw tc dce] | |
wembs = sent_enc.unsqueeze(-2).expand(Nb, Tw, Tc, -1).view(Nb*Tw, Tc, -1) | |
#^ wembs: [b tw dwe] => [b tw _ dwe] => [b*tw tc dwe] | |
cw_embs = T.cat([cembs, wembs], dim=-1) | |
#^ char_embs : [b*tw tc dcw] ; dcw = dc + dwe | |
cw_embs = self.dropout(cw_embs) | |
cw_embs_ = cw_embs | |
if self.down_project_token_embeds is not None: | |
cw_embs_ = self.down_project_token_embeds(cw_embs) | |
if self.down_project_token_embeds_deep is not None: | |
cw_embs_ = cw_embs_ + self.down_project_token_embeds_deep(cw_embs) | |
cw_embs = cw_embs_ | |
diac_enc: BaseModelOutputWithPoolingAndCrossAttentions | |
diac_enc = self.diac_emb_model(inputs_embeds=cw_embs, attention_mask=char_nonpad_mask) | |
diac_emb = diac_enc.last_hidden_state | |
diac_emb = self.dropout(diac_emb) | |
#^ diac_emb: [b*tw tc dce] | |
diac_emb = diac_emb.view(Nb, Tw, Tc, -1) | |
sent_residual = sent_enc.unsqueeze(2).expand(-1, -1, Tc, -1) | |
final_feature = T.cat([sent_residual, diac_emb], dim=-1) | |
if self.final_feature_transform is not None: | |
final_feature = self.final_feature_transform(final_feature) | |
final_feature = F.tanh(final_feature) | |
final_feature = self.dropout(final_feature) | |
else: | |
final_feature = diac_emb | |
# final_feature = self.feature_layer_norm(final_feature) | |
diac_out = self.classifier(final_feature) | |
# if T.isnan(diac_out).any(): | |
# breakpoint() | |
return diac_out | |
def predict(self, dataloader): | |
from tqdm import tqdm | |
import diac_utils as du | |
training = self.training | |
self.eval() | |
preds = {'haraka': [], 'shadda': [], 'tanween': []} | |
print("> Predicting...") | |
for inputs, _, subword_lengths in tqdm(dataloader, total=len(dataloader)): | |
inputs[0] = inputs[0].to(self.device) | |
inputs[1] = inputs[1].to(self.device) | |
output = self(*inputs, subword_lengths).detach() | |
marks = np.argmax(T.softmax(output, dim=-1).cpu().numpy(), axis=-1) | |
#^ [b ts tw] | |
haraka, tanween, shadda = du.flat_2_3head(marks) | |
preds['haraka'].extend(haraka) | |
preds['tanween'].extend(tanween) | |
preds['shadda'].extend(shadda) | |
self.train(training) | |
return ( | |
np.array(preds['haraka']), | |
np.array(preds["tanween"]), | |
np.array(preds["shadda"]), | |
) | |
if __name__ == "__main__": | |
model = Diacritizer({ | |
"num-chars": 36, | |
"hidden_size": 768, | |
"char-embed-dim": 32, | |
"dropout": 0.25, | |
"sentence_dropout": 0.2, | |
"diac_model_config": { | |
"num_layers": 4, | |
"hidden_size": 768 + 32, | |
"intermediate_size": (768 + 32) * 4, | |
}, | |
}, load_pretrained=False) | |
total_params = sum(p.numel() for p in model.parameters()) | |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
print(model) | |
print(f"{trainable_params:,}/{total_params:,} Trainable Parameters") |