RubiRLM-1B-Base / RubiRLM.py
DevHunterAI's picture
Upload folder using huggingface_hub
cd16f07 verified
"""Rubi-RLM: 1B-class Recursive Language Model (RLM) prototype.
Bu dosya, recursive düşünme + dual-loop öğrenme hedefiyle tasarlanmış bir
araştırma prototipi içerir.
Eklenen sohbet katmanı:
- İngilizce/Türkçe çift dilli chat şablonu
- HF tokenizer ile metin->id / id->metin köprüsü
- Tek mesaj veya interaktif chat CLI
"""
from __future__ import annotations
import argparse
import importlib
import importlib.util
from dataclasses import dataclass
from typing import List, Optional, Protocol, Sequence, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from rubi_train_stack import (
TrainStackConfig,
build_dataloader,
build_dataset,
build_optimizer,
train_demo_steps,
)
from xqs_moe import build_deepspeed_moe
from xqs_stack import choose_moe_backend, detect_xqs_backends, format_backend_report
from x_quantum_sparse_ops import (
build_linear,
causal_scaled_dot_product_attention,
fused_residual_add,
maybe_compile_module,
pack_rows,
scatter_rows,
)
class TextTokenizer(Protocol):
def encode(self, text: str, return_tensors: Optional[str] = None): ...
def decode(self, token_ids: Sequence[int], skip_special_tokens: bool = True) -> str: ...
@dataclass
class ChatTurn:
role: str
content: str
@dataclass
class RLMConfig:
vocab_size: int = 50_257
max_seq_len: int = 2_048
d_model: int = 2_048
n_layers: int = 14
n_heads: int = 16
ff_mult: int = 4
dropout: float = 0.1
recurse_steps: int = 6
critique_threshold: float = 0.20
tie_embeddings: bool = True
use_moe: bool = False
moe_num_experts: int = 0
moe_top_k: int = 2
moe_expert_hidden: int = 0
moe_router_jitter: float = 0.0
moe_aux_loss_weight: float = 0.01
use_layer_skip: bool = False
layer_skip_threshold: float = 0.50
layer_skip_target: float = 1.0
layer_skip_aux_weight: float = 0.01
use_ternary_weights: bool = False
use_flash_attention: bool = False
use_fused_ops: bool = False
packed_execution: bool = False
use_torch_compile: bool = False
moe_backend: str = "auto"
moe_ep_size: int = 1
@classmethod
def scale_1b(cls) -> "RLMConfig":
return cls(
vocab_size=50_257,
max_seq_len=2_048,
d_model=1_024,
n_layers=10,
n_heads=16,
ff_mult=4,
recurse_steps=6,
critique_threshold=0.20,
use_moe=True,
moe_num_experts=32,
moe_top_k=1,
moe_expert_hidden=1_280,
moe_router_jitter=0.01,
moe_aux_loss_weight=0.01,
use_layer_skip=True,
layer_skip_threshold=0.80,
layer_skip_target=0.03,
layer_skip_aux_weight=0.01,
use_ternary_weights=True,
use_flash_attention=True,
use_fused_ops=True,
packed_execution=True,
use_torch_compile=False,
moe_backend="auto",
moe_ep_size=1,
)
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.scale = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
return self.scale * (x / rms)
class DenseFeedForward(nn.Module):
def __init__(self, cfg: RLMConfig):
super().__init__()
hidden = cfg.d_model * cfg.ff_mult
self.up_proj = build_linear(cfg.d_model, hidden, ternary=cfg.use_ternary_weights)
self.down_proj = build_linear(hidden, cfg.d_model, ternary=cfg.use_ternary_weights)
self.dropout = nn.Dropout(cfg.dropout)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return self.dropout(self.down_proj(F.gelu(self.up_proj(x)))), x.new_zeros(())
class FastSelfAttention(nn.Module):
def __init__(self, cfg: RLMConfig):
super().__init__()
if cfg.d_model % cfg.n_heads != 0:
raise ValueError("d_model must be divisible by n_heads.")
self.n_heads = cfg.n_heads
self.head_dim = cfg.d_model // cfg.n_heads
self.dropout = cfg.dropout
self.use_flash_attention = cfg.use_flash_attention
self.q_proj = build_linear(cfg.d_model, cfg.d_model, bias=False, ternary=cfg.use_ternary_weights)
self.k_proj = build_linear(cfg.d_model, cfg.d_model, bias=False, ternary=cfg.use_ternary_weights)
self.v_proj = build_linear(cfg.d_model, cfg.d_model, bias=False, ternary=cfg.use_ternary_weights)
self.out_proj = build_linear(cfg.d_model, cfg.d_model, bias=False, ternary=cfg.use_ternary_weights)
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
bsz, seq_len, _ = x.shape
q = self.q_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
attn_out = causal_scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.dropout,
training=self.training,
)
attn_out = attn_out.transpose(1, 2).contiguous().view(bsz, seq_len, self.n_heads * self.head_dim)
return self.out_proj(attn_out)
class MoEExpert(nn.Module):
def __init__(self, d_model: int, hidden: int):
super().__init__()
self.up_proj = build_linear(d_model, hidden, ternary=True)
self.down_proj = build_linear(hidden, d_model, ternary=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.gelu(self.up_proj(x)))
class MoEFeedForward(nn.Module):
def __init__(self, cfg: RLMConfig):
super().__init__()
if cfg.moe_num_experts <= 0:
raise ValueError("moe_num_experts must be positive when use_moe=True.")
if cfg.moe_top_k <= 0 or cfg.moe_top_k > cfg.moe_num_experts:
raise ValueError("moe_top_k must be in the range [1, moe_num_experts].")
self.num_experts = cfg.moe_num_experts
self.top_k = cfg.moe_top_k
self.router_jitter = cfg.moe_router_jitter
requested_backend = cfg.moe_backend.lower()
self.backend = choose_moe_backend(prefer_deepspeed=requested_backend in {"auto", "deepspeed"}) if requested_backend != "native" else "native"
self.router = build_linear(cfg.d_model, cfg.moe_num_experts, ternary=cfg.use_ternary_weights)
self.experts = nn.ModuleList([MoEExpert(cfg.d_model, cfg.moe_expert_hidden) for _ in range(cfg.moe_num_experts)])
self.deepspeed_moe = None
if self.backend == "deepspeed":
self.deepspeed_moe = build_deepspeed_moe(
hidden_size=cfg.d_model,
expert=MoEExpert(cfg.d_model, cfg.moe_expert_hidden),
num_experts=cfg.moe_num_experts,
top_k=cfg.moe_top_k,
ep_size=cfg.moe_ep_size,
)
if self.deepspeed_moe is None:
self.backend = "native"
self.dropout = nn.Dropout(cfg.dropout)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.deepspeed_moe is not None:
out, aux_loss = self.deepspeed_moe(x)
return self.dropout(out), aux_loss
flat_x = x.reshape(-1, x.size(-1))
router_logits = self.router(flat_x)
if self.training and self.router_jitter > 0:
router_logits = router_logits + torch.randn_like(router_logits) * self.router_jitter
router_probs = F.softmax(router_logits, dim=-1)
topk_weights, topk_indices = torch.topk(router_probs, self.top_k, dim=-1)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
mixed = flat_x.new_zeros(flat_x.shape)
expert_load = router_probs.new_zeros(self.num_experts)
for expert_id, expert in enumerate(self.experts):
expert_mask = topk_indices == expert_id
if not expert_mask.any():
continue
token_indices, slot_indices = expert_mask.nonzero(as_tuple=True)
expert_inputs = flat_x.index_select(0, token_indices)
expert_outputs = expert(expert_inputs)
weights = topk_weights[token_indices, slot_indices].unsqueeze(-1)
mixed.index_add_(0, token_indices, expert_outputs * weights)
expert_load[expert_id] = float(token_indices.numel())
mixed = self.dropout(mixed.view_as(x))
importance = router_probs.mean(dim=0)
load = expert_load / max(1, flat_x.size(0) * self.top_k)
aux_loss = self.num_experts * torch.sum(importance * load)
return mixed, aux_loss
class RecursiveBlock(nn.Module):
def __init__(self, cfg: RLMConfig):
super().__init__()
self.use_layer_skip = cfg.use_layer_skip
self.layer_skip_threshold = cfg.layer_skip_threshold
self.layer_skip_target = cfg.layer_skip_target
self.use_fused_ops = cfg.use_fused_ops
self.packed_execution = cfg.packed_execution
self.norm_attn = RMSNorm(cfg.d_model)
self.norm_ff = RMSNorm(cfg.d_model)
self.attn = FastSelfAttention(cfg)
self.ffn = MoEFeedForward(cfg) if cfg.use_moe else DenseFeedForward(cfg)
self.skip_router = build_linear(cfg.d_model, 1, ternary=cfg.use_ternary_weights) if cfg.use_layer_skip else None
self.state_fuse = build_linear(cfg.d_model * 2, cfg.d_model, ternary=cfg.use_ternary_weights)
self.state_update = build_linear(cfg.d_model, cfg.d_model, ternary=cfg.use_ternary_weights)
self.state_gate = build_linear(cfg.d_model * 2, cfg.d_model, ternary=cfg.use_ternary_weights)
def _run_core(
self,
x: torch.Tensor,
state: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
x_norm = self.norm_attn(x)
attn_out = self.attn(x_norm, attn_mask=attn_mask)
fuse_input = torch.cat([attn_out, state], dim=-1)
gate = torch.sigmoid(self.state_gate(fuse_input))
fused = self.state_fuse(fuse_input)
fused = gate * fused + (1.0 - gate) * state
if self.use_fused_ops:
x = fused_residual_add(x, fused)
else:
x = x + fused
ff_out, moe_aux_loss = self.ffn(self.norm_ff(x))
if self.use_fused_ops:
x = fused_residual_add(x, ff_out)
else:
x = x + ff_out
new_state = torch.tanh(self.state_update(x))
return x, new_state, moe_aux_loss
def forward(
self,
x: torch.Tensor,
state: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
exec_prob = x.new_ones((x.size(0),))
skip_aux_loss = x.new_zeros(())
if self.skip_router is None:
x, new_state, moe_aux_loss = self._run_core(x, state, attn_mask=attn_mask)
return x, new_state, moe_aux_loss, skip_aux_loss, exec_prob.mean()
router_input = x.mean(dim=1)
exec_prob = torch.sigmoid(self.skip_router(router_input)).squeeze(-1)
target = exec_prob.new_full(exec_prob.shape, self.layer_skip_target)
skip_aux_loss = F.mse_loss(exec_prob, target)
hard_gate = exec_prob >= self.layer_skip_threshold
if not torch.any(hard_gate):
return x, state, x.new_zeros(()), skip_aux_loss, exec_prob.mean()
if torch.all(hard_gate):
x_exec, state_exec, moe_aux_loss = self._run_core(x, state, attn_mask=attn_mask)
elif self.packed_execution:
active_indices = torch.nonzero(hard_gate, as_tuple=False).squeeze(-1)
x_active, state_active = pack_rows(active_indices, x, state)
x_active, state_active, moe_aux_loss = self._run_core(x_active, state_active, attn_mask=attn_mask)
x_exec = scatter_rows(x, active_indices, x_active)
state_exec = scatter_rows(state, active_indices, state_active)
else:
x_exec, state_exec, moe_aux_loss = self._run_core(x, state, attn_mask=attn_mask)
if self.training:
exec_gate = exec_prob + (hard_gate.to(exec_prob.dtype) - exec_prob).detach()
exec_scale = exec_gate.view(-1, 1, 1)
x_exec = x + exec_scale * (x_exec - x)
state_exec = state + exec_scale * (state_exec - state)
return x_exec, state_exec, moe_aux_loss, skip_aux_loss, exec_prob.mean()
class RubiRLM(nn.Module):
def __init__(self, cfg: RLMConfig):
super().__init__()
self.cfg = cfg
self._last_moe_aux_loss = torch.tensor(0.0)
self._last_layer_skip_aux_loss = torch.tensor(0.0)
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
self.drop = nn.Dropout(cfg.dropout)
self.layers = nn.ModuleList([maybe_compile_module(RecursiveBlock(cfg), cfg.use_torch_compile) for _ in range(cfg.n_layers)])
self.final_norm = RMSNorm(cfg.d_model)
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
if cfg.tie_embeddings:
self.lm_head.weight = self.tok_emb.weight
self.critique_head = nn.Sequential(
nn.Linear(cfg.d_model, cfg.d_model // 2),
nn.GELU(),
nn.Linear(cfg.d_model // 2, 1),
)
def _causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
mask = torch.full((seq_len, seq_len), float("-inf"), device=device)
return torch.triu(mask, diagonal=1)
def _embed(self, input_ids: torch.Tensor) -> torch.Tensor:
bsz, seq_len = input_ids.shape
if seq_len > self.cfg.max_seq_len:
raise ValueError(f"Girdi uzunluğu max_seq_len={self.cfg.max_seq_len} sınırını aşıyor.")
pos = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(bsz, seq_len)
return self.drop(self.tok_emb(input_ids) + self.pos_emb(pos))
def forward_recursive(
self,
input_ids: torch.Tensor,
steps: Optional[int] = None,
stop_on_critique: bool = True,
return_trace: bool = False,
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
steps = steps or self.cfg.recurse_steps
x = self._embed(input_ids)
bsz, seq_len, d_model = x.shape
states = [x.new_zeros((bsz, seq_len, d_model)) for _ in range(self.cfg.n_layers)]
mask = self._causal_mask(seq_len, x.device)
logits_trace: List[torch.Tensor] = []
critique_trace: List[torch.Tensor] = []
moe_aux_total = x.new_zeros(())
layer_skip_aux_total = x.new_zeros(())
for _ in range(steps):
h = x
new_states = []
for layer, st in zip(self.layers, states):
h, st_new, moe_aux, skip_aux, _ = layer(h, st, attn_mask=mask)
new_states.append(st_new)
moe_aux_total = moe_aux_total + moe_aux
layer_skip_aux_total = layer_skip_aux_total + skip_aux
states = new_states
h_norm = self.final_norm(h)
logits = self.lm_head(h_norm)
pooled = h_norm[:, -1, :]
critique = torch.sigmoid(self.critique_head(pooled)).squeeze(-1)
logits_trace.append(logits)
critique_trace.append(critique)
x = h
if stop_on_critique and torch.all(critique < self.cfg.critique_threshold):
break
denom = max(1, len(logits_trace) * len(self.layers))
self._last_moe_aux_loss = moe_aux_total / denom
self._last_layer_skip_aux_loss = layer_skip_aux_total / denom
final_logits = logits_trace[-1]
if return_trace:
return final_logits, logits_trace, critique_trace
return final_logits, [], critique_trace
def training_loss(
self,
input_ids: torch.Tensor,
target_ids: torch.Tensor,
steps: Optional[int] = None,
alpha_iterative: float = 0.30,
beta_correction: float = 0.10,
) -> torch.Tensor:
final_logits, trace, critique = self.forward_recursive(
input_ids, steps=steps, stop_on_critique=False, return_trace=True
)
final_loss = F.cross_entropy(
final_logits.view(-1, final_logits.size(-1)),
target_ids.view(-1),
ignore_index=-100,
)
if trace:
iterative = 0.0
for logits in trace[:-1]:
iterative = iterative + F.cross_entropy(
logits.view(-1, logits.size(-1)),
target_ids.view(-1),
ignore_index=-100,
)
iterative = iterative / max(1, len(trace) - 1)
else:
iterative = final_loss.new_tensor(0.0)
correction_bonus = 0.0
if len(critique) > 1:
start = critique[0].mean()
end = critique[-1].mean()
correction_bonus = torch.relu(end - start)
total_loss = final_loss + alpha_iterative * iterative + beta_correction * correction_bonus
if self.cfg.use_moe:
total_loss = total_loss + self.cfg.moe_aux_loss_weight * self._last_moe_aux_loss
if self.cfg.use_layer_skip:
total_loss = total_loss + self.cfg.layer_skip_aux_weight * self._last_layer_skip_aux_loss
return total_loss
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 64,
temperature: float = 0.8,
top_k: int = 50,
steps: Optional[int] = None,
) -> torch.Tensor:
self.eval()
out = input_ids
for _ in range(max_new_tokens):
context = out[:, -self.cfg.max_seq_len :]
logits, _, _ = self.forward_recursive(context, steps=steps, stop_on_critique=True, return_trace=False)
next_logits = logits[:, -1, :] / max(temperature, 1e-5)
if top_k > 0:
values, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
cutoff = values[:, [-1]]
next_logits = torch.where(next_logits < cutoff, torch.full_like(next_logits, -1e9), next_logits)
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
out = torch.cat([out, next_token], dim=1)
return out
def generate_text(
self,
tokenizer: TextTokenizer,
prompt: str,
max_new_tokens: int = 128,
temperature: float = 0.7,
top_k: int = 50,
steps: Optional[int] = None,
device: Optional[torch.device] = None,
) -> str:
device = device or next(self.parameters()).device
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
output_ids = self.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
steps=steps,
)
new_tokens = output_ids[0, input_ids.shape[1] :].tolist()
return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
def chat(
self,
tokenizer: TextTokenizer,
history: List[ChatTurn],
user_message: str,
lang: str = "auto",
max_new_tokens: int = 192,
temperature: float = 0.7,
top_k: int = 50,
steps: Optional[int] = None,
device: Optional[torch.device] = None,
) -> Tuple[str, List[ChatTurn]]:
prompt = build_chat_prompt(history, user_message, lang=lang)
assistant_reply = self.generate_text(
tokenizer=tokenizer,
prompt=prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
steps=steps,
device=device,
)
updated = history + [ChatTurn(role="user", content=user_message), ChatTurn(role="assistant", content=assistant_reply)]
return assistant_reply, updated
def outer_sleep_phase_step(
self,
optimizer: torch.optim.Optimizer,
input_ids: torch.Tensor,
target_ids: torch.Tensor,
steps: Optional[int] = None,
) -> float:
self.train()
optimizer.zero_grad(set_to_none=True)
loss = self.training_loss(input_ids, target_ids, steps=steps)
loss.backward()
nn.utils.clip_grad_norm_(self.parameters(), 1.0)
optimizer.step()
return float(loss.detach().item())
def estimate_parameters(cfg: RLMConfig) -> int:
d = cfg.d_model
total = cfg.vocab_size * d + cfg.max_seq_len * d
attn_params = (4 * d * d) + (4 * d)
state_params = (5 * d * d) + (3 * d)
router_params = 0
layer_skip_params = 0
ff_params = (2 * d * d * cfg.ff_mult) + (d * cfg.ff_mult) + d
if cfg.use_moe:
router_params = (d * cfg.moe_num_experts) + cfg.moe_num_experts
expert_params = (2 * d * cfg.moe_expert_hidden) + cfg.moe_expert_hidden + d
ff_params = cfg.moe_num_experts * expert_params
if cfg.use_layer_skip:
layer_skip_params = d + 1
per_layer = attn_params + state_params + router_params + layer_skip_params + ff_params + (2 * d)
total += cfg.n_layers * per_layer
total += d * (d // 2) + (d // 2) + (d // 2) + 1 + d
if not cfg.tie_embeddings:
total += d * cfg.vocab_size
return total
def estimate_active_parameters(cfg: RLMConfig) -> int:
d = cfg.d_model
total = cfg.vocab_size * d + cfg.max_seq_len * d
attn_params = (4 * d * d) + (4 * d)
state_params = (5 * d * d) + (3 * d)
router_params = 0
layer_skip_params = 0
ff_params = (2 * d * d * cfg.ff_mult) + (d * cfg.ff_mult) + d
if cfg.use_moe:
router_params = (d * cfg.moe_num_experts) + cfg.moe_num_experts
expert_params = (2 * d * cfg.moe_expert_hidden) + cfg.moe_expert_hidden + d
ff_params = cfg.moe_top_k * expert_params
if cfg.use_layer_skip:
layer_skip_params = d + 1
routed_layer = attn_params + state_params + router_params + ff_params + (2 * d)
routed_layer = cfg.layer_skip_target * routed_layer
per_layer = layer_skip_params + routed_layer
total += cfg.n_layers * per_layer
total += d * (d // 2) + (d // 2) + (d // 2) + 1 + d
if not cfg.tie_embeddings:
total += d * cfg.vocab_size
return int(total)
def language_system_prompt(lang: str) -> str:
base = (
"You are Rubi-RLM assistant. Reason step-by-step internally, be concise in final answer, "
"self-correct if needed."
)
if lang == "tr":
return base + " Yanıtlarını Türkçe ver."
if lang == "en":
return base + " Reply in English."
return base + " Reply in the user's language (Turkish or English)."
def build_chat_prompt(history: List[ChatTurn], user_message: str, lang: str = "auto") -> str:
lines = [f"<|system|>\n{language_system_prompt(lang)}"]
for turn in history:
role = "user" if turn.role.lower() == "user" else "assistant"
lines.append(f"<|{role}|>\n{turn.content}")
lines.append(f"\n{user_message}")
lines.append("<|assistant|>\n")
return "\n".join(lines)
def load_hf_tokenizer(tokenizer_name: str):
if importlib.util.find_spec("transformers") is None:
raise RuntimeError("transformers yüklü değil. `pip install transformers` ile kurun.")
transformers = importlib.import_module("transformers")
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token", None) is not None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def demo() -> None:
cfg = RLMConfig(
vocab_size=4096,
max_seq_len=128,
d_model=256,
n_layers=4,
n_heads=8,
ff_mult=4,
recurse_steps=4,
use_moe=True,
moe_num_experts=8,
moe_top_k=2,
moe_expert_hidden=384,
)
model = RubiRLM(cfg)
x = torch.randint(0, cfg.vocab_size, (2, 32))
y = torch.randint(0, cfg.vocab_size, (2, 32))
loss = model.training_loss(x, y)
print(f"demo_loss={loss.item():.4f}")
out = model.generate(x[:, :8], max_new_tokens=8, steps=3)
print("generated_shape=", tuple(out.shape))
def resolve_config(scale: str) -> RLMConfig:
if scale == "1b":
return RLMConfig.scale_1b()
return RLMConfig(d_model=512, n_layers=8, n_heads=8, vocab_size=50_257, max_seq_len=512)
def runtime_torch_compile_available() -> bool:
if not hasattr(torch, "compile"):
return False
if torch.cuda.is_available() and importlib.util.find_spec("triton") is None:
return False
return True
def apply_runtime_config_overrides(cfg: RLMConfig, args: argparse.Namespace) -> RLMConfig:
cfg.moe_backend = getattr(args, "moe_backend", cfg.moe_backend)
cfg.moe_ep_size = getattr(args, "moe_ep_size", cfg.moe_ep_size)
requested_compile = bool(getattr(args, "use_torch_compile", cfg.use_torch_compile))
cfg.use_torch_compile = requested_compile and runtime_torch_compile_available()
return cfg
def maybe_load_checkpoint(model: RubiRLM, checkpoint: Optional[str], device: torch.device) -> None:
if not checkpoint:
return
state = torch.load(checkpoint, map_location=device)
if isinstance(state, dict) and "model_state_dict" in state:
model.load_state_dict(state["model_state_dict"])
return
model.load_state_dict(state)
def run_single_chat(args: argparse.Namespace) -> None:
cfg = apply_runtime_config_overrides(resolve_config(args.scale), args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RubiRLM(cfg).to(device)
maybe_load_checkpoint(model, args.checkpoint, device)
tokenizer = load_hf_tokenizer(args.tokenizer_name)
history: List[ChatTurn] = []
if args.interactive:
print("Interactive chat başladı. Çıkmak için /exit yaz.")
while True:
user_msg = input("You> ").strip()
if not user_msg:
continue
if user_msg.lower() in {"/exit", "exit", "quit"}:
break
reply, history = model.chat(
tokenizer=tokenizer,
history=history,
user_message=user_msg,
lang=args.lang,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
steps=args.steps,
device=device,
)
print(f"Rubi> {reply}")
return
if not args.prompt:
raise ValueError("--chat modunda --prompt veya --interactive gerekli.")
reply, _ = model.chat(
tokenizer=tokenizer,
history=[],
user_message=args.prompt,
lang=args.lang,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
steps=args.steps,
device=device,
)
print(reply)
def print_stack_report() -> None:
report = detect_xqs_backends()
print(format_backend_report(report))
def run_train_demo(args: argparse.Namespace) -> None:
cfg = apply_runtime_config_overrides(resolve_config(args.scale), args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RubiRLM(cfg).to(device)
maybe_load_checkpoint(model, args.checkpoint, device)
train_cfg = TrainStackConfig(
optimizer_name=args.optimizer_name,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=not args.disable_pin_memory,
prefetch_factor=args.prefetch_factor,
persistent_workers=not args.disable_persistent_workers,
max_seq_len=cfg.max_seq_len,
dataset_dir=args.dataset_dir,
use_bf16=not args.disable_bf16,
)
dataset = build_dataset(
dataset_dir=train_cfg.dataset_dir,
vocab_size=cfg.vocab_size,
max_seq_len=min(cfg.max_seq_len, args.train_seq_len),
synthetic_samples=max(args.train_steps * args.batch_size * 2, 32),
)
dataloader = build_dataloader(dataset, train_cfg, shuffle=True)
optimizer = build_optimizer(model, train_cfg)
mean_loss, total_tokens = train_demo_steps(
model=model,
optimizer=optimizer,
dataloader=dataloader,
device=device,
steps=args.train_steps,
use_bf16=train_cfg.use_bf16,
)
print(
f"train_demo optimizer={optimizer.__class__.__name__} steps={args.train_steps} "
f"mean_loss={mean_loss:.4f} tokens={total_tokens:,} device={device}"
)
def main() -> None:
parser = argparse.ArgumentParser(description="Rubi-RLM recursive language model")
parser.add_argument("--scale", choices=["1b", "tiny"], default="1b")
parser.add_argument("--estimate-only", action="store_true")
parser.add_argument("--demo", action="store_true")
parser.add_argument("--train-demo", action="store_true")
parser.add_argument("--stack-report", action="store_true")
parser.add_argument("--chat", action="store_true", help="Türkçe/İngilizce sohbet modunu açar")
parser.add_argument("--interactive", action="store_true", help="Interactive chat loop")
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--lang", choices=["auto", "tr", "en"], default="auto")
parser.add_argument("--tokenizer-name", type=str, default="gpt2")
parser.add_argument("--checkpoint", type=str, default=None)
parser.add_argument("--steps", type=int, default=None)
parser.add_argument("--max-new-tokens", type=int, default=192)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--top-k", type=int, default=50)
parser.add_argument("--optimizer-name", type=str, default="auto")
parser.add_argument("--moe-backend", choices=["auto", "native", "deepspeed"], default="auto")
parser.add_argument("--moe-ep-size", type=int, default=1)
parser.add_argument("--use-torch-compile", action="store_true")
parser.add_argument("--learning-rate", type=float, default=3e-4)
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--num-workers", type=int, default=2)
parser.add_argument("--prefetch-factor", type=int, default=4)
parser.add_argument("--dataset-dir", type=str, default="")
parser.add_argument("--train-steps", type=int, default=2)
parser.add_argument("--train-seq-len", type=int, default=256)
parser.add_argument("--disable-pin-memory", action="store_true")
parser.add_argument("--disable-persistent-workers", action="store_true")
parser.add_argument("--disable-bf16", action="store_true")
args = parser.parse_args()
if args.chat:
run_single_chat(args)
return
if args.stack_report:
print_stack_report()
return
if args.train_demo:
run_train_demo(args)
return
if args.demo:
demo()
return
cfg = apply_runtime_config_overrides(resolve_config(args.scale), args)
n_params = estimate_parameters(cfg)
active_params = estimate_active_parameters(cfg)
print(f"Scale={args.scale}, estimated_params={n_params:,}, estimated_active_params={active_params:,}")
if not args.estimate_only:
model = RubiRLM(cfg)
actual = sum(p.numel() for p in model.parameters())
print(f"actual_params={actual:,}")
if __name__ == "__main__":
main()