| | import torch |
| | import tiktoken |
| | from model import ismail, ModelArgs |
| | from data import TurkishTokenizerWrapper, TURKISH_TOKENIZER_AVAILABLE |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def generate_text_simple(model, idx, max_new_tokens, context_size): |
| | """ |
| | Generate text using simple greedy decoding (argmax). |
| | |
| | Args: |
| | model: The transformer model |
| | idx: Input token indices of shape (batch_size, seq_len) |
| | max_new_tokens: Number of new tokens to generate |
| | context_size: Maximum context size the model can handle |
| | |
| | Returns: |
| | Generated token indices of shape (batch_size, seq_len + max_new_tokens) |
| | """ |
| | |
| | for _ in range(max_new_tokens): |
| | |
| | |
| | |
| | idx_cond = idx[:, -context_size:] |
| |
|
| | |
| | with torch.no_grad(): |
| | logits = model(idx_cond) |
| |
|
| | |
| | |
| | logits = logits[:, -1, :] |
| |
|
| | |
| | idx_next = torch.argmax(logits, dim=-1, keepdim=True) |
| |
|
| | |
| | idx = torch.cat((idx, idx_next), dim=1) |
| |
|
| | return idx |
| |
|
| |
|
| | def generate_text_with_sampling(model, idx, max_new_tokens, context_size, temperature=1.0, top_k=None): |
| | """ |
| | Generate text using sampling with temperature and optional top-k filtering. |
| | |
| | Args: |
| | model: The transformer model |
| | idx: Input token indices of shape (batch_size, seq_len) |
| | max_new_tokens: Number of new tokens to generate |
| | context_size: Maximum context size the model can handle |
| | temperature: Sampling temperature (higher = more random, lower = more deterministic) |
| | top_k: If set, only sample from the top k most likely tokens |
| | |
| | Returns: |
| | Generated token indices of shape (batch_size, seq_len + max_new_tokens) |
| | """ |
| | for _ in range(max_new_tokens): |
| | |
| | idx_cond = idx[:, -context_size:] |
| |
|
| | |
| | with torch.no_grad(): |
| | logits = model(idx_cond) |
| |
|
| | |
| | logits = logits[:, -1, :] |
| |
|
| | |
| | temperature = max(temperature, 1e-8) |
| | logits = logits / temperature |
| |
|
| | |
| | if top_k is not None: |
| | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| | logits[logits < v[:, [-1]]] = -float('Inf') |
| |
|
| | |
| | probs = torch.softmax(logits, dim=-1, dtype=torch.float32) |
| |
|
| | |
| | if torch.isnan(probs).any() or torch.isinf(probs).any(): |
| | |
| | probs = torch.ones_like(probs) / probs.size(-1) |
| |
|
| | |
| | probs = probs / probs.sum(dim=-1, keepdim=True) |
| |
|
| | |
| | idx_next = torch.multinomial(probs, num_samples=1) |
| |
|
| | |
| | idx = torch.cat((idx, idx_next), dim=1) |
| |
|
| | return idx |
| |
|
| |
|
| | def text_to_token_ids(text, tokenizer): |
| | """ |
| | Convert text to token IDs. |
| | |
| | Args: |
| | text: Input text string |
| | tokenizer: Tokenizer instance (tiktoken or TurkishTokenizerWrapper) |
| | |
| | Returns: |
| | Tensor of token IDs with shape (1, seq_len) |
| | """ |
| | |
| | if isinstance(tokenizer, TurkishTokenizerWrapper): |
| | encoded = tokenizer.encode(text) |
| | else: |
| | encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"}) |
| |
|
| | encoded_tensor = torch.tensor(encoded).unsqueeze(0) |
| | return encoded_tensor |
| |
|
| |
|
| | def token_ids_to_text(token_ids, tokenizer): |
| | """ |
| | Convert token IDs to text. |
| | |
| | Args: |
| | token_ids: Tensor of token IDs, can be 1D or 2D |
| | tokenizer: Tokenizer instance (tiktoken or TurkishTokenizerWrapper) |
| | |
| | Returns: |
| | Decoded text string |
| | """ |
| | |
| | if token_ids.dim() == 2: |
| | token_ids = token_ids.squeeze(0) |
| |
|
| | |
| | flat = token_ids.tolist() |
| | return tokenizer.decode(flat) |
| |
|
| |
|
| | def get_tokenizer(use_turkish=False, tokenizer_name="gpt2"): |
| | """ |
| | Get the appropriate tokenizer based on user preference. |
| | |
| | Args: |
| | use_turkish: Whether to use Turkish tokenizer |
| | tokenizer_name: Name of tiktoken tokenizer to use if not using Turkish |
| | |
| | Returns: |
| | Tokenizer instance (TurkishTokenizerWrapper or tiktoken tokenizer) |
| | """ |
| | if use_turkish: |
| | if not TURKISH_TOKENIZER_AVAILABLE: |
| | raise ImportError( |
| | "Turkish tokenizer requested but not available. " |
| | "Install it with: pip install turkish-tokenizer" |
| | ) |
| | tokenizer = TurkishTokenizerWrapper() |
| | print(f"🇹🇷 Using Turkish Tokenizer (vocab size: {tokenizer.n_vocab:,})") |
| | return tokenizer |
| | else: |
| | tokenizer = tiktoken.get_encoding(tokenizer_name) |
| | print(f"📚 Using tiktoken tokenizer: {tokenizer_name} (vocab size: {tokenizer.n_vocab:,})") |
| | return tokenizer |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def load_checkpoint(model, checkpoint_path): |
| | """ |
| | Load a trained checkpoint into the model. |
| | |
| | Args: |
| | model: The model instance |
| | checkpoint_path: Path to the checkpoint file (.pt) |
| | |
| | Returns: |
| | The loaded checkpoint dictionary with metadata |
| | """ |
| | print(f"\n📦 Loading checkpoint: {checkpoint_path}") |
| | checkpoint = torch.load(checkpoint_path, map_location='cpu') |
| |
|
| | |
| | if 'model_state_dict' in checkpoint: |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | print(f"✅ Loaded model state from checkpoint") |
| | if 'step' in checkpoint: |
| | print(f" Training step: {checkpoint['step']:,}") |
| | if 'loss' in checkpoint: |
| | print(f" Loss: {checkpoint['loss']:.4f}") |
| | else: |
| | |
| | model.load_state_dict(checkpoint) |
| | print(f"✅ Loaded model state (direct)") |
| |
|
| | return checkpoint |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import json |
| | from pathlib import Path |
| | import sys |
| |
|
| | |
| | USE_TURKISH_TOKENIZER = True |
| |
|
| | |
| | |
| | |
| | CHECKPOINT_PATH = None |
| |
|
| | |
| | if len(sys.argv) > 1: |
| | CHECKPOINT_PATH = sys.argv[1] |
| | print(f"🔧 Using checkpoint from command line: {CHECKPOINT_PATH}") |
| |
|
| | |
| | config_path = Path("config.json") |
| | if config_path.exists(): |
| | with open(config_path) as f: |
| | config = json.load(f) |
| | print(f"✅ Loaded config from {config_path}") |
| | args = ModelArgs(**config["model"]) |
| | else: |
| | print("⚠️ config.json not found, using default ModelArgs") |
| | args = ModelArgs() |
| |
|
| | |
| | tokenizer_name = getattr(args, "tokenizer_name", "gpt2") |
| | |
| | use_turkish = (tokenizer_name.lower() == "turkish") or USE_TURKISH_TOKENIZER |
| |
|
| | tokenizer = get_tokenizer( |
| | use_turkish=use_turkish, |
| | tokenizer_name="gpt2" if use_turkish else tokenizer_name |
| | ) |
| |
|
| | |
| | if use_turkish and isinstance(tokenizer, TurkishTokenizerWrapper): |
| | if args.vocab_size != tokenizer.n_vocab: |
| | print(f"⚠️ Config vocab_size ({args.vocab_size:,}) doesn't match tokenizer ({tokenizer.n_vocab:,})") |
| | args.vocab_size = tokenizer.n_vocab |
| | print(f"📊 Updated vocab_size to {args.vocab_size:,} for Turkish tokenizer") |
| |
|
| | |
| | print("\n🚀 Initializing model...") |
| | torch.manual_seed(123) |
| | model = ismail(args) |
| |
|
| | |
| | if CHECKPOINT_PATH: |
| | checkpoint_file = Path(CHECKPOINT_PATH) |
| | if checkpoint_file.exists(): |
| | load_checkpoint(model, checkpoint_file) |
| | else: |
| | print(f"❌ Checkpoint not found: {CHECKPOINT_PATH}") |
| | print(" Using random initialization instead") |
| | else: |
| | print("ℹ️ No checkpoint specified, using random initialization") |
| |
|
| | model.eval() |
| |
|
| | |
| | print(f"\n{'='*60}") |
| | print("EXAMPLE 1: GREEDY GENERATION (ARGMAX)") |
| | print(f"{'='*60}") |
| |
|
| | |
| | if USE_TURKISH_TOKENIZER: |
| | start_context = "Merhaba, ben" |
| | else: |
| | start_context = "Hello, I am" |
| | print(f"\nInput: '{start_context}'") |
| |
|
| | token_ids = text_to_token_ids(start_context, tokenizer) |
| | print(f"Token IDs shape: {token_ids.shape}") |
| |
|
| | generated_ids = generate_text_simple( |
| | model=model, |
| | idx=token_ids, |
| | max_new_tokens=20, |
| | context_size=args.max_seq_len |
| | ) |
| |
|
| | generated_text = token_ids_to_text(generated_ids, tokenizer) |
| | print(f"\nGenerated: '{generated_text}'") |
| | print(f"Total tokens: {generated_ids.shape[1]}") |
| |
|
| | |
| | print(f"\n{'='*60}") |
| | print("EXAMPLE 2: SAMPLING WITH TEMPERATURE") |
| | print(f"{'='*60}") |
| |
|
| | if USE_TURKISH_TOKENIZER: |
| | start_context = "Bir varmış bir yokmuş" |
| | else: |
| | start_context = "Once upon a time" |
| | print(f"\nInput: '{start_context}'") |
| |
|
| | token_ids = text_to_token_ids(start_context, tokenizer) |
| |
|
| | |
| | for temp in [0.5, 1.0, 1.5]: |
| | print(f"\n--- Temperature: {temp} ---") |
| | generated_ids = generate_text_with_sampling( |
| | model=model, |
| | idx=token_ids.clone(), |
| | max_new_tokens=20, |
| | context_size=args.max_seq_len, |
| | temperature=temp |
| | ) |
| | generated_text = token_ids_to_text(generated_ids, tokenizer) |
| | print(f"Generated: '{generated_text}'") |
| |
|
| | |
| | print(f"\n{'='*60}") |
| | print("EXAMPLE 3: TOP-K SAMPLING") |
| | print(f"{'='*60}") |
| |
|
| | if USE_TURKISH_TOKENIZER: |
| | start_context = "Yapay zekanın geleceği" |
| | else: |
| | start_context = "The future of AI is" |
| | print(f"\nInput: '{start_context}'") |
| |
|
| | token_ids = text_to_token_ids(start_context, tokenizer) |
| |
|
| | generated_ids = generate_text_with_sampling( |
| | model=model, |
| | idx=token_ids, |
| | max_new_tokens=30, |
| | context_size=args.max_seq_len, |
| | temperature=0.8, |
| | top_k=50 |
| | ) |
| |
|
| | generated_text = token_ids_to_text(generated_ids, tokenizer) |
| | print(f"Generated: '{generated_text}'") |
| |
|
| | print(f"\n{'='*60}") |
| | print("Generation examples completed!") |
| | print(f"{'='*60}\n") |
| |
|