flashppi / configuration_flashppi.py
andrecornman's picture
Upload FlashPPI model
2eddd0d verified
"""FlashPPI model configuration."""
from transformers import PretrainedConfig
class FlashPPIConfig(PretrainedConfig):
model_type = "flashppi"
def __init__(
self,
# gLM2 backbone config (defaults match gLM2_650M)
plm_dim: int = 1280,
plm_depth: int = 33,
plm_heads: int = 20,
plm_vocab_size: int = 37,
plm_norm_eps: float = 1e-5,
plm_swiglu_multiple_of: int = 256,
plm_ffn_dim_multiplier: float = None,
# FlashPPI head config
clip_embed_dim: int = 1024,
contact_embed_dim: int = 1280,
contact_num_heads: int = 8,
contact_transformer_depth: int = 2,
max_position_embeddings: int = 512,
use_flash_attention: bool = True,
**kwargs
):
super().__init__(**kwargs)
# gLM2 config
self.plm_dim = plm_dim
self.plm_depth = plm_depth
self.plm_heads = plm_heads
self.plm_vocab_size = plm_vocab_size
self.plm_norm_eps = plm_norm_eps
self.plm_swiglu_multiple_of = plm_swiglu_multiple_of
self.plm_ffn_dim_multiplier = plm_ffn_dim_multiplier
# FlashPPI config
self.clip_embed_dim = clip_embed_dim
self.contact_embed_dim = contact_embed_dim
self.contact_num_heads = contact_num_heads
self.contact_transformer_depth = contact_transformer_depth
self.max_position_embeddings = max_position_embeddings
self.use_flash_attention = use_flash_attention