simplified_phi2 / phi2_configuration.py
BucketOfFish's picture
Just uploading entire rewritten codebase at once
a420fe7
raw
history blame
2.39 kB
import math
from transformers import PretrainedConfig
class Phi2Config(PretrainedConfig):
model_type = "phi2" # not necessary unless you want to register model with auto classes
attribute_map = {
"max_position_embeddings": "initial_cos_sin_cache_len",
"hidden_size": "d_embedding",
"num_attention_heads": "n_attn_heads",
"num_hidden_layers": "n_blocks",
}
def __init__(
self,
vocab_size: int = 50295, # this includes the extra tokens included by Phi2 in tokenizer_config.json
vocab_chunk_for_gpu_efficiency: int = 64,
initial_cos_sin_cache_len: int = 2048,
d_embedding: int = 1024, # 2560?
n_blocks: int = 20, # 32?
n_attn_heads: int = 16, # 32?
use_flash_attn: bool = False,
use_flash_rotary: bool = False,
use_fused_dense: bool = False,
attn_pdrop: float = 0.0,
embd_pdrop: float = 0.0,
resid_pdrop: float = 0.0,
layer_norm_epsilon: float = 1e-5,
weight_initialization_range: float = 0.02,
tie_word_embeddings: bool = False, # whether embedding weights are shared between the encoder and decoder
checkpointing: bool = False, # whether to use gradient checkpointing to reduce memory usage (I think)
**kwargs
) -> None:
self.vocab_size = (
math.ceil(
vocab_size / vocab_chunk_for_gpu_efficiency
) * vocab_chunk_for_gpu_efficiency
)
self.initial_cos_sin_cache_len = initial_cos_sin_cache_len
self.d_embedding = d_embedding
self.n_blocks = n_blocks
self.n_attn_heads = n_attn_heads
self.use_flash_attn = use_flash_attn
self.use_flash_rotary = use_flash_rotary
self.use_fused_dense = use_fused_dense
self.attn_pdrop = attn_pdrop
self.embd_pdrop = embd_pdrop
self.resid_pdrop = resid_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.weight_initialization_range = weight_initialization_range
self.checkpointing = checkpointing
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
if __name__ == "__main__":
phi2_config = Phi2Config()
# phi2_config.save_pretrained("phi2_config")
# phi2_config = Phi2Config.from_pretrained("phi2_config")
# phi2_config.push_to_hub("phi2_config")