import tensorflow as tf from tensorflow.keras import layers, activations, initializers, regularizers import numpy as np # Define RMSNorm class RMSNorm(tf.keras.layers.Layer): def __init__(self, epsilon=1e-6): super(RMSNorm, self).__init__() self.epsilon = epsilon def call(self, inputs): # Calculate the RMS and normalize the input rms = tf.sqrt(tf.reduce_mean(tf.square(inputs), axis=-1, keepdims=True)) return inputs / (rms + self.epsilon) class MiniSunConfig: def __init__(self, vocab_size=30522, max_position_embeddings=1024, hidden_size=512, num_attention_heads=8, intermediate_size=2048, num_hidden_layers=8, dropout_rate=0.1, weight_decay=0.01, learning_rate=1e-4, total_steps=2500, warmup_ratio=0.5, restart_period=500, l1_reg=0.0, l2_reg=0.01): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.dropout_rate = dropout_rate self.weight_decay = weight_decay self.learning_rate = learning_rate self.total_steps = total_steps self.warmup_ratio = warmup_ratio self.restart_period = restart_period self.l1_reg = l1_reg # L1 regularization strength self.l2_reg = l2_reg # L2 regularization strength @tf.keras.utils.register_keras_serializable() class MiniSunModel(tf.keras.Model): def __init__(self, config): super(MiniSunModel, self).__init__() self.config = config # Embedding layers for token and dynamic positional embeddings (RoPE) self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size) self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size) # Initialize an empty list for decoder blocks self.decoder_blocks = [] # Final normalization and head self.layer_norm = RMSNorm(epsilon=1e-6) self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal(), kernel_regularizer=regularizers.l2(config.l2_reg)) # Stochastic depth (layer drop) self.layer_dropout = tf.keras.layers.Dropout(config.dropout_rate) def build(self, input_shape): # Create transformer decoder blocks based on the model configuration self.decoder_blocks = [self._build_decoder_block() for _ in range(self.config.num_hidden_layers)] super(MiniSunModel, self).build(input_shape) def _build_decoder_block(self): # Decoder block with multi-query attention and feed-forward layers, using RMSNorm and regularization return [ layers.MultiHeadAttention(num_heads=self.config.num_attention_heads, key_dim=self.config.hidden_size, kernel_initializer=initializers.he_normal(), kernel_regularizer=regularizers.l2(self.config.l2_reg)), RMSNorm(epsilon=1e-6), # Use RMSNorm instead of LayerNormalization layers.Dense(self.config.intermediate_size, activation=activations.elu, kernel_initializer=initializers.he_normal(), kernel_regularizer=regularizers.l1_l2(self.config.l1_reg, self.config.l2_reg)), layers.Dense(self.config.hidden_size, kernel_initializer=initializers.he_normal(), kernel_regularizer=regularizers.l1_l2(self.config.l1_reg, self.config.l2_reg)), layers.Dropout(self.config.dropout_rate) ] def call(self, inputs, attention_mask=None, training=False): input_ids = inputs['input_ids'] position_ids = tf.range(start=0, limit=tf.shape(input_ids)[-1], delta=1) # Token and position embeddings with RoPE (Rotary Positional Embeddings) embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids) # Adjust attention mask to correct shape [batch_size, 1, 1, seq_len] if attention_mask is not None: attention_mask = tf.cast(attention_mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.float32) # Apply decoder blocks with stochastic depth and gradient clipping hidden_states = embeddings for mha, norm, ffn1, ffn2, dropout in self.decoder_blocks: attn_output = mha(hidden_states, hidden_states, attention_mask=attention_mask, training=training) attn_output = dropout(attn_output, training=training) hidden_states = norm(attn_output + hidden_states) # Add & RMSNorm # Feed-forward layers ffn_output = ffn1(hidden_states) ffn_output = ffn2(ffn_output) ffn_output = dropout(ffn_output, training=training) hidden_states = norm(ffn_output + hidden_states) # Add & RMSNorm # Final layer normalization hidden_states = self.layer_norm(hidden_states) # LM Head for token generation logits = self.lm_head(hidden_states) # Softmax layer for confidence softmax_output = tf.nn.softmax(logits, axis=-1) return logits, softmax_output def get_config(self): return {'config': self.config.__dict__} @classmethod def from_config(cls, config): return cls(MiniSunConfig(**config['config'])) def compute_loss(self, labels, logits): if labels is None or logits is None: raise ValueError("Labels and logits cannot be None.") # Add label smoothing to loss computation return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True, label_smoothing=0.1) def train_step(self, data): inputs, labels = data input_ids = inputs['input_ids'] attention_mask = inputs['attention_mask'] with tf.GradientTape() as tape: logits, _ = self(inputs, training=True) loss = self.compute_loss(labels, logits) gradients = tape.gradient(loss, self.trainable_variables) # Gradient clipping for stability gradients = [tf.clip_by_value(g, -1.0, 1.0) for g in gradients] self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) # Compute predictions and metrics logits_for_metrics = tf.argmax(logits, axis=-1) labels_for_metrics = tf.reshape(labels, [-1]) # Flatten labels logits_for_metrics = tf.reshape(logits_for_metrics, [-1]) # Flatten predictions for metric in self.metrics: metric.update_state(labels_for_metrics, logits_for_metrics) # Return loss and metrics results = {m.name: m.result() for m in self.metrics} results['loss'] = loss return results def create_model(config): model = MiniSunModel(config) # Optimizer with weight decay and mixed precision training optimizer = tf.keras.mixed_precision.LossScaleOptimizer( tf.keras.optimizers.AdamW(learning_rate=config.learning_rate, weight_decay=config.weight_decay) ) strategy = tf.distribute.get_strategy() with strategy.scope(): model.compile(optimizer=optimizer, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy']) return model def cosine_annealing_with_warmup(step, config): """Learning rate schedule with warm-up and cosine annealing.""" warmup_steps = int(config.total_steps * config.warmup_ratio) if step < warmup_steps: return config.learning_rate * (step / warmup_steps) else: cos_step = step - warmup_steps total_cos_steps = config.total_steps - warmup_steps return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps)) def cosine_annealing_with_restarts(step, config): """Learning rate schedule with warm-up and cosine annealing with restarts.""" warmup_steps = int(config.total_steps * config.warmup_ratio) current_cycle = step // config.restart_period effective_step = step % config.restart_period if effective_step < warmup_steps: return config.learning_rate * (effective_step / warmup_steps) else: cos_step = effective_step - warmup_steps total_cos_steps = config.restart_period - warmup_steps return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps)) # Configuration config = MiniSunConfig(l1_reg=1e-5, l2_reg=3e-4) # Initialize model with improvements model = create_model(config) # Create LearningRateScheduler callbacks lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_warmup(step, config)) lr_scheduler_with_restarts = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_restarts(step, config))