TinyWay-1.0.0 / modeling_tinyway.py
Shivam Sharma
Initial release: TinyWay 1.0.0 (52.94M params)
3d56b4c verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.generation.utils import GenerationMixin
from transformers.modeling_outputs import CausalLMOutput
# =========================
# Config
# =========================
class TinyWayConfig(PretrainedConfig):
model_type = "tinyway"
def __init__(
self,
vocab_size=50257,
n_positions=256,
n_embd=384,
n_layer=8,
n_head=8,
dropout=0.1,
**kwargs
):
super().__init__(**kwargs)
# --- original fields ---
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.dropout = dropout
# --- HF standard aliases (CRITICAL) ---
self.hidden_size = n_embd
self.num_hidden_layers = n_layer
self.num_attention_heads = n_head
self.max_position_embeddings = n_positions
# =========================
# Attention
# =========================
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.head_dim = config.n_embd // config.n_head
self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd)
self.proj = nn.Linear(config.n_embd, config.n_embd)
self.register_buffer(
"mask",
torch.tril(torch.ones(config.n_positions, config.n_positions))
)
def forward(self, x):
B, T, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
att = att.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
att = F.softmax(att, dim=-1)
out = att @ v
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.proj(out)
# =========================
# Transformer Block
# =========================
class DecoderBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.attn = CausalSelfAttention(config)
self.ffn = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(),
nn.Linear(4 * config.n_embd, config.n_embd)
)
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = x + self.dropout(self.attn(self.ln1(x)))
x = x + self.dropout(self.ffn(self.ln2(x)))
return x
# =========================
# Model
# =========================
class TinyWayForCausalLM(PreTrainedModel, GenerationMixin):
config_class = TinyWayConfig
def __init__(self, config):
super().__init__(config)
self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)
self.pos_emb = nn.Embedding(config.n_positions, config.n_embd)
self.blocks = nn.ModuleList(
[DecoderBlock(config) for _ in range(config.n_layer)]
)
self.ln = nn.LayerNorm(config.n_embd)
# MUST match training
self.head = nn.Linear(config.n_embd, config.vocab_size)
self.post_init()
# ---- HF REQUIRED METHODS ----
def get_input_embeddings(self):
return self.token_emb
def set_input_embeddings(self, value):
self.token_emb = value
# ---- Forward ----
def forward(self, input_ids, **kwargs):
B, T = input_ids.shape
pos = torch.arange(T, device=input_ids.device)
x = self.token_emb(input_ids) + self.pos_emb(pos)
for block in self.blocks:
x = block(x)
x = self.ln(x)
logits = self.head(x)
return CausalLMOutput(logits=logits)