| from torch import nn |
| from torch import LongTensor |
| from transformers import PreTrainedModel |
|
|
| from .configuration import CobaldParserConfig |
| from .encoder import WordTransformerEncoder |
| from .mlp_classifier import MlpClassifier |
| from .dependency_classifier import DependencyClassifier |
| from .utils import ( |
| build_padding_mask, |
| build_null_mask, |
| prepend_cls, |
| remove_nulls, |
| add_nulls |
| ) |
|
|
|
|
| class CobaldParser(PreTrainedModel): |
| """Morpho-Syntax-Semantic Parser.""" |
|
|
| config_class = CobaldParserConfig |
|
|
| def __init__(self, config: CobaldParserConfig): |
| super().__init__(config) |
|
|
| self.encoder = WordTransformerEncoder( |
| model_name=config.encoder_model_name |
| ) |
| embedding_size = self.encoder.get_embedding_size() |
|
|
| self.classifiers = nn.ModuleDict() |
| self.classifiers["null"] = MlpClassifier( |
| input_size=self.encoder.get_embedding_size(), |
| hidden_size=config.null_classifier_hidden_size, |
| n_classes=config.consecutive_null_limit + 1, |
| activation=config.activation, |
| dropout=config.dropout |
| ) |
| if "lemma_rule" in config.vocabulary: |
| self.classifiers["lemma_rule"] = MlpClassifier( |
| input_size=embedding_size, |
| hidden_size=config.lemma_classifier_hidden_size, |
| n_classes=len(config.vocabulary["lemma_rule"]), |
| activation=config.activation, |
| dropout=config.dropout |
| ) |
| if "joint_feats" in config.vocabulary: |
| self.classifiers["joint_feats"] = MlpClassifier( |
| input_size=embedding_size, |
| hidden_size=config.morphology_classifier_hidden_size, |
| n_classes=len(config.vocabulary["joint_feats"]), |
| activation=config.activation, |
| dropout=config.dropout |
| ) |
| if "ud_deprel" in config.vocabulary or "eud_deprel" in config.vocabulary: |
| self.classifiers["syntax"] = DependencyClassifier( |
| input_size=embedding_size, |
| hidden_size=config.dependency_classifier_hidden_size, |
| n_rels_ud=len(config.vocabulary["ud_deprel"]), |
| n_rels_eud=len(config.vocabulary["eud_deprel"]), |
| activation=config.activation, |
| dropout=config.dropout |
| ) |
| if "misc" in config.vocabulary: |
| self.classifiers["misc"] = MlpClassifier( |
| input_size=embedding_size, |
| hidden_size=config.misc_classifier_hidden_size, |
| n_classes=len(config.vocabulary["misc"]), |
| activation=config.activation, |
| dropout=config.dropout |
| ) |
| if "deepslot" in config.vocabulary: |
| self.classifiers["deepslot"] = MlpClassifier( |
| input_size=embedding_size, |
| hidden_size=config.deepslot_classifier_hidden_size, |
| n_classes=len(config.vocabulary["deepslot"]), |
| activation=config.activation, |
| dropout=config.dropout |
| ) |
| if "semclass" in config.vocabulary: |
| self.classifiers["semclass"] = MlpClassifier( |
| input_size=embedding_size, |
| hidden_size=config.semclass_classifier_hidden_size, |
| n_classes=len(config.vocabulary["semclass"]), |
| activation=config.activation, |
| dropout=config.dropout |
| ) |
|
|
| def forward( |
| self, |
| words: list[list[str]], |
| counting_masks: LongTensor = None, |
| lemma_rules: LongTensor = None, |
| joint_feats: LongTensor = None, |
| deps_ud: LongTensor = None, |
| deps_eud: LongTensor = None, |
| miscs: LongTensor = None, |
| deepslots: LongTensor = None, |
| semclasses: LongTensor = None, |
| sent_ids: list[str] = None, |
| texts: list[str] = None, |
| inference_mode: bool = False |
| ) -> dict: |
| output = {} |
|
|
| |
| words_with_cls = prepend_cls(words) |
| words_without_nulls = remove_nulls(words_with_cls) |
| |
| embeddings_without_nulls = self.encoder(words_without_nulls) |
| |
| null_output = self.classifiers["null"](embeddings_without_nulls, counting_masks) |
| output["counting_mask"] = null_output['preds'] |
| output["loss"] = null_output["loss"] |
|
|
| |
| |
| if inference_mode: |
| |
| output["words"] = add_nulls(words, null_output["preds"]) |
| else: |
| output["words"] = words |
|
|
| |
| |
| embeddings = self.encoder(output["words"]) |
|
|
| |
| if "lemma_rule" in self.classifiers: |
| lemma_output = self.classifiers["lemma_rule"](embeddings, lemma_rules) |
| output["lemma_rules"] = lemma_output['preds'] |
| output["loss"] += lemma_output['loss'] |
|
|
| if "joint_feats" in self.classifiers: |
| joint_feats_output = self.classifiers["joint_feats"](embeddings, joint_feats) |
| output["joint_feats"] = joint_feats_output['preds'] |
| output["loss"] += joint_feats_output['loss'] |
|
|
| |
| if "syntax" in self.classifiers: |
| padding_mask = build_padding_mask(output["words"], self.device) |
| null_mask = build_null_mask(output["words"], self.device) |
| deps_output = self.classifiers["syntax"]( |
| embeddings, |
| deps_ud, |
| deps_eud, |
| null_mask, |
| padding_mask |
| ) |
| output["deps_ud"] = deps_output['preds_ud'] |
| output["deps_eud"] = deps_output['preds_eud'] |
| output["loss"] += deps_output['loss_ud'] + deps_output['loss_eud'] |
|
|
| |
| if "misc" in self.classifiers: |
| misc_output = self.classifiers["misc"](embeddings, miscs) |
| output["miscs"] = misc_output['preds'] |
| output["loss"] += misc_output['loss'] |
|
|
| |
| if "deepslot" in self.classifiers: |
| deepslot_output = self.classifiers["deepslot"](embeddings, deepslots) |
| output["deepslots"] = deepslot_output['preds'] |
| output["loss"] += deepslot_output['loss'] |
|
|
| if "semclass" in self.classifiers: |
| semclass_output = self.classifiers["semclass"](embeddings, semclasses) |
| output["semclasses"] = semclass_output['preds'] |
| output["loss"] += semclass_output['loss'] |
|
|
| return output |