| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Literal |
|
|
| from transformers.models.modernbert.configuration_modernbert import ModernBertConfig |
|
|
|
|
| class ModChemBertConfig(ModernBertConfig): |
| """ |
| Configuration class for ModChemBert models. |
| |
| This configuration class extends ModernBertConfig with additional parameters specific to |
| chemical molecule modeling and custom pooling strategies for classification/regression tasks. |
| It accepts all arguments and keyword arguments from ModernBertConfig. |
| |
| Args: |
| classifier_pooling (str, optional): Pooling strategy for sequence classification. |
| Available options: |
| - "cls": Use CLS token representation |
| - "mean": Attention-weighted average pooling |
| - "sum_mean": Sum all hidden states across layers, then mean pool over sequence (ChemLM approach) |
| - "sum_sum": Sum all hidden states across layers, then sum pool over sequence |
| - "mean_mean": Mean all hidden states across layers, then mean pool over sequence |
| - "mean_sum": Mean all hidden states across layers, then sum pool over sequence |
| - "max_cls": Element-wise max pooling over last k hidden states, then take CLS token |
| - "cls_mha": Multi-head attention with CLS token as query and full sequence as keys/values |
| - "max_seq_mha": Max pooling over last k states + multi-head attention with CLS as query |
| - "mean_seq_mha": Mean pooling over last k states + multi-head attention with CLS as query |
| - "max_seq_mean": Max pooling over last k hidden states, then mean pooling over sequence |
| Defaults to "sum_mean". |
| classifier_pooling_num_attention_heads (int, optional): Number of attention heads for multi-head attention |
| pooling strategies (cls_mha, max_seq_mha, mean_seq_mha). Defaults to 4. |
| classifier_pooling_attention_dropout (float, optional): Dropout probability for multi-head attention |
| pooling strategies (cls_mha, max_seq_mha, mean_seq_mha). Defaults to 0.0. |
| classifier_pooling_last_k (int, optional): Number of last hidden layers to use for max/mean pooling |
| strategies (max_cls, max_seq_mha, mean_seq_mha, max_seq_mean). Defaults to 8. |
| *args: Variable length argument list passed to ModernBertConfig. |
| **kwargs: Arbitrary keyword arguments passed to ModernBertConfig. |
| |
| Note: |
| This class inherits all configuration parameters from ModernBertConfig including |
| hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, etc. |
| """ |
|
|
| model_type = "modchembert" |
|
|
| def __init__( |
| self, |
| *args, |
| classifier_pooling: Literal[ |
| "cls", |
| "mean", |
| "sum_mean", |
| "sum_sum", |
| "mean_mean", |
| "mean_sum", |
| "max_cls", |
| "cls_mha", |
| "max_seq_mha", |
| "mean_seq_mha", |
| "max_seq_mean", |
| ] = "max_seq_mha", |
| classifier_pooling_num_attention_heads: int = 4, |
| classifier_pooling_attention_dropout: float = 0.0, |
| classifier_pooling_last_k: int = 8, |
| **kwargs, |
| ): |
| valid_classifier_pooling_options = [ |
| "cls", |
| "mean", |
| "sum_mean", |
| "sum_sum", |
| "mean_mean", |
| "mean_sum", |
| "max_cls", |
| "cls_mha", |
| "max_seq_mha", |
| "mean_seq_mha", |
| "max_seq_mean", |
| ] |
| if classifier_pooling not in valid_classifier_pooling_options: |
| raise ValueError( |
| f"Invalid value for `classifier_pooling`, should be one of {valid_classifier_pooling_options}, " |
| f"but is {classifier_pooling}." |
| ) |
|
|
| |
| super().__init__(*args, classifier_pooling="cls", **kwargs) |
| |
| self.classifier_pooling = classifier_pooling |
| self.classifier_pooling_num_attention_heads = classifier_pooling_num_attention_heads |
| self.classifier_pooling_attention_dropout = classifier_pooling_attention_dropout |
| self.classifier_pooling_last_k = classifier_pooling_last_k |
| self.auto_map = { |
| "AutoConfig": "configuration_modchembert.ModChemBertConfig", |
| "AutoModel": "modeling_modchembert.ModChemBertModel", |
| "AutoModelForMaskedLM": "modeling_modchembert.ModChemBertForMaskedLM", |
| "AutoModelForSequenceClassification": "modeling_modchembert.ModChemBertForSequenceClassification", |
| } |
|
|