{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"none","dataSources":[{"sourceId":242901,"sourceType":"modelInstanceVersion","modelInstanceId":207463,"modelId":229179}],"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":false}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"# Required Packages","metadata":{"_uuid":"d6f701cd-8ce4-400a-8851-7cf812774af4","_cell_guid":"e7c343d7-1c57-4701-8be4-57a4d864815b","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false}}},{"cell_type":"code","source":"!pip install torch transformers tiktoken datasets","metadata":{"_uuid":"dffff9b7-6a40-449e-8aa2-bd194224a9d4","_cell_guid":"4dfca29c-1a01-4c05-9390-a914f2df253c","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2025-01-27T04:45:05.870880Z","iopub.execute_input":"2025-01-27T04:45:05.871263Z","iopub.status.idle":"2025-01-27T04:45:11.910743Z","shell.execute_reply.started":"2025-01-27T04:45:05.871192Z","shell.execute_reply":"2025-01-27T04:45:11.909357Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Imports and Setup","metadata":{"_uuid":"359eb12d-bfa0-4c77-8bf6-f19679a479fa","_cell_guid":"59babb5b-d6d2-4970-8cbf-4d6191a4bb28","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false}}},{"cell_type":"code","source":"import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.utils.data import Dataset, DataLoader, RandomSampler\nfrom torch.optim import AdamW\nfrom torch.cuda.amp import GradScaler, autocast\nfrom datasets import load_dataset\nimport numpy as np\nfrom IPython.display import clear_output\nimport json\nimport os\nfrom typing import Dict, List, Optional, Tuple\nimport tiktoken\n\n# Check if GPU is available\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nprint(f\"Using device: {device}\")\n\nshuffle_generator = torch.Generator()\nshuffle_generator.manual_seed(42)\n\nos.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"","metadata":{"_uuid":"f2f9ae8c-1949-4d4f-8a66-d624702411f0","_cell_guid":"1552ddb7-f240-45da-a4bc-cab212cfdd15","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2025-01-27T04:45:11.912717Z","iopub.execute_input":"2025-01-27T04:45:11.913088Z","iopub.status.idle":"2025-01-27T04:45:14.671621Z","shell.execute_reply.started":"2025-01-27T04:45:11.913056Z","shell.execute_reply":"2025-01-27T04:45:14.670335Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Configuration Class","metadata":{"_uuid":"c8da3525-99d9-4787-899c-b2cb122f9917","_cell_guid":"d30aecbe-d860-46d6-9000-77c028d87b5c","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false}}},{"cell_type":"code","source":"class Config:\n def __init__(self):\n # Model architecture\n self.vocab_size = 100283\n self.max_position_embeddings = 1024\n self.hidden_size = 768\n self.num_layers = 6\n self.num_heads = 12\n self.intermediate_size = 3072\n self.dropout = 0.1\n \n # Training\n self.batch_size = 4\n self.learning_rate = 3e-4\n self.weight_decay = 0.01\n self.warmup_steps = 1000\n self.max_epochs = 3\n self.gradient_accumulation_steps = 8\n self.max_grad_norm = 1.0\n \n # Checkpointing\n self.checkpoint_every = 300 # Save every N batches\n self.evaluation_every = 500 # Evaluate every N batches","metadata":{"_uuid":"b666ec5e-bfb0-4452-8754-fb5eaa4c0895","_cell_guid":"33c5c394-cc03-421e-a63e-eb0d0c4e7475","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2025-01-27T04:45:14.673364Z","iopub.execute_input":"2025-01-27T04:45:14.673999Z","iopub.status.idle":"2025-01-27T04:45:14.680072Z","shell.execute_reply.started":"2025-01-27T04:45:14.673964Z","shell.execute_reply":"2025-01-27T04:45:14.678806Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Dataset Class","metadata":{"_uuid":"d43efd80-8e6b-4d8a-8943-9569a5513a37","_cell_guid":"fd40d156-ba7f-4b29-a59b-04a96964191a","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false}}},{"cell_type":"code","source":"class TextDataset(Dataset):\n def __init__(self, text: str, block_size: int, tokenizer, chunk_size: int = 1024):\n self.tokenizer = tokenizer\n self.block_size = block_size\n self.chunk_size = chunk_size\n\n # Tokenize the entire text first\n tokens = self.tokenizer.encode(text, allowed_special={'', '', '', '', '', ''})\n \n # Process the tokenized text in chunks\n self.examples = []\n for chunk_start in range(0, len(tokens), self.chunk_size):\n # Get the token chunk\n chunk = tokens[chunk_start:chunk_start + self.chunk_size]\n \n # Create overlapping blocks from the tokenized chunk\n for i in range(0, len(chunk) - block_size + 1):\n self.examples.append(chunk[i:i + block_size])\n\n def __len__(self):\n return len(self.examples)\n \n def __getitem__(self, i):\n return torch.tensor(self.examples[i], dtype=torch.long)","metadata":{"_uuid":"9711c202-367b-4b8e-8f3a-af1e815592b8","_cell_guid":"0189f7c1-bc6a-409f-b818-b7059b3817ee","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2025-01-27T04:45:14.681668Z","iopub.execute_input":"2025-01-27T04:45:14.682077Z","iopub.status.idle":"2025-01-27T04:45:14.704555Z","shell.execute_reply.started":"2025-01-27T04:45:14.682033Z","shell.execute_reply":"2025-01-27T04:45:14.703351Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Model Architecture Classes","metadata":{"_uuid":"25bf9134-4c81-4179-8293-b890b1fe43eb","_cell_guid":"336edd03-8e84-44a9-9986-d078baf5ebcc","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false}}},{"cell_type":"code","source":"class AttentionHead(nn.Module):\n def __init__(self, config: Config):\n super().__init__()\n self.head_dim = config.hidden_size // config.num_heads\n self.query = nn.Linear(config.hidden_size, self.head_dim)\n self.key = nn.Linear(config.hidden_size, self.head_dim)\n self.value = nn.Linear(config.hidden_size, self.head_dim)\n \n def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:\n Q = self.query(x)\n K = self.key(x)\n V = self.value(x)\n \n # Scaled dot-product attention\n scale = Q.size(-1) ** 0.5\n scores = torch.matmul(Q, K.transpose(-2, -1)) / scale\n \n if mask is not None:\n scores = scores.masked_fill(mask == 0, float('-inf'))\n \n attention = F.softmax(scores, dim=-1)\n return torch.matmul(attention, V)\n\nclass MultiHeadAttention(nn.Module):\n def __init__(self, config: Config):\n super().__init__()\n self.heads = nn.ModuleList([\n AttentionHead(config) for _ in range(config.num_heads)\n ])\n self.linear = nn.Linear(config.hidden_size, config.hidden_size)\n self.dropout = nn.Dropout(config.dropout)\n \n def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:\n heads = [head(x, mask) for head in self.heads]\n multihead = torch.cat(heads, dim=-1)\n return self.dropout(self.linear(multihead))\n\nclass TransformerBlock(nn.Module):\n def __init__(self, config: Config):\n super().__init__()\n self.attention = MultiHeadAttention(config)\n self.norm1 = nn.LayerNorm(config.hidden_size)\n self.norm2 = nn.LayerNorm(config.hidden_size)\n self.feed_forward = nn.Sequential(\n nn.Linear(config.hidden_size, config.intermediate_size),\n nn.GELU(),\n nn.Linear(config.intermediate_size, config.hidden_size),\n nn.Dropout(config.dropout)\n )\n \n def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:\n # Attention with residual connection and layer norm\n attended = self.attention(x, mask)\n x = self.norm1(x + attended)\n \n # Feed forward with residual connection and layer norm\n fed_forward = self.feed_forward(x)\n return self.norm2(x + fed_forward)\n\nclass SmallLanguageModel(nn.Module):\n def __init__(self, config: Config):\n super().__init__()\n self.config = config\n \n # Token and position embeddings\n self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)\n self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)\n \n # Transformer blocks\n self.transformer_blocks = nn.ModuleList([\n TransformerBlock(config) for _ in range(config.num_layers)\n ])\n \n self.dropout = nn.Dropout(config.dropout)\n self.ln_f = nn.LayerNorm(config.hidden_size)\n self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n \n # Initialize weights\n self.apply(self._init_weights)\n \n def _init_weights(self, module):\n if isinstance(module, (nn.Linear, nn.Embedding)):\n module.weight.data.normal_(mean=0.0, std=0.02)\n if isinstance(module, nn.Linear) and module.bias is not None:\n module.bias.data.zero_()\n elif isinstance(module, nn.LayerNorm):\n module.bias.data.zero_()\n module.weight.data.fill_(1.0)\n \n def get_causal_mask(self, size: int) -> torch.Tensor:\n mask = torch.triu(torch.ones(size, size), diagonal=1).bool()\n return ~mask\n \n def forward(self, input_ids: torch.Tensor) -> torch.Tensor:\n b, t = input_ids.size()\n \n # Create position indices and causal mask\n positions = torch.arange(0, t, dtype=torch.long, device=input_ids.device)\n mask = self.get_causal_mask(t).to(input_ids.device)\n \n # Get token and position embeddings\n token_embeddings = self.token_embedding(input_ids)\n position_embeddings = self.position_embedding(positions)\n \n # Combine embeddings\n x = self.dropout(token_embeddings + position_embeddings)\n \n # Apply transformer blocks\n for block in self.transformer_blocks:\n x = block(x, mask)\n \n x = self.ln_f(x)\n logits = self.head(x)\n \n return logits","metadata":{"_uuid":"6c7ce1ea-e86f-41c0-93a5-ab1d3372adba","_cell_guid":"cf7cdc13-912a-43a5-815f-7a9415bbd86b","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2025-01-27T04:45:14.705751Z","iopub.execute_input":"2025-01-27T04:45:14.706154Z","iopub.status.idle":"2025-01-27T04:45:14.730069Z","shell.execute_reply.started":"2025-01-27T04:45:14.706111Z","shell.execute_reply":"2025-01-27T04:45:14.728832Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Trainer Class","metadata":{"_uuid":"815ebb1d-b356-4fe5-a175-fa98acd717fc","_cell_guid":"06a19f00-fe43-4767-af8e-f9021cd8c994","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false}}},{"cell_type":"code","source":"class TrainingMode:\n PRETRAIN = \"pretrain\"\n RESUME = \"resume\"\n NEW_DATASET = \"new_dataset\"\n FINETUNE = \"finetune\"\n\nclass TrainerConfig:\n @staticmethod\n def get_pretrain_config():\n config = Config()\n # Default pretraining parameters are already set in Config class\n return config\n\n @staticmethod\n def get_finetune_config():\n config = Config()\n # Modify for finetuning\n config.learning_rate = 2e-5\n config.max_epochs = 3\n config.batch_size = 4\n config.weight_decay = 0.02\n config.warmup_steps = 100\n config.gradient_accumulation_steps = 16\n return config\n\nclass Trainer:\n def __init__(self, model: nn.Module, train_dataset: Dataset,\n val_dataset: Optional[Dataset] = None):\n self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n self.model = model.to(self.device)\n self.train_dataset = train_dataset\n self.val_dataset = val_dataset\n self.scaler = GradScaler()\n self.optimizer = None\n\n def setup_optimizer(self, config: Config, mode: str):\n if mode == TrainingMode.FINETUNE:\n # Layer-wise learning rate decay for finetuning\n params = []\n assigned_params = set()\n\n def add_params_to_group(param_names, lr_scale):\n group_params = []\n for name, param in self.model.named_parameters():\n if any(pn in name for pn in param_names) and param not in assigned_params:\n group_params.append(param)\n assigned_params.add(param)\n if group_params:\n params.append({\n 'params': group_params,\n 'lr': config.learning_rate * lr_scale\n })\n\n # Add embedding parameters (lowest learning rate)\n add_params_to_group(['embedding'], 0.1)\n\n # Add transformer layers with progressive learning rates\n for i in range(config.num_layers):\n lr_scale = 1 - (0.1 * (config.num_layers - i - 1) / config.num_layers)\n add_params_to_group([f'transformer_blocks.{i}.'], lr_scale)\n\n # Add output layer parameters (highest learning rate)\n add_params_to_group(['head', 'ln_f'], 1.0)\n\n # Add remaining parameters\n remaining_params = [p for n, p in self.model.named_parameters()\n if p not in assigned_params]\n if remaining_params:\n params.append({\n 'params': remaining_params,\n 'lr': config.learning_rate\n })\n\n self.optimizer = AdamW(params, weight_decay=config.weight_decay)\n else:\n # Standard optimizer for pretraining\n self.optimizer = AdamW(\n self.model.parameters(),\n lr=config.learning_rate,\n weight_decay=config.weight_decay\n )\n\n # Move optimizer states to the correct device\n for state in self.optimizer.state.values():\n for k, v in state.items():\n if isinstance(v, torch.Tensor):\n state[k] = v.to(self.device)\n\n def save_checkpoint(self, epoch: int, batch_idx: int, loss: float, config: Config, mode: str):\n checkpoint = {\n 'epoch': epoch,\n 'batch_idx': batch_idx,\n 'model_state_dict': self.model.state_dict(),\n 'optimizer_state_dict': self.optimizer.state_dict(),\n 'scaler_state_dict': self.scaler.state_dict(),\n 'loss': loss,\n 'config': vars(config),\n 'original_mode': mode\n }\n path = 'checkpoint.pt'\n torch.save(checkpoint, path)\n print(f\"Saved checkpoint to {path}\")\n\n def load_checkpoint(self, path: str, mode: str) -> Tuple[int, int, str]:\n print(f\"Loading checkpoint from {path}\")\n checkpoint = torch.load(path, map_location=self.device, weights_only=False)\n\n if mode == TrainingMode.NEW_DATASET:\n # Only load model weights for new dataset\n self.model.load_state_dict(checkpoint['model_state_dict'])\n return 0, 0, checkpoint.get('original_mode', TrainingMode.PRETRAIN)\n elif mode == TrainingMode.RESUME:\n # Load model state\n self.model.load_state_dict(checkpoint['model_state_dict'])\n \n # Load optimizer and scaler states\n self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n self.scaler.load_state_dict(checkpoint['scaler_state_dict'])\n \n # Move optimizer states to correct device\n for state in self.optimizer.state.values():\n for k, v in state.items():\n if isinstance(v, torch.Tensor):\n state[k] = v.to(self.device)\n \n return checkpoint['epoch'], checkpoint['batch_idx'], checkpoint.get('original_mode', TrainingMode.PRETRAIN)\n else:\n # For direct pretrain/finetune, just load model weights\n self.model.load_state_dict(checkpoint['model_state_dict'])\n return 0, 0, mode\n\n def train(self, mode: str, checkpoint_path: Optional[str] = None):\n # Get appropriate config first\n if mode == TrainingMode.RESUME and checkpoint_path:\n checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)\n original_mode = checkpoint.get('original_mode', TrainingMode.PRETRAIN)\n config = (TrainerConfig.get_finetune_config() \n if original_mode == TrainingMode.FINETUNE \n else TrainerConfig.get_pretrain_config())\n else:\n config = (TrainerConfig.get_finetune_config() \n if mode == TrainingMode.FINETUNE \n else TrainerConfig.get_pretrain_config())\n\n # Setup optimizer before loading checkpoint\n self.setup_optimizer(config, mode)\n\n # Initialize starting points\n start_epoch = 0\n start_batch = 0\n original_mode = mode\n\n # Load checkpoint if provided\n if checkpoint_path:\n start_epoch, start_batch, original_mode = self.load_checkpoint(checkpoint_path, mode)\n\n # Training loop setup\n train_loader = DataLoader(\n self.train_dataset,\n batch_size=config.batch_size,\n sampler=RandomSampler(self.train_dataset, generator=shuffle_generator),\n pin_memory=True,\n num_workers=4\n )\n\n print(f\"Training started in {mode} mode\")\n print(f\"Total batches per epoch: {len(train_loader)}\")\n print(f\"Total epochs: {config.max_epochs}\")\n print(f\"Device: {self.device}\")\n\n # Training loop\n for epoch in range(start_epoch, config.max_epochs):\n self.model.train()\n total_loss = 0\n\n for batch_idx, batch in enumerate(train_loader):\n # Skip batches if resuming from checkpoint\n if epoch == start_epoch and batch_idx < start_batch:\n continue\n\n # Move batch to device\n batch = batch.to(self.device)\n\n # Forward pass with mixed precision\n with autocast():\n logits = self.model(batch)\n targets = batch[:, 1:]\n logits = logits[:, :-1, :]\n loss = F.cross_entropy(\n logits.reshape(-1, logits.size(-1)),\n targets.reshape(-1)\n )\n\n # Scale loss and backward pass\n self.scaler.scale(loss).backward()\n\n # Gradient accumulation\n if (batch_idx + 1) % config.gradient_accumulation_steps == 0:\n self.scaler.unscale_(self.optimizer)\n torch.nn.utils.clip_grad_norm_(\n self.model.parameters(),\n config.max_grad_norm\n )\n\n self.scaler.step(self.optimizer)\n self.scaler.update()\n self.optimizer.zero_grad()\n\n total_loss += loss.item()\n\n # Save checkpoint and display progress\n if batch_idx % config.checkpoint_every == 0:\n self.save_checkpoint(epoch, batch_idx, loss.item(), config, original_mode)\n clear_output(wait=True)\n print(f\"Mode: {mode}\")\n print(f\"Epoch {epoch+1}/{config.max_epochs}\")\n print(f\"Batch {batch_idx+1}/{len(train_loader)}\")\n print(f\"Loss {total_loss / (batch_idx + 1):.4f}\")\n\n # Save checkpoint at end of epoch\n self.save_checkpoint(epoch, len(train_loader)-1, loss.item(), config, original_mode)\n\n print(f'{mode} training complete.')","metadata":{"_uuid":"3935f5aa-bafb-4ff8-8b30-1e741b7dfeb9","_cell_guid":"2a13e1d7-2842-4bcf-9cb1-54b1f9788d5a","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2025-01-27T04:45:14.731189Z","iopub.execute_input":"2025-01-27T04:45:14.731521Z","iopub.status.idle":"2025-01-27T04:45:14.760459Z","shell.execute_reply.started":"2025-01-27T04:45:14.731493Z","shell.execute_reply":"2025-01-27T04:45:14.759413Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Initialize Tokenizer","metadata":{"_uuid":"4d65a12d-84b7-453f-a968-784ab5e318f9","_cell_guid":"2b4c1cae-5989-408c-a5bb-2f345e7b5c64","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false}}},{"cell_type":"code","source":"cl100k_base = tiktoken.get_encoding(\"cl100k_base\")\n\ntokenizer = tiktoken.Encoding(\n name=\"cl100k_xml\",\n pat_str=cl100k_base._pat_str,\n mergeable_ranks=cl100k_base._mergeable_ranks,\n special_tokens={\n **cl100k_base._special_tokens,\n \"\": 100277,\n \"\": 100278,\n \"\": 100279,\n \"\": 100280,\n \"\": 100281,\n \"\": 100282\n }\n)\n\n# Create config\nconfig = Config()\n# Update vocab size to match tokenizer\nconfig.vocab_size = tokenizer.n_vocab","metadata":{"_uuid":"803c4a00-432c-4ea7-bb27-dc8ed44bbd3e","_cell_guid":"262ad9d0-e5bc-4f6e-a58f-cd16c8341806","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2025-01-27T04:45:16.047144Z","iopub.execute_input":"2025-01-27T04:45:16.047602Z","iopub.status.idle":"2025-01-27T04:45:17.235690Z","shell.execute_reply.started":"2025-01-27T04:45:16.047569Z","shell.execute_reply":"2025-01-27T04:45:17.234557Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Load Data","metadata":{"_uuid":"a4c36482-6844-4c68-ae52-bf55fce9abba","_cell_guid":"139dbfcd-6293-4611-9e31-8b4b2d83657a","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false}}},{"cell_type":"markdown","source":"## Pretrain 1st Session","metadata":{}},{"cell_type":"code","source":"from datasets import load_dataset\n\n# Load Dataset\nds = ''.join(load_dataset(\"nampdn-ai/mini-en\", split='train', token=\"[INSERT HF TOKEN HERE]\")['text'])\n\n# Create dataset\ntrain_dataset = TextDataset(ds, config.max_position_embeddings, tokenizer)","metadata":{"_uuid":"f621cde8-cb01-470d-9d58-a01177922a23","_cell_guid":"7b6ac734-dc85-4a18-854e-82c31cf2b440","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2025-01-23T00:16:50.285800Z","iopub.execute_input":"2025-01-23T00:16:50.286089Z","iopub.status.idle":"2025-01-23T00:18:52.365126Z","shell.execute_reply.started":"2025-01-23T00:16:50.286067Z","shell.execute_reply":"2025-01-23T00:18:52.364189Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Pretrain 2nd Session","metadata":{}},{"cell_type":"code","source":"from datasets import load_dataset\n\n# Load Dataset\nds = ''.join(load_dataset(\"HuggingFaceTB/cosmopedia-100k\", split='train', token=\"[INSERT HF TOKEN HERE]\")['text'])\n\n# Create dataset\ntrain_dataset = TextDataset(ds, config.max_position_embeddings, tokenizer)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-26T05:42:43.999673Z","iopub.execute_input":"2025-01-26T05:42:43.999984Z","iopub.status.idle":"2025-01-26T05:43:37.233353Z","shell.execute_reply.started":"2025-01-26T05:42:43.999960Z","shell.execute_reply":"2025-01-26T05:43:37.232558Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Finetune Session","metadata":{}},{"cell_type":"code","source":"fds1 = load_dataset(\"HuggingFaceH4/SystemChat\", split=\"train_sft\")\nfds2 = load_dataset(\"HuggingFaceH4/no_robots\", split=\"train\")\nfds3 = load_dataset('b-ai/deepseek_synthetic_conversation_dialogue', split='train')\nfinetune_datasets = [fds1, fds2, fds3]\n\nfinetune_parts = []\n\nfor ds in finetune_datasets:\n dataset_conversations = []\n \n for messages in ds['messages']:\n # Format all messages first\n conversation_chunks = [\n f\"<{message['role']}>{message['content']}{message['role']}>\"\n for message in messages\n ]\n \n # Add two newlines at conversation start + join messages\n full_conversation = '\\n'.join(conversation_chunks)\n dataset_conversations.append(full_conversation)\n \n # Join conversations within dataset\n finetune_parts.append('\\n\\n'.join(dataset_conversations))\n\n# Final dataset assembly\nfinetune_text = '\\n\\n'.join(finetune_parts)\ntrain_dataset = TextDataset(finetune_text, config.max_position_embeddings, tokenizer)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-26T23:19:09.690508Z","iopub.execute_input":"2025-01-26T23:19:09.690753Z","iopub.status.idle":"2025-01-26T23:19:22.554546Z","shell.execute_reply.started":"2025-01-26T23:19:09.690733Z","shell.execute_reply":"2025-01-26T23:19:22.553845Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Initialize Model","metadata":{"_uuid":"cb096bdf-a986-459b-b706-1ea266d82525","_cell_guid":"55f8815c-311e-42c2-b89f-42c04f07a8cf","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false}}},{"cell_type":"code","source":"# Initialize model\nmodel = SmallLanguageModel(config)","metadata":{"_uuid":"9d712e53-3d6f-42ce-bd82-6229c6cf186a","_cell_guid":"e1ce7e2a-0710-4201-88c7-54f648f526db","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2025-01-27T04:45:19.919342Z","iopub.execute_input":"2025-01-27T04:45:19.919778Z","iopub.status.idle":"2025-01-27T04:45:24.113995Z","shell.execute_reply.started":"2025-01-27T04:45:19.919748Z","shell.execute_reply":"2025-01-27T04:45:24.112813Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Start or Resume Training","metadata":{"_uuid":"3cb570e8-1c74-4590-8704-125e79761bf4","_cell_guid":"edcca091-7d2b-43b8-afeb-26d5d7aff4ad","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false}}},{"cell_type":"markdown","source":"## Initialize Trainer","metadata":{"_uuid":"601aec0a-4f86-4907-acab-ce21141e47f2","_cell_guid":"b89d17d8-543a-49ed-a4e9-77d0557f7e51","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false}}},{"cell_type":"code","source":"# Initialize trainer\ntrainer = Trainer(model, train_dataset)","metadata":{"_uuid":"09f4cb01-15db-4cd2-beff-c03a0f2f1946","_cell_guid":"58db7d96-4aab-4f5a-9ad4-aee2c23802f1","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2025-01-27T04:45:24.115481Z","iopub.execute_input":"2025-01-27T04:45:24.115825Z","iopub.status.idle":"2025-01-27T04:45:24.195446Z","shell.execute_reply.started":"2025-01-27T04:45:24.115799Z","shell.execute_reply":"2025-01-27T04:45:24.194072Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Start Pretrain Mode","metadata":{}},{"cell_type":"code","source":"trainer.train(mode=TrainingMode.PRETRAIN)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-21T06:34:06.102355Z","iopub.execute_input":"2025-01-21T06:34:06.102690Z","execution_failed":"2025-01-21T18:30:34.429Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Resume Mode","metadata":{}},{"cell_type":"code","source":"trainer.train(mode=TrainingMode.RESUME, checkpoint_path='checkpoint.pt')","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-26T05:44:56.079383Z","iopub.execute_input":"2025-01-26T05:44:56.079666Z","iopub.status.idle":"2025-01-26T10:30:42.601029Z","shell.execute_reply.started":"2025-01-26T05:44:56.079645Z","shell.execute_reply":"2025-01-26T10:30:42.600179Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### New Dataset Pretrain Mode","metadata":{}},{"cell_type":"code","source":"trainer.train(mode=TrainingMode.NEW_DATASET, checkpoint_path='checkpoint.pt')","metadata":{"_uuid":"ab8f6116-fc92-4cc0-9e49-77f43bc37774","_cell_guid":"0071cb6f-36d8-47c9-a48f-2f4b6f0b6231","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2025-01-25T00:25:22.649464Z","iopub.execute_input":"2025-01-25T00:25:22.649772Z","execution_failed":"2025-01-25T05:09:22.939Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Finetune Mode","metadata":{"_uuid":"2cc85254-567a-4acd-83dd-231a13658ffa","_cell_guid":"cdc3a4d5-6046-48d4-879f-e291cb327c8b","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false}}},{"cell_type":"code","source":"trainer.train(mode=TrainingMode.FINETUNE, checkpoint_path='checkpoint.pt')","metadata":{"_uuid":"02595175-3327-4007-8f49-87fdce872c02","_cell_guid":"fe2d0ba2-d259-4b17-b704-f1074378bb09","trusted":true,"collapsed":false,"execution":{"iopub.status.busy":"2025-01-26T23:21:27.867223Z","iopub.execute_input":"2025-01-26T23:21:27.867524Z","iopub.status.idle":"2025-01-27T00:15:07.046427Z","shell.execute_reply.started":"2025-01-26T23:21:27.867502Z","shell.execute_reply":"2025-01-27T00:15:07.044840Z"},"jupyter":{"outputs_hidden":false}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Inference","metadata":{"_uuid":"d3f751f4-c616-432b-b5a6-9c1a5828970c","_cell_guid":"f8cc2a71-7ae3-4b77-809c-0467bc4408fa","trusted":true,"collapsed":false,"jupyter":{"outputs_hidden":false}}},{"cell_type":"code","source":"class TextGenerator:\n def __init__(self, model, tokenizer):\n self.model = model\n self.model.eval() # Set to evaluation mode\n self.tokenizer = tokenizer\n \n @torch.no_grad() # Disable gradient calculation for inference\n def generate(\n self, \n prompt: str,\n max_length: int = 100,\n temperature: float = 0.7,\n top_k: int = 50,\n top_p: float = 0.9\n ):\n try:\n # Encode the prompt\n input_ids = torch.tensor(self.tokenizer.encode(prompt, allowed_special={'', '', '', '', '', ''})).unsqueeze(0).to(device)\n \n # Generate tokens\n for _ in range(max_length):\n # Get model predictions\n if input_ids.size(1) > config.max_position_embeddings:\n input_ids = input_ids[:, -config.max_position_embeddings:]\n \n logits = self.model(input_ids)\n next_token_logits = logits[:, -1, :] / temperature\n \n # Apply top-k filtering\n if top_k > 0:\n values, _ = torch.topk(next_token_logits, top_k)\n min_value = values[:, -1].unsqueeze(-1).expand_as(next_token_logits)\n next_token_logits = torch.where(\n next_token_logits < min_value,\n torch.ones_like(next_token_logits) * float('-inf'),\n next_token_logits\n )\n \n # Apply top-p (nucleus) filtering\n if top_p < 1.0:\n # Sort logits in descending order\n sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)\n cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n\n # Remove tokens with cumulative probability above the threshold\n sorted_indices_to_remove = cumulative_probs > top_p\n # Shift the indices to the right to keep also the first token above the threshold\n sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()\n sorted_indices_to_remove[..., 0] = 0\n\n # Scatter sorted tensors to original indexing\n indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)\n next_token_logits = next_token_logits.masked_fill(indices_to_remove, float('-inf'))\n \n # Sample next token\n probs = F.softmax(next_token_logits, dim=-1)\n next_token = torch.multinomial(probs, num_samples=1)\n \n # Append next token to input_ids\n input_ids = torch.cat((input_ids, next_token), dim=1)\n \n # Check for end of text token (if your tokenizer has one)\n if hasattr(self.tokenizer, 'eot_token') and next_token.item() == self.tokenizer.eot_token:\n break\n \n # Decode the generated tokens\n return self.tokenizer.decode(input_ids[0].tolist())\n \n except Exception as e:\n print(f\"Error during generation: {str(e)}\")\n return prompt # Return original prompt if generation fails\n\n# Load the checkpoint\ncheckpoint = torch.load('checkpoint.pt', map_location=torch.device(device), weights_only=False)\nmodel.load_state_dict(checkpoint['model_state_dict'])\nmodel.to(device)\n\n# Initialize the generator\ngenerator = TextGenerator(model, tokenizer)\n\n# Example usage\nprompt = \"Hello\"\ngenerated_text = generator.generate(\n prompt=prompt,\n max_length=100, # Maximum tokens to generate\n temperature=0.7, # Higher = more random, lower = more focused\n top_k=50, # Consider only top k tokens\n top_p=0.9 # Nucleus sampling threshold\n)\n\nprint(\"Generated text:\")\nprint(generated_text)","metadata":{"_uuid":"ff3202b8-37dc-4fec-92f2-d22fe70aa9ca","_cell_guid":"fccd1d1f-c907-410a-a84e-2ecbd65ca607","trusted":true,"collapsed":false,"execution":{"iopub.status.busy":"2025-01-27T04:50:53.914779Z","iopub.execute_input":"2025-01-27T04:50:53.915211Z","iopub.status.idle":"2025-01-27T04:51:15.670139Z","shell.execute_reply.started":"2025-01-27T04:50:53.915166Z","shell.execute_reply":"2025-01-27T04:51:15.669099Z"},"jupyter":{"outputs_hidden":false}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Inspect Checkpoint","metadata":{}},{"cell_type":"code","source":"import torch\nfrom typing import Any\n\nclass CheckpointManager:\n def __init__(self, checkpoint_path: str):\n self.checkpoint_path = checkpoint_path\n self.checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)\n \n def display_contents(self):\n \"\"\"Display all contents of the checkpoint\"\"\"\n print(\"\\n=== Checkpoint Contents ===\")\n print(f\"Training Mode: {self.checkpoint.get('original_mode', 'Not specified')}\")\n print(f\"Epoch: {self.checkpoint.get('epoch', 'Not specified')}\")\n print(f\"Batch Index: {self.checkpoint.get('batch_idx', 'Not specified')}\")\n print(f\"Loss: {self.checkpoint.get('loss', 'Not specified')}\")\n \n print(\"\\n=== Configuration ===\")\n if 'config' in self.checkpoint:\n for key, value in self.checkpoint['config'].items():\n print(f\"{key}: {value}\")\n else:\n print(\"No configuration found in checkpoint\")\n \n print(\"\\n=== State Dictionaries Present ===\")\n print(\"Model state dict:\", 'model_state_dict' in self.checkpoint)\n print(\"Optimizer state dict:\", 'optimizer_state_dict' in self.checkpoint)\n print(\"Scaler state dict:\", 'scaler_state_dict' in self.checkpoint)\n\n def expand_embeddings(self, new_vocab_size: int, init_method: str = \"mean\"):\n state_dict = self.checkpoint['model_state_dict']\n \n # Get original sizes from BOTH layers\n old_token_vocab_size = state_dict['token_embedding.weight'].size(0)\n old_head_vocab_size = state_dict['head.weight'].size(0)\n \n # ===== 1. Expand Token Embeddings =====\n old_emb = state_dict['token_embedding.weight']\n new_emb = torch.zeros((new_vocab_size, old_emb.size(1)))\n new_emb[:old_token_vocab_size] = old_emb\n \n # Initialize new token embeddings\n if init_method == \"mean\":\n new_emb[old_token_vocab_size:] = old_emb.mean(dim=0)\n elif init_method == \"normal\":\n new_emb[old_token_vocab_size:] = torch.randn_like(new_emb[old_token_vocab_size:]) * 0.02\n \n state_dict['token_embedding.weight'] = new_emb\n \n # ===== 2. Expand Output Layer =====\n old_head = state_dict['head.weight']\n new_head = torch.zeros((new_vocab_size, old_head.size(1)))\n new_head[:old_head_vocab_size] = old_head # Use HEAD's original size\n \n # Initialize new output weights\n if init_method == \"mean\":\n new_head[old_head_vocab_size:] = old_head.mean(dim=0)\n elif init_method == \"normal\":\n new_head[old_head_vocab_size:] = torch.randn_like(new_head[old_head_vocab_size:]) * 0.02\n \n state_dict['head.weight'] = new_head\n \n # Update config\n self.checkpoint['config']['vocab_size'] = new_vocab_size\n print(f\"Expanded: Tokens {old_token_vocab_size}→{new_vocab_size}, Head {old_head_vocab_size}→{new_vocab_size}\")\n \n def modify_value(self, key_path: str, new_value: Any):\n \"\"\"\n Modify a value in the checkpoint using a dot-notation path\n Example: 'config.learning_rate' or 'original_mode'\n \"\"\"\n keys = key_path.split('.')\n current = self.checkpoint\n \n # Navigate to the nested location\n for key in keys[:-1]:\n if key not in current:\n print(f\"Error: Key '{key}' not found in checkpoint\")\n return\n current = current[key]\n \n final_key = keys[-1]\n if final_key not in current:\n print(f\"Error: Final key '{final_key}' not found\")\n return\n \n # Convert value to the same type as the existing value\n try:\n old_value = current[final_key]\n if isinstance(old_value, bool):\n new_value = bool(new_value)\n elif isinstance(old_value, int):\n new_value = int(new_value)\n elif isinstance(old_value, float):\n new_value = float(new_value)\n except ValueError:\n print(f\"Error: Could not convert new value to type {type(old_value)}\")\n return\n \n # Update the value\n current[final_key] = new_value\n print(f\"Updated {key_path} from {old_value} to {new_value}\")\n \n def save(self, output_path: str = None):\n \"\"\"Save the modified checkpoint\"\"\"\n save_path = output_path or self.checkpoint_path\n torch.save(self.checkpoint, save_path)\n print(f\"Saved checkpoint to {save_path}\")\n\nmanager = CheckpointManager('checkpoint.pt')\n\nmanager.display_contents()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-27T04:54:31.037973Z","iopub.execute_input":"2025-01-27T04:54:31.038262Z","iopub.status.idle":"2025-01-27T04:54:37.100896Z","shell.execute_reply.started":"2025-01-27T04:54:31.038233Z","shell.execute_reply":"2025-01-27T04:54:37.099804Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"from IPython.display import FileLink\n\nFileLink(r'checkpoint.pt')","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-27T04:54:37.101878Z","iopub.execute_input":"2025-01-27T04:54:37.102326Z","iopub.status.idle":"2025-01-27T04:54:37.110404Z","shell.execute_reply.started":"2025-01-27T04:54:37.102294Z","shell.execute_reply":"2025-01-27T04:54:37.109374Z"}},"outputs":[],"execution_count":null}]}