|
|
| """
|
| CUDA-optimized inference script for Ursa Minor Smashed model
|
| Requires CUDA-capable GPU
|
| """
|
|
|
| import torch
|
| import torch.nn.functional as F
|
| import argparse
|
| import tiktoken
|
| from typing import Optional, List, Tuple
|
| import warnings
|
| warnings.filterwarnings('ignore')
|
|
|
|
|
| class GPTConfig:
|
| def __init__(self, **kwargs):
|
| self.block_size = kwargs.get('block_size', 1024)
|
| self.vocab_size = kwargs.get('vocab_size', 50304)
|
| self.n_layer = kwargs.get('n_layer', 12)
|
| self.n_head = kwargs.get('n_head', 12)
|
| self.n_embd = kwargs.get('n_embd', 768)
|
|
|
| class CausalSelfAttention(torch.nn.Module):
|
| def __init__(self, config):
|
| super().__init__()
|
| assert config.n_embd % config.n_head == 0
|
| self.c_attn = torch.nn.Linear(config.n_embd, 3 * config.n_embd)
|
| self.c_proj = torch.nn.Linear(config.n_embd, config.n_embd)
|
| self.n_head = config.n_head
|
| self.n_embd = config.n_embd
|
|
|
| def forward(self, x):
|
| B, T, C = x.size()
|
| qkv = self.c_attn(x)
|
| q, k, v = qkv.split(self.n_embd, dim=2)
|
| k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
| q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
| v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
| y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| y = self.c_proj(y)
|
| return y
|
|
|
| class MLP(torch.nn.Module):
|
| def __init__(self, config):
|
| super().__init__()
|
| self.c_fc = torch.nn.Linear(config.n_embd, 4 * config.n_embd)
|
| self.gelu = torch.nn.GELU(approximate='tanh')
|
| self.c_proj = torch.nn.Linear(4 * config.n_embd, config.n_embd)
|
|
|
| def forward(self, x):
|
| x = self.c_fc(x)
|
| x = self.gelu(x)
|
| x = self.c_proj(x)
|
| return x
|
|
|
| class Block(torch.nn.Module):
|
| def __init__(self, config):
|
| super().__init__()
|
| self.ln_1 = torch.nn.LayerNorm(config.n_embd)
|
| self.attn = CausalSelfAttention(config)
|
| self.ln_2 = torch.nn.LayerNorm(config.n_embd)
|
| self.mlp = MLP(config)
|
|
|
| def forward(self, x):
|
| x = x + self.attn(self.ln_1(x))
|
| x = x + self.mlp(self.ln_2(x))
|
| return x
|
|
|
| class GPT(torch.nn.Module):
|
| def __init__(self, config):
|
| super().__init__()
|
| self.config = config
|
|
|
| self.transformer = torch.nn.ModuleDict(dict(
|
| wte = torch.nn.Embedding(config.vocab_size, config.n_embd),
|
| wpe = torch.nn.Embedding(config.block_size, config.n_embd),
|
| h = torch.nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
| ln_f = torch.nn.LayerNorm(config.n_embd),
|
| ))
|
| self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
|
|
|
| self.transformer.wte.weight = self.lm_head.weight
|
|
|
| def forward(self, idx):
|
| B, T = idx.size()
|
| assert T <= self.config.block_size, f"Sequence length {T} exceeds block size {self.config.block_size}"
|
|
|
| pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
|
| pos_emb = self.transformer.wpe(pos)
|
| tok_emb = self.transformer.wte(idx)
|
| x = tok_emb + pos_emb
|
|
|
| for block in self.transformer.h:
|
| x = block(x)
|
|
|
| x = self.transformer.ln_f(x)
|
| logits = self.lm_head(x)
|
|
|
| return logits
|
|
|
| def apply_repetition_penalty(logits: torch.Tensor, token_ids: List[int], penalty: float = 1.1):
|
| """Apply repetition penalty to logits"""
|
| for token_id in set(token_ids):
|
| logits[0, token_id] /= penalty
|
| return logits
|
|
|
| def top_k_top_p_filtering(logits: torch.Tensor, top_k: int = 50, top_p: float = 0.9):
|
| """Filter logits using top-k and/or top-p (nucleus) filtering"""
|
| if top_k > 0:
|
| values, indices = torch.topk(logits, min(top_k, logits.size(-1)))
|
| logits[logits < values[:, [-1]]] = float('-inf')
|
|
|
| if top_p < 1.0:
|
| sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
|
|
| sorted_indices_to_remove = cumulative_probs > top_p
|
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| sorted_indices_to_remove[..., 0] = 0
|
|
|
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| logits[indices_to_remove] = float('-inf')
|
|
|
| return logits
|
|
|
| def generate_direct(
|
| model: GPT,
|
| prompt: str,
|
| max_new_tokens: int = 100,
|
| temperature: float = 0.8,
|
| top_k: int = 50,
|
| top_p: float = 0.9,
|
| repetition_penalty: float = 1.1
|
| ):
|
| """Generate text using CUDA-optimized PyTorch implementation"""
|
| device = "cuda"
|
|
|
|
|
| enc = tiktoken.get_encoding("gpt2")
|
|
|
|
|
| tokens = enc.encode(prompt)
|
| x = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
|
|
|
| model.eval()
|
| generated_tokens = []
|
|
|
| with torch.no_grad():
|
| for _ in range(max_new_tokens):
|
|
|
| with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| logits = model(x)
|
|
|
|
|
| logits = logits[:, -1, :] / temperature
|
|
|
|
|
| if repetition_penalty > 1.0 and len(generated_tokens) > 0:
|
| logits = apply_repetition_penalty(logits, generated_tokens[-20:], repetition_penalty)
|
|
|
|
|
| filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
|
|
|
|
| probs = F.softmax(filtered_logits, dim=-1)
|
| next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
| x = torch.cat([x, next_token], dim=1)
|
| generated_tokens.append(next_token.item())
|
|
|
|
|
| if next_token.item() == enc.eot_token:
|
| break
|
|
|
|
|
| if x.size(1) > model.config.block_size:
|
| x = x[:, -model.config.block_size:]
|
|
|
|
|
| all_tokens = tokens + generated_tokens
|
| return enc.decode(all_tokens)
|
|
|
| def load_model_direct(checkpoint_path: str):
|
| """Load model from a PyTorch checkpoint - CUDA optimized"""
|
| if not torch.cuda.is_available():
|
| raise RuntimeError("CUDA is not available. Use inference_cpu.py for CPU inference.")
|
|
|
| device = "cuda"
|
| print(f"Loading model from checkpoint: {checkpoint_path}")
|
|
|
|
|
| import sys
|
| import types
|
|
|
|
|
| train_gpt2_module = types.ModuleType('train_gpt2')
|
|
|
| class DummyGPTConfig:
|
| def __init__(self, **kwargs):
|
| for k, v in kwargs.items():
|
| setattr(self, k, v)
|
|
|
| train_gpt2_module.GPTConfig = DummyGPTConfig
|
| sys.modules['train_gpt2'] = train_gpt2_module
|
|
|
| try:
|
|
|
| checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| finally:
|
|
|
| if 'train_gpt2' in sys.modules:
|
| del sys.modules['train_gpt2']
|
|
|
|
|
| config_obj = checkpoint['config']
|
| if hasattr(config_obj, '__dict__'):
|
|
|
| config_dict = vars(config_obj)
|
| else:
|
|
|
| config_dict = config_obj
|
|
|
| config = GPTConfig(**config_dict)
|
| model = GPT(config)
|
| model.load_state_dict(checkpoint['model'])
|
| model.to(device)
|
|
|
|
|
| model = torch.compile(model) if hasattr(torch, 'compile') else model
|
|
|
| return model
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description="Generate text with Ursa Minor Smashed model (CUDA)")
|
| parser.add_argument("--model", type=str, default="model_optimized.pt",
|
| help="Path to model checkpoint (.pt file)")
|
| parser.add_argument("--prompt", type=str, default="Hello, I'm a language model",
|
| help="Input prompt")
|
| parser.add_argument("--max-tokens", type=int, default=100,
|
| help="Maximum number of tokens to generate")
|
| parser.add_argument("--temperature", type=float, default=0.8,
|
| help="Sampling temperature (0.1=conservative, 1.0=creative)")
|
| parser.add_argument("--top-k", type=int, default=50,
|
| help="Top-k sampling (0=disabled)")
|
| parser.add_argument("--top-p", type=float, default=0.9,
|
| help="Top-p (nucleus) sampling")
|
| parser.add_argument("--repetition-penalty", type=float, default=1.1,
|
| help="Repetition penalty (1.0=disabled)")
|
|
|
| args = parser.parse_args()
|
|
|
|
|
| model = load_model_direct(args.model)
|
|
|
| result = generate_direct(
|
| model,
|
| args.prompt,
|
| args.max_tokens,
|
| args.temperature,
|
| args.top_k,
|
| args.top_p,
|
| args.repetition_penalty
|
| )
|
|
|
| print("\nGenerated text:")
|
| print("-" * 50)
|
| print(result)
|
| print("-" * 50)
|
|
|
| if __name__ == "__main__":
|
| main() |