from typing import List, Iterator, cast import copy import numpy as np import torch as T from torch import nn from torch.nn import functional as F from transformers import BertConfig, BertModel from transformers import AutoTokenizer, AutoModel, AutoConfig from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions class Diacritizer(nn.Module): def __init__( self, config, device=None, load_pretrained=True ) -> None: super().__init__() self._dummy = nn.Parameter(T.ones(1)) if 'modeling' in config: config = config['modeling'] self.config = config model_name = config.get('base_model', "CAMeL-Lab/bert-base-arabic-camelbert-mix-ner") self.tokenizer = AutoTokenizer.from_pretrained(model_name) if load_pretrained: self.token_model: BertModel = AutoModel.from_pretrained(model_name) else: marbert_config = AutoConfig.from_pretrained(model_name) self.token_model = AutoModel.from_config(marbert_config) self.num_classes = 15 self.diac_model_config = BertConfig(**config['diac_model_config']) self.token_model_config: BertConfig = self.token_model.config self.char_embs = nn.Embedding(config["num-chars"], embedding_dim=config["char-embed-dim"]) self.diac_emb_model = self.build_diac_model(self.token_model) self.down_project_token_embeds_deep = None self.down_project_token_embeds = None if 'token_hidden_size' in config: if config['token_hidden_size'] == 'auto': down_proj_size = self.diac_emb_model.config.hidden_size else: down_proj_size = config['token_hidden_size'] if config.get('deep-down-proj', False): self.down_project_token_embeds_deep = nn.Sequential( nn.Linear( self.token_model_config.hidden_size + config["char-embed-dim"], down_proj_size * 4, bias=False, ), nn.Tanh(), nn.Linear( down_proj_size * 4, down_proj_size, bias=False, ) ) # else: self.down_project_token_embeds = nn.Linear( self.token_model_config.hidden_size + config["char-embed-dim"], down_proj_size, bias=False, ) # assert self.down_project_token_embeds_deep is None or self.down_project_token_embeds is None classifier_feature_size = self.diac_model_config.hidden_size if config.get('deep-cls', False): # classifier_feature_size = 512 self.final_feature_transform = nn.Linear( self.diac_model_config.hidden_size + self.token_model_config.hidden_size, #^ diac_features + [residual from token_model] out_features=classifier_feature_size, bias=False ) else: self.final_feature_transform = None self.feature_layer_norm = nn.LayerNorm(classifier_feature_size) self.classifier = nn.Linear(classifier_feature_size, self.num_classes, bias=True) self.trim_model_(config) self.dropout = nn.Dropout(config['dropout']) self.sent_dropout_p = config['sentence_dropout'] self.closs = F.cross_entropy def build_diac_model(self, token_model=None): if self.config.get('pre-init-diac-model', False): model = copy.deepcopy(self.token_model) model.pooler = None model.embeddings.word_embeddings = None num_layers = self.config.get('keep-token-model-layers', None) model.encoder.layer = nn.ModuleList( list(model.encoder.layer[num_layers:num_layers*2]) ) model.encoder.config.num_hidden_layers = num_layers else: model = BertModel(self.diac_model_config) return model def trim_model_(self, config): self.token_model.pooler = None self.diac_emb_model.pooler = None # self.diac_emb_model.embeddings = None self.diac_emb_model.embeddings.word_embeddings = None num_token_model_kept_layers = config.get('keep-token-model-layers', None) if num_token_model_kept_layers is not None: self.token_model.encoder.layer = nn.ModuleList( list(self.token_model.encoder.layer[:num_token_model_kept_layers]) ) self.token_model.encoder.config.num_hidden_layers = num_token_model_kept_layers if not config.get('full-finetune', False): for param in self.token_model.parameters(): param.requires_grad = False finetune_last_layers = config.get('num-finetune-last-layers', 4) if finetune_last_layers > 0: unfrozen_layers = self.token_model.encoder.layer[-finetune_last_layers:] for layer in unfrozen_layers: for param in layer.parameters(): param.requires_grad = True def get_grouped_params(self): downstream_params: Iterator[nn.Parameter] = cast( Iterator, (param for module in (self.diac_emb_model, self.classifier, self.char_embs) for param in module.parameters()) ) pg = { 'pretrained': self.token_model.parameters(), 'downstream': downstream_params, } return pg @property def device(self): return self._dummy.device def step(self, xt, yt, mask=None, subword_lengths: T.Tensor=None): # ^ word_x, char_x, diac_x are Indices # ^ xt : self.preprocess((word_x, char_x, diac_x)), # ^ yt : T.tensor(diac_y, dtype=T.long), # ^ subword_lengths: T.tensor(subword_lengths, dtype=T.long) #< Move char_x, diac_x to device because they're small and trainable xt[0], xt[1], yt, subword_lengths = self._slim_batch_size(xt[0], xt[1], yt, subword_lengths) xt[0] = xt[0].to(self.device) xt[1] = xt[1].to(self.device) # xt[2] = xt[2].to(self.device) yt = yt.to(self.device) #^ yt: [b tw tc] Nb, Tword, Tchar = xt[1].shape if Tword * Tchar < 500: diac = self(*xt, subword_lengths) loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1), reduction='sum') else: num_chunks = Tword * Tchar / 300 loss = 0 for i in range(round(num_chunks+0.5)): _slice = slice(i*300, (i+1)*300) chunk = self._slice_batch(xt, _slice) diac = self(*chunk, subword_lengths[_slice]) chunk_loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1), reduction='sum') loss = loss + chunk_loss return loss def _slice_batch(self, xt: List[T.Tensor], _slice): return [xt[0][_slice], xt[1][_slice], xt[2][_slice]] def _slim_batch_size( self, tx: T.Tensor, cx: T.Tensor, yt: T.Tensor, subword_lengths: T.Tensor ): #^ tx : [b tt] #^ cx : [b tw tc] #^ yt : [b tw tc] token_nonpad_mask = tx.ne(self.tokenizer.pad_token_id) Ttoken = token_nonpad_mask.sum(1).max() tx = tx[:, :Ttoken] char_nonpad_mask = cx.ne(0) Tword = char_nonpad_mask.any(2).sum(1).max() Tchar = char_nonpad_mask.sum(2).max() cx = cx[:, :Tword, :Tchar] yt = yt[:, :Tword, :Tchar] subword_lengths = subword_lengths[:, :Tword] return tx, cx, yt, subword_lengths def token_dropout(self, toke_x): #^ toke_x : [b tw] if self.training: q = 1.0 - self.sent_dropout_p sdo = T.bernoulli(T.full(toke_x.shape, q)) toke_x[sdo == 0] = self.tokenizer.pad_token_id return toke_x def sentence_dropout(self, word_embs: T.Tensor): #^ word_embs : [b tw dwe] if self.training: q = 1.0 - self.sent_dropout_p sdo = T.bernoulli(T.full(word_embs.shape[:2], q)) sdo = sdo.detach().unsqueeze(-1).to(word_embs) word_embs = word_embs * sdo # toke_x[sdo == 0] = self.tokenizer.pad_token_id return word_embs def embed_tokens(self, input_ids: T.Tensor, attention_mask: T.Tensor): y: BaseModelOutputWithPoolingAndCrossAttentions y = self.token_model(input_ids, attention_mask=attention_mask) z = y.last_hidden_state return z def forward( self, toke_x : T.Tensor, char_x : T.Tensor, diac_x : T.Tensor, subword_lengths : T.Tensor, ): #^ toke_x : [b tt] #^ char_x : [b tw tc] #^ diac_x/labels : [b tw tc] #^ subword_lengths : [b, tw] # !TODO Use `subword_lengths` to aggregate subword embeddings first before ... # ... passing concatenated contextual embedding to chars in diac_model token_nonpad_mask = toke_x.ne(self.tokenizer.pad_token_id) char_nonpad_mask = char_x.ne(0) Nb, Tw, Tc = char_x.shape # assert Tw == Tw_0 and Tc == Tc_0, f"{Tw=} {Tw_0=}, {Tc=} {Tc_0=}" # toke_x = self.token_dropout(toke_x) token_embs = self.embed_tokens(toke_x, attention_mask=token_nonpad_mask) # token_embs = self.sentence_dropout(token_embs) #? Strip BOS,EOS token_embs = token_embs[:, 1:-1, ...] sent_word_strides = subword_lengths.cumsum(1) sent_enc: T.Tensor = T.zeros(Nb, Tw, token_embs.shape[-1]).to(token_embs) for i_b in range(Nb): token_embs_ib = token_embs[i_b] start_iw = 0 for i_word, end_iw in enumerate(sent_word_strides[i_b]): if end_iw == start_iw: break word_emb = token_embs_ib[start_iw : end_iw].sum(0) / (end_iw - start_iw) sent_enc[i_b, i_word] = word_emb start_iw = end_iw #^ sent_enc: [b tw dwe] char_x_flat = char_x.reshape(Nb*Tw, Tc) char_nonpad_mask = char_x_flat.gt(0) # ^ char_nonpad_mask [b*tw tc] char_x_flat = char_x_flat * char_nonpad_mask cembs = self.char_embs(char_x_flat) #^ cembs: [b*tw tc dce] wembs = sent_enc.unsqueeze(-2).expand(Nb, Tw, Tc, -1).view(Nb*Tw, Tc, -1) #^ wembs: [b tw dwe] => [b tw _ dwe] => [b*tw tc dwe] cw_embs = T.cat([cembs, wembs], dim=-1) #^ char_embs : [b*tw tc dcw] ; dcw = dc + dwe cw_embs = self.dropout(cw_embs) cw_embs_ = cw_embs if self.down_project_token_embeds is not None: cw_embs_ = self.down_project_token_embeds(cw_embs) if self.down_project_token_embeds_deep is not None: cw_embs_ = cw_embs_ + self.down_project_token_embeds_deep(cw_embs) cw_embs = cw_embs_ diac_enc: BaseModelOutputWithPoolingAndCrossAttentions diac_enc = self.diac_emb_model(inputs_embeds=cw_embs, attention_mask=char_nonpad_mask) diac_emb = diac_enc.last_hidden_state diac_emb = self.dropout(diac_emb) #^ diac_emb: [b*tw tc dce] diac_emb = diac_emb.view(Nb, Tw, Tc, -1) sent_residual = sent_enc.unsqueeze(2).expand(-1, -1, Tc, -1) final_feature = T.cat([sent_residual, diac_emb], dim=-1) if self.final_feature_transform is not None: final_feature = self.final_feature_transform(final_feature) final_feature = F.tanh(final_feature) final_feature = self.dropout(final_feature) else: final_feature = diac_emb # final_feature = self.feature_layer_norm(final_feature) diac_out = self.classifier(final_feature) # if T.isnan(diac_out).any(): # breakpoint() return diac_out def predict(self, dataloader): from tqdm import tqdm import diac_utils as du training = self.training self.eval() preds = {'haraka': [], 'shadda': [], 'tanween': []} print("> Predicting...") for inputs, _, subword_lengths in tqdm(dataloader, total=len(dataloader)): inputs[0] = inputs[0].to(self.device) inputs[1] = inputs[1].to(self.device) output = self(*inputs, subword_lengths).detach() marks = np.argmax(T.softmax(output, dim=-1).cpu().numpy(), axis=-1) #^ [b ts tw] haraka, tanween, shadda = du.flat_2_3head(marks) preds['haraka'].extend(haraka) preds['tanween'].extend(tanween) preds['shadda'].extend(shadda) self.train(training) return ( np.array(preds['haraka']), np.array(preds["tanween"]), np.array(preds["shadda"]), ) if __name__ == "__main__": model = Diacritizer({ "num-chars": 36, "hidden_size": 768, "char-embed-dim": 32, "dropout": 0.25, "sentence_dropout": 0.2, "diac_model_config": { "num_layers": 4, "hidden_size": 768 + 32, "intermediate_size": (768 + 32) * 4, }, }, load_pretrained=False) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(model) print(f"{trainable_params:,}/{total_params:,} Trainable Parameters")