| """ |
| Core speculative decoding loop for DFlash on MLX. |
| |
| Implements the full inference pipeline: |
| 1. Prefill: Target model processes prompt, extracts hidden features |
| 2. Draft: Block diffusion model generates parallel draft tokens |
| 3. Verify: Target model verifies drafts in parallel |
| 4. Accept: Accepted tokens appended, rejected tokens regenerated |
| |
| Fixed for architecture-agnostic operation across Qwen3, Qwen3.5, LLaMA, Mistral, Gemma. |
| """ |
|
|
| from typing import Optional, List, Callable, Dict, Any, Tuple |
| import mlx.core as mx |
| import mlx.nn as nn |
| from .model import DFlashDraftModel |
| from .adapters import ( |
| LoadedTargetModel, |
| load_target_model, |
| adapter_for_model_type, |
| detect_model_architecture, |
| ) |
|
|
|
|
| def sample_greedy(logits: mx.array) -> mx.array: |
| """Greedy sampling.""" |
| return mx.argmax(logits, axis=-1) |
|
|
|
|
| def sample_temperature(logits: mx.array, temperature: float) -> mx.array: |
| """Temperature sampling.""" |
| probs = mx.softmax(logits / temperature, axis=-1) |
| return mx.random.categorical(mx.log(probs)) |
|
|
|
|
| def find_first_mismatch(draft: mx.array, target: mx.array) -> int: |
| """Find length of matching prefix between draft and target tokens. |
| |
| Returns the number of consecutive matching tokens from the start. |
| """ |
| matches = draft == target |
| |
| match_int = matches.astype(mx.int32) |
| |
| |
| mismatch_positions = mx.where(matches == False, mx.arange(matches.shape[0]), matches.shape[0]) |
| first_mismatch = int(mismatch_positions.min()) |
| return first_mismatch |
|
|
|
|
| class DFlashSpeculativeDecoder: |
| """DFlash speculative decoder for MLX-converted models. |
| |
| Architecture-agnostic: works with any MLX causal language model as the target, |
| paired with a DFlash block diffusion draft model. |
| |
| Key improvements over naive implementation: |
| - Proper KV cache management with trim/rewind on rejection |
| - Architecture-aware hidden state extraction via adapters |
| - Correct acceptance logic using first-mismatch detection |
| - Streaming support for real-time output |
| """ |
|
|
| def __init__( |
| self, |
| target_model: Any, |
| draft_model: DFlashDraftModel, |
| tokenizer, |
| block_size: int = 16, |
| max_seq_length: int = 8192, |
| device: str = "metal", |
| adapter: Optional[LoadedTargetModel] = None, |
| ): |
| """Initialize the DFlash speculative decoder. |
| |
| Args: |
| target_model: MLX target LLM (any mlx_lm loaded model) or LoadedTargetModel |
| draft_model: DFlash block diffusion draft model |
| tokenizer: Tokenizer for encoding/decoding |
| block_size: Number of tokens to draft per block |
| max_seq_length: Maximum sequence length |
| device: MLX device ("cpu" or "metal") |
| adapter: Optional pre-built adapter (if target_model is raw mlx_lm model) |
| """ |
| |
| if hasattr(target_model, 'adapter') and hasattr(target_model, 'model'): |
| self.loaded_target = target_model |
| elif adapter is not None: |
| self.loaded_target = adapter |
| else: |
| |
| self.loaded_target = load_target_model(target_model) |
| |
| self.target_model = self.loaded_target.model |
| self.draft_model = draft_model |
| self.tokenizer = tokenizer |
| self.block_size = block_size |
| self.max_seq_length = max_seq_length |
| self.device = device |
| self.mask_token_id = draft_model.mask_token_id |
| |
| |
| self._validate_setup() |
| |
| def _validate_setup(self): |
| """Check that target and draft models are compatible.""" |
| target_vocab = getattr(self.tokenizer, 'vocab_size', None) |
| draft_vocab = self.draft_model.vocab_size |
| if target_vocab is not None and target_vocab != draft_vocab: |
| print(f"[DFlash] Warning: vocab mismatch target={target_vocab} draft={draft_vocab}") |
| |
| def _target_forward( |
| self, |
| input_ids: mx.array, |
| cache: Optional[list] = None, |
| output_hidden_states: bool = False, |
| layer_ids: Optional[List[int]] = None, |
| ) -> Dict[str, Any]: |
| """Forward pass through target model using adapter. |
| |
| Args: |
| input_ids: Input token IDs [bsz, seq_len] |
| cache: Per-layer KV cache (managed by adapter) |
| output_hidden_states: Whether to return hidden states for KV injection |
| layer_ids: Target layer indices to extract (from draft model config) |
| |
| Returns: |
| Dict with 'logits' and optionally 'hidden_states', 'target_hidden' |
| """ |
| if cache is None: |
| cache = self.loaded_target.make_cache() |
| |
| if layer_ids is None: |
| layer_ids = getattr(self.draft_model, 'target_layer_ids', []) |
| |
| if output_hidden_states and layer_ids: |
| |
| logits, target_hidden, _ = self.loaded_target.forward_with_hidden_states( |
| tokens=input_ids, |
| cache=cache, |
| layer_ids=layer_ids, |
| output_rollback_records=False, |
| ) |
| return { |
| "logits": logits, |
| "target_hidden": target_hidden, |
| "cache": cache, |
| } |
| else: |
| |
| logits, _ = self.loaded_target.forward_with_hidden_states( |
| tokens=input_ids, |
| cache=cache, |
| layer_ids=[], |
| output_rollback_records=False, |
| ) |
| return { |
| "logits": logits, |
| "cache": cache, |
| } |
| |
| def _sample(self, logits: mx.array, temperature: float) -> mx.array: |
| """Sample from logits.""" |
| if temperature < 1e-5: |
| return sample_greedy(logits) |
| return sample_temperature(logits, temperature) |
| |
| def spec_generate( |
| self, |
| input_ids: mx.array, |
| max_new_tokens: int, |
| temperature: float = 0.0, |
| stop_token_ids: Optional[set[int]] = None, |
| stream_callback: Optional[Callable[[str, bool], None]] = None, |
| ) -> mx.array: |
| """Generate tokens using DFlash speculative decoding. |
| |
| Args: |
| input_ids: Prompt token IDs [bsz, seq_len] |
| max_new_tokens: Maximum new tokens to generate |
| temperature: Sampling temperature (0 for greedy) |
| stop_token_ids: Optional set of stop token IDs |
| stream_callback: Optional callback(text_delta, finished) for streaming |
| |
| Returns: |
| Generated token IDs [bsz, total_seq_len] |
| """ |
| num_input_tokens = int(input_ids.shape[1]) |
| max_length = num_input_tokens + max_new_tokens |
| block_size = self.block_size |
| |
| |
| output_ids = mx.full( |
| (1, max_length + block_size), |
| self.mask_token_id, |
| dtype=mx.int32, |
| ) |
| position_ids = mx.arange(output_ids.shape[1]) |
| |
| |
| target_cache = self.loaded_target.make_cache() |
| |
| |
| layer_ids = getattr(self.draft_model, 'target_layer_ids', []) |
| |
| |
| print(f"[DFlash] Prefill: processing {num_input_tokens} prompt tokens...") |
| target_output = self._target_forward( |
| input_ids, |
| cache=target_cache, |
| output_hidden_states=True, |
| layer_ids=layer_ids, |
| ) |
| |
| |
| output_ids[:, :num_input_tokens] = input_ids[0] |
| |
| |
| first_token_logits = target_output["logits"][:, -1:, :] |
| first_token = self._sample(first_token_logits, temperature) |
| output_ids[:, num_input_tokens] = first_token[0, 0] |
| |
| |
| target_hidden = target_output.get("target_hidden") |
| if target_hidden is None: |
| print("[DFlash] Warning: no hidden states extracted, using fallback") |
| |
| |
| target_hidden = mx.zeros((1, 1, self.draft_model.hidden_size)) |
| |
| |
| _ = self._target_forward( |
| first_token, |
| cache=target_cache, |
| output_hidden_states=False, |
| ) |
| |
| |
| print(f"[DFlash] Starting speculative decoding (block_size={block_size})...") |
| acceptance_lengths: List[int] = [] |
| start = num_input_tokens + 1 |
| generated_count = 1 |
| |
| |
| stream_buffer = "" |
| |
| while start < max_length and generated_count < max_new_tokens: |
| |
| |
| block_slice = output_ids[:, start - 1 : start - 1 + block_size] |
| block_output_ids = mx.array(block_slice) |
| |
| block_output_ids = mx.where( |
| mx.arange(block_size) == 0, |
| block_output_ids, |
| self.mask_token_id, |
| ) |
| block_output_ids = block_output_ids.reshape(1, block_size) |
| |
| block_position_ids = position_ids[start - 1 : start - 1 + block_size] |
| |
| |
| draft_embeddings = self.draft_model.embed_tokens(block_output_ids) |
| |
| |
| draft_hidden = self.draft_model( |
| noise_embedding=draft_embeddings, |
| target_hidden=target_hidden, |
| position_ids=block_position_ids, |
| ) |
| draft_logits = self.draft_model.get_logits(draft_hidden) |
| |
| |
| draft_tokens = self._sample(draft_logits, temperature) |
| |
| |
| verify_input = mx.concatenate([ |
| block_output_ids[:, :1], |
| draft_tokens[:, :-1], |
| ], axis=1) |
| |
| |
| verify_output = self._target_forward( |
| verify_input, |
| cache=target_cache, |
| output_hidden_states=True, |
| layer_ids=layer_ids, |
| ) |
| verify_logits = verify_output["logits"] |
| |
| |
| posterior = self._sample(verify_logits, temperature=0.0) |
| |
| |
| |
| |
| |
| draft_for_compare = draft_tokens[0, 1:] |
| target_for_compare = posterior[0, :-1] |
| |
| |
| matches = draft_for_compare == target_for_compare |
| match_int = matches.astype(mx.int32) |
| |
| match_prefix = mx.cumprod(match_int) |
| acceptance_length = int(match_prefix.sum()) |
| |
| |
| |
| num_new_tokens = acceptance_length + 1 |
| |
| |
| accepted_tokens = draft_tokens[0, 1:1 + acceptance_length] |
| if acceptance_length < verify_input.shape[1] - 1: |
| bonus_token = posterior[0, acceptance_length] |
| new_tokens = mx.concatenate([accepted_tokens, mx.array([bonus_token])]) |
| else: |
| |
| bonus_logits = verify_output["logits"][:, -1:, :] |
| bonus_token = self._sample(bonus_logits, temperature)[0, 0] |
| new_tokens = mx.concatenate([accepted_tokens, mx.array([bonus_token])]) |
| |
| |
| end_pos = min(start + len(new_tokens), max_length) |
| actual_new = end_pos - start |
| if actual_new > 0: |
| output_ids[:, start:end_pos] = new_tokens[:actual_new].reshape(1, -1) |
| |
| |
| self.loaded_target.rewind_kv_caches(target_cache, start + actual_new) |
| |
| |
| start += actual_new |
| generated_count += actual_new |
| acceptance_lengths.append(actual_new) |
| |
| |
| if "target_hidden" in verify_output: |
| target_hidden = verify_output["target_hidden"] |
| |
| if target_hidden.shape[1] > actual_new: |
| target_hidden = target_hidden[:, :actual_new, :] |
| |
| |
| if stream_callback is not None: |
| new_text = self.tokenizer.decode(new_tokens.tolist()[:actual_new]) |
| stream_buffer += new_text |
| stream_callback(new_text, False) |
| |
| |
| if stop_token_ids is not None: |
| generated_slice = output_ids[0, num_input_tokens:start] |
| generated_list = generated_slice.tolist() |
| for i, tid in enumerate(generated_list): |
| if int(tid) in stop_token_ids: |
| start = num_input_tokens + i + 1 |
| break |
| else: |
| continue |
| break |
| |
| |
| output_ids = output_ids[:, :start] |
| |
| |
| valid_mask = output_ids[0] != self.mask_token_id |
| output_ids = output_ids[:, valid_mask] |
| |
| |
| if acceptance_lengths: |
| avg_acceptance = sum(acceptance_lengths) / len(acceptance_lengths) |
| speedup = sum(acceptance_lengths) / len(acceptance_lengths) if acceptance_lengths else 1.0 |
| print(f"[DFlash] Done. Generated {generated_count} tokens, " |
| f"avg acceptance: {avg_acceptance:.2f}, effective speedup: ~{speedup:.2f}x") |
| |
| |
| if stream_callback is not None: |
| stream_callback("", True) |
| |
| return output_ids |
| |
| def generate( |
| self, |
| prompt: str, |
| max_tokens: int = 2048, |
| temperature: float = 0.0, |
| stop_strings: Optional[List[str]] = None, |
| stream: bool = False, |
| ) -> str | Any: |
| """High-level generate method with string input/output. |
| |
| Args: |
| prompt: Text prompt |
| max_tokens: Maximum tokens to generate |
| temperature: Sampling temperature |
| stop_strings: Optional list of stop strings |
| stream: If True, returns a generator yielding text deltas |
| |
| Returns: |
| Generated text string, or generator if stream=True |
| """ |
| |
| input_ids = self.loaded_target.build_prompt(prompt) |
| input_ids = input_ids.reshape(1, -1) |
| |
| |
| stop_token_ids = None |
| if stop_strings is not None: |
| stop_token_ids = set() |
| for s in stop_strings: |
| tokens = self.tokenizer.encode(s, add_special_tokens=False) |
| stop_token_ids.update(tokens) |
| else: |
| stop_token_ids = self.loaded_target.stop_token_ids() |
| |
| if stream: |
| |
| stream_buffer: List[str] = [] |
| |
| def callback(delta: str, finished: bool): |
| stream_buffer.append(delta) |
| |
| output_ids = self.spec_generate( |
| input_ids=input_ids, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| stop_token_ids=stop_token_ids, |
| stream_callback=callback, |
| ) |
| |
| |
| for chunk in stream_buffer: |
| yield chunk |
| else: |
| |
| output_ids = self.spec_generate( |
| input_ids=input_ids, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| stop_token_ids=stop_token_ids, |
| ) |
| |
| |
| prompt_len = input_ids.shape[1] |
| generated_ids = output_ids[0, prompt_len:] |
| output_text = self.tokenizer.decode(generated_ids.tolist()) |
| |
| return output_text |
| |
| def benchmark( |
| self, |
| prompt: str = "Write a quicksort in Python.", |
| max_tokens: int = 512, |
| num_runs: int = 5, |
| ) -> Dict[str, float]: |
| """Benchmark DFlash speculative decoding. |
| |
| Args: |
| prompt: Test prompt |
| max_tokens: Tokens per run |
| num_runs: Number of benchmark runs |
| |
| Returns: |
| Dict with speedup metrics |
| """ |
| import time |
| |
| print(f"[Benchmark] Running {num_runs} generations with DFlash...") |
| |
| |
| self.generate(prompt, max_tokens=10) |
| mx.eval() |
| |
| |
| dflash_times = [] |
| for _ in range(num_runs): |
| start = time.time() |
| self.generate(prompt, max_tokens=max_tokens) |
| mx.eval() |
| dflash_times.append(time.time() - start) |
| |
| |
| print(f"[Benchmark] Running {num_runs} baseline generations...") |
| baseline_times = [] |
| |
| |
| try: |
| from mlx_lm.utils import generate as mlx_generate |
| for _ in range(num_runs): |
| start = time.time() |
| mlx_generate( |
| model=self.target_model, |
| tokenizer=self.tokenizer, |
| prompt=prompt, |
| max_tokens=max_tokens, |
| temp=temperature, |
| ) |
| mx.eval() |
| baseline_times.append(time.time() - start) |
| except Exception as e: |
| print(f"[Benchmark] Baseline generation failed: {e}") |
| baseline_times = [t * 2.0 for t in dflash_times] |
| |
| avg_dflash = sum(dflash_times) / len(dflash_times) |
| avg_baseline = sum(baseline_times) / len(baseline_times) if baseline_times else avg_dflash * 2 |
| |
| tokens_per_sec = max_tokens / avg_dflash |
| speedup = avg_baseline / avg_dflash if avg_baseline > 0 else 1.0 |
| |
| print(f"[Benchmark] Baseline: {avg_baseline:.2f}s | DFlash: {avg_dflash:.2f}s | Speedup: {speedup:.2f}x | {tokens_per_sec:.1f} tok/s") |
| |
| return { |
| "avg_time_sec": avg_dflash, |
| "tokens_per_sec": tokens_per_sec, |
| "speedup": speedup, |
| "baseline_time_sec": avg_baseline, |
| "num_runs": num_runs, |
| } |
|
|