diffusionGPT / infer-base.py
thejagstudio's picture
Upload 10 files
486838c verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer
from dataclasses import dataclass
import os
import math
# ============== Model Architecture ==============
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
var = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(var + self.eps)
return self.weight * x
class RotaryEmbedding(nn.Module):
"""Rotary Position Embeddings (RoPE) with NTK extrapolation."""
def __init__(self, dim, max_position_embeddings=16384, base=100000, scaling_factor=1.0):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.base = base
self.max_position_embeddings = max_position_embeddings
self.inv_freq = None
self._cache = {}
def _update_freqs(self, device):
base = self.base * (self.scaling_factor ** (self.dim / (self.dim - 2)))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.inv_freq = inv_freq
def forward(self, x, seq_len=None):
if seq_len is None:
seq_len = x.shape[-2]
if self.inv_freq is None or self.inv_freq.device != x.device:
self._update_freqs(x.device)
cache_key = (seq_len, x.device, x.dtype)
if cache_key in self._cache:
return self._cache[cache_key]
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()[None, None, :, :]
sin = emb.sin()[None, None, :, :]
self._cache[cache_key] = (cos, sin)
if len(self._cache) > 10:
self._cache.pop(next(iter(self._cache)))
return cos, sin
def apply_rotary_pos_emb(q, k, cos, sin):
"""Apply rotary embeddings to Q and K."""
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class DiffusionAttention(nn.Module):
"""Multi-head attention with GQA and Flash Attention support."""
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.use_flash_attn = config.use_flash_attn
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
def forward(self, hidden_states, freqs_cis, attention_mask=None, past_kv=None):
bsz, q_len, _ = hidden_states.size()
q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = freqs_cis
cos = cos[:, :, :q_len, :]
sin = sin[:, :, :q_len, :]
q, k = apply_rotary_pos_emb(q, k, cos, sin)
if past_kv is not None:
cache_k, cache_v = past_kv
k = torch.cat([cache_k, k], dim=2)
v = torch.cat([cache_v, v], dim=2)
current_kv = (k, v)
k = k.repeat_interleave(self.num_key_value_groups, dim=1)
v = v.repeat_interleave(self.num_key_value_groups, dim=1)
attn_mask = None
if attention_mask is not None:
attn_mask = attention_mask[:, None, None, :].to(dtype=q.dtype)
attn_mask = (1.0 - attn_mask) * torch.finfo(q.dtype).min
output = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
)
output = output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
return self.o_proj(output), current_kv
class MLP(nn.Module):
"""Gated MLP with SiLU activation."""
def __init__(self, config):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.act_fn = nn.SiLU()
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class BlockDiffusionBlock(nn.Module):
"""Transformer block with pre-norm, attention, and MLP."""
def __init__(self, config):
super().__init__()
self.self_attn = DiffusionAttention(config)
self.mlp = MLP(config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.use_activation_checkpointing = config.use_activation_checkpointing
def forward(self, hidden_states, freqs_cis, attention_mask, past_kv):
return self._forward(hidden_states, freqs_cis, attention_mask, past_kv)
def _forward(self, hidden_states, freqs_cis, attention_mask, past_kv):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_out, new_kv = self.self_attn(hidden_states, freqs_cis, attention_mask, past_kv)
hidden_states = residual + attn_out
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + self.mlp(hidden_states)
return hidden_states, new_kv
@dataclass
class ModelConfig:
"""Model architecture configuration."""
vocab_size: int = 151936
hidden_size: int = 1024
intermediate_size: int = 2816
num_hidden_layers: int = 16
num_attention_heads: int = 16
num_key_value_heads: int = 4
max_position_embeddings: int = 16384
rms_norm_eps: float = 1e-6
rope_theta: float = 100000.0
pad_token_id: int = 0
mask_token_id: int = 1
use_flash_attn: bool = True
use_activation_checkpointing: bool = False
attention_dropout: float = 0.0
hidden_dropout: float = 0.0
class DiffusionLLM(nn.Module):
"""Complete diffusion language model."""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
pad_idx = config.pad_token_id if config.pad_token_id < config.vocab_size else None
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=pad_idx)
self.layers = nn.ModuleList([BlockDiffusionBlock(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.rotary_emb = RotaryEmbedding(
config.hidden_size // config.num_attention_heads,
config.max_position_embeddings
)
self.lm_head.weight = self.embed_tokens.weight
def forward(self, input_ids, attention_mask=None, past_key_values=None):
bsz, seqlen = input_ids.shape
hidden_states = self.embed_tokens(input_ids)
freqs_cis = self.rotary_emb(hidden_states, seq_len=seqlen)
if past_key_values is None:
past_key_values = [None] * len(self.layers)
new_kvs = []
for i, layer in enumerate(self.layers):
hidden_states, kv = layer(hidden_states, freqs_cis, attention_mask, past_key_values[i])
new_kvs.append(kv)
hidden_states = self.norm(hidden_states)
logits = self.lm_head(hidden_states)
return logits, new_kvs
def get_num_params(self, trainable_only=True):
if trainable_only:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
else:
return sum(p.numel() for p in self.parameters())
# ============== Inference Functions ==============
def load_model(model_path: str, device: str = 'cuda'):
"""Load a saved model (fp16 or fp32) for inference."""
print(f"Loading model from {model_path}...")
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
config = checkpoint['config']
model = DiffusionLLM(config)
state_dict = checkpoint['model_state']
state_dict = {k: v.float() for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
num_params = model.get_num_params() / 1e6
file_size = os.path.getsize(model_path) / 1e6
print(f"✓ Model loaded: {num_params:.1f}M params from {file_size:.1f} MB file")
return model, config
def visualize_diffusion_state(tokenizer, context_ids, mask_blocks, is_masked_list, config, clear=True, block_colors=None):
"""Visualize the current state of diffusion generation with multiple blocks.
Args:
mask_blocks: Either a single block tensor (1, block_size) or list of block tensors
is_masked_list: Either a single mask tensor (1, block_size) or list of mask tensors
block_colors: List of ANSI color codes for each block. If None, uses defaults.
"""
import sys
import os
# Default colors for different blocks (green, cyan, yellow, magenta)
DEFAULT_COLORS = ['\033[92m', '\033[96m', '\033[93m', '\033[95m']
MASK_COLOR = '\033[90m' # Gray for masked tokens
RESET = '\033[0m'
# Normalize inputs to lists
if not isinstance(mask_blocks, list):
mask_blocks = [mask_blocks]
is_masked_list = [is_masked_list]
if block_colors is None:
block_colors = DEFAULT_COLORS
# Decode context (prompt + previously generated blocks) and replace newlines
context_text = tokenizer.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
# Build visualization for all blocks
all_blocks_text = []
for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
color = block_colors[block_idx % len(block_colors)]
block_tokens = mask_block[0].tolist()
block_color_tokens = []
for i, token_id in enumerate(block_tokens):
if is_masked[0, i]:
# Use block-specific color for masked tokens to distinguish blocks
block_color_tokens.append(f'{MASK_COLOR}██{RESET}')
else:
# Decode individual token; use block color for revealed tokens
token_text = tokenizer.decode([token_id], skip_special_tokens=False)
block_color_tokens.append(f'{color}{token_text}{RESET}')
all_blocks_text.append(''.join(block_color_tokens))
# Join all blocks with a subtle separator
blocks_combined = ''.join(all_blocks_text)
# Clear entire terminal
if clear:
clear_cmd = 'cls' if os.name == 'nt' else 'clear'
try:
os.system(clear_cmd)
except Exception:
sys.stdout.write('\r\033[K')
# Print legend for parallel blocks
if len(mask_blocks) > 1:
legend_parts = []
for i in range(len(mask_blocks)):
color = block_colors[i % len(block_colors)]
legend_parts.append(f'{color}Block {i+1}{RESET}')
print(f"Generating: {' | '.join(legend_parts)}\n")
# Print the full context with colored blocks
print(f"{context_text}{blocks_combined}", flush=True)
def demo_visualize_truncation():
"""Demo for visualize_diffusion_state without a full model.
Simulates streaming output and verifies there is no line duplication when content exceeds terminal width.
"""
class MockTokenizer:
def __init__(self):
# Map token id to token text (simple ASCII characters and spaces)
self.vocab = {i: chr(65 + (i % 26)) for i in range(256)}
self.vocab[32] = ' '
self.eos_token = '\n'
self.pad_token = ' '
def decode(self, ids, skip_special_tokens=True):
# ids can be tensor or list
if isinstance(ids, torch.Tensor):
ids = ids.tolist()
if isinstance(ids, (list, tuple)):
return ''.join(self.vocab.get(int(i) % 256, '?') for i in ids)
return str(ids)
tok = MockTokenizer()
# Create a long context and a block that's also long
# Make context exceed terminal width
term_width = 80
long_context_ids = torch.tensor([[i % 26 + 65 for i in range(120)]], dtype=torch.long)
block_size = 32
mask_block = torch.full((1, block_size), 32, dtype=torch.long) # spaces
is_masked = torch.ones(1, block_size, dtype=torch.bool)
for i in range(0, block_size, 3):
is_masked[0, i] = False
mask_block[0, i] = 65 + (i % 26)
print('\nRunning demo: long prompt + block to test truncation\n')
for i in range(8):
visualize_diffusion_state(tok, long_context_ids, [mask_block], [is_masked], ModelConfig(), clear=(i > 0))
# rotate some tokens to simulate diffusion
mask_block = torch.roll(mask_block, shifts=1, dims=1)
time_delay = 0.08
try:
import time
time.sleep(time_delay)
except Exception:
pass
print('\n\nDemo completed.')
@torch.no_grad()
def generate_block_diffusion(
model,
tokenizer,
prompt: str,
steps: int = 16,
block_size: int = 64,
max_new_tokens: int = 256,
device: str = 'cuda',
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 0.9,
repetition_penalty: float = 1.2,
no_repeat_ngram_size: int = 3,
visualize: bool = False,
parallel_blocks: int = 1, # Number of blocks to generate in parallel
):
"""Generate text using block diffusion with proper sampling and repetition control.
Args:
visualize: If True, stream output in real-time showing the diffusion effect.
parallel_blocks: Number of blocks to generate in parallel (1-4 recommended).
"""
import time
model.eval()
prompt_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
config = model.module.config if hasattr(model, 'module') else model.config
if hasattr(model, '_orig_mod'):
config = model._orig_mod.config
num_blocks = max_new_tokens // block_size
parallel_blocks = min(parallel_blocks, num_blocks) # Can't parallelize more than total blocks
if not visualize:
if parallel_blocks > 1:
print(f"Generating {num_blocks} blocks of {block_size} tokens each ({parallel_blocks} blocks in parallel)...")
else:
print(f"Generating {num_blocks} blocks of {block_size} tokens each...")
else:
print(f"\n\033[94mStarting diffusion generation...\033[0m\n")
print(prompt, end='', flush=True)
context_ids = prompt_ids
all_generated_tokens = set(prompt_ids[0].tolist())
# Process blocks in batches of parallel_blocks
blocks_generated = 0
while blocks_generated < num_blocks:
# Determine how many blocks to generate this iteration
current_parallel = min(parallel_blocks, num_blocks - blocks_generated)
if current_parallel > 1:
# Parallel block generation
generated_blocks = _generate_parallel_blocks(
model, tokenizer, context_ids, config, device,
current_parallel, block_size, steps, temperature,
top_k, top_p, repetition_penalty, no_repeat_ngram_size,
all_generated_tokens, visualize
)
# Concatenate all generated blocks to context
for block in generated_blocks:
context_ids = torch.cat([context_ids, block], dim=1)
all_generated_tokens.update(block[0].tolist())
if not visualize:
print(f" Blocks {blocks_generated + 1}-{blocks_generated + current_parallel}/{num_blocks} complete")
blocks_generated += current_parallel
else:
# Single block generation (original logic)
mask_block, block_token_history = _generate_single_block(
model, tokenizer, context_ids, config, device,
block_size, steps, temperature, top_k, top_p,
repetition_penalty, no_repeat_ngram_size,
all_generated_tokens, visualize
)
context_ids = torch.cat([context_ids, mask_block], dim=1)
all_generated_tokens.update(mask_block[0].tolist())
if not visualize:
print(f" Block {blocks_generated + 1}/{num_blocks} complete")
blocks_generated += 1
if visualize:
# Final newline after visualization
print("\n")
generated_ids = context_ids[0].tolist()
return tokenizer.decode(generated_ids, skip_special_tokens=True)
def _generate_single_block(
model, tokenizer, context_ids, config, device,
block_size, steps, temperature, top_k, top_p,
repetition_penalty, no_repeat_ngram_size,
all_generated_tokens, visualize
):
"""Generate a single block using diffusion."""
mask_block = torch.full((1, block_size), config.mask_token_id, device=device)
is_masked = torch.ones(1, block_size, dtype=torch.bool, device=device)
block_token_history = []
for step_idx in range(steps):
full_input = torch.cat([context_ids, mask_block], dim=1)
attention_mask = torch.ones_like(full_input, dtype=torch.float32)
logits, _ = model(full_input, attention_mask=attention_mask)
block_logits = logits[:, -block_size:, :]
block_logits = _apply_sampling_controls(
block_logits, context_ids, mask_block, is_masked,
repetition_penalty, temperature, top_k, top_p,
no_repeat_ngram_size, block_token_history
)
probs = F.softmax(block_logits, dim=-1)
probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
probs = probs.clamp(min=1e-10)
probs = probs / probs.sum(dim=-1, keepdim=True)
sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1)
sampled_tokens = sampled_tokens.view(1, block_size)
confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
tokens_to_unmask = max(1, block_size // steps)
if step_idx == steps - 1:
tokens_to_unmask = is_masked.sum().item()
if tokens_to_unmask > 0 and is_masked.sum() > 0:
masked_confidence = confidence.clone()
masked_confidence[~is_masked] = -1.0
num_to_unmask = min(tokens_to_unmask, is_masked.sum().item())
_, top_indices = torch.topk(masked_confidence.view(-1), num_to_unmask)
for idx in top_indices:
mask_block[0, idx] = sampled_tokens[0, idx]
is_masked[0, idx] = False
block_token_history.append(sampled_tokens[0, idx].item())
all_generated_tokens.add(sampled_tokens[0, idx].item())
if visualize:
visualize_diffusion_state(tokenizer, context_ids, [mask_block], [is_masked], config, clear=(step_idx > 0))
return mask_block, block_token_history
def _generate_parallel_blocks(
model, tokenizer, context_ids, config, device,
num_parallel, block_size, steps, temperature,
top_k, top_p, repetition_penalty, no_repeat_ngram_size,
all_generated_tokens, visualize
):
"""Generate multiple blocks in parallel using batched computation.
Each block sees all previous blocks in the sequence, maintaining proper order:
- Block 0: context + [block0]
- Block 1: context + [block0] + [block1]
- Block 2: context + [block0] + [block1] + [block2]
- etc.
This ensures sequential coherence while still benefiting from batched computation.
"""
batch_size = num_parallel
context_len = context_ids.shape[1]
# Initialize mask blocks for all parallel blocks
# Shape: (num_parallel, block_size)
mask_blocks = torch.full((batch_size, block_size), config.mask_token_id, device=device)
is_masked = torch.ones(batch_size, block_size, dtype=torch.bool, device=device)
block_token_histories = [[] for _ in range(batch_size)]
for step_idx in range(steps):
# Build inputs with proper sequential structure
# Each batch item has context + all blocks up to and including its own position
# Block i sees: context + block_0 + block_1 + ... + block_i
# Create padded inputs - each batch item has different length
# We'll pad to the longest sequence (which is the last block)
max_seq_len = context_len + (num_parallel * block_size)
# Build full input for each batch item
full_inputs = []
attention_masks = []
for b in range(batch_size):
# This block sees: context + all previous blocks + its own block
seq_parts = [context_ids[0]] # Start with context
# Add all blocks from 0 to b (inclusive)
for prev_b in range(b + 1):
seq_parts.append(mask_blocks[prev_b])
# Concatenate to form this batch item's input
batch_input = torch.cat(seq_parts, dim=0) # (seq_len,)
current_len = batch_input.shape[0]
# Pad to max_seq_len
padding_needed = max_seq_len - current_len
if padding_needed > 0:
padding = torch.full((padding_needed,), config.pad_token_id, device=device)
batch_input = torch.cat([batch_input, padding], dim=0)
full_inputs.append(batch_input)
# Create attention mask (1 for real tokens, 0 for padding)
attn_mask = torch.zeros(max_seq_len, device=device)
attn_mask[:current_len] = 1.0
attention_masks.append(attn_mask)
# Stack into batched tensors
full_input = torch.stack(full_inputs, dim=0) # (batch, max_seq_len)
attention_mask = torch.stack(attention_masks, dim=0) # (batch, max_seq_len)
# Single forward pass for all blocks
logits, _ = model(full_input, attention_mask=attention_mask)
# Extract logits for each block's position
# Block b's logits are at positions [context_len + b*block_size : context_len + (b+1)*block_size]
block_logits_list = []
for b in range(batch_size):
start_pos = context_len + (b * block_size)
end_pos = start_pos + block_size
block_logits_list.append(logits[b, start_pos:end_pos, :])
block_logits = torch.stack(block_logits_list, dim=0) # (batch, block_size, vocab)
# Apply sampling controls per batch item
for b in range(batch_size):
# Build context that includes previous blocks for repetition penalty
extended_context = context_ids
if b > 0:
prev_blocks = torch.cat([mask_blocks[pb:pb+1] for pb in range(b)], dim=1)
extended_context = torch.cat([context_ids, prev_blocks], dim=1)
block_logits[b:b+1] = _apply_sampling_controls(
block_logits[b:b+1],
extended_context,
mask_blocks[b:b+1],
is_masked[b:b+1],
repetition_penalty, temperature, top_k, top_p,
no_repeat_ngram_size, block_token_histories[b]
)
probs = F.softmax(block_logits, dim=-1)
probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
probs = probs.clamp(min=1e-10)
probs = probs / probs.sum(dim=-1, keepdim=True)
# Sample for all batches
sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1)
sampled_tokens = sampled_tokens.view(batch_size, block_size)
confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
tokens_to_unmask = max(1, block_size // steps)
if step_idx == steps - 1:
tokens_to_unmask = block_size # Unmask all remaining
# Unmask for each batch item
for b in range(batch_size):
if is_masked[b].sum() > 0:
masked_confidence = confidence[b].clone()
masked_confidence[~is_masked[b]] = -1.0
num_to_unmask = min(tokens_to_unmask, is_masked[b].sum().item())
if num_to_unmask > 0:
_, top_indices = torch.topk(masked_confidence, num_to_unmask)
for idx in top_indices:
mask_blocks[b, idx] = sampled_tokens[b, idx]
is_masked[b, idx] = False
block_token_histories[b].append(sampled_tokens[b, idx].item())
if visualize:
# Visualize all blocks with different colors
block_list = [mask_blocks[b:b+1] for b in range(batch_size)]
is_masked_list = [is_masked[b:b+1] for b in range(batch_size)]
visualize_diffusion_state(
tokenizer, context_ids, block_list, is_masked_list,
config, clear=(step_idx > 0)
)
# Return list of generated blocks
return [mask_blocks[b:b+1] for b in range(batch_size)]
def _apply_sampling_controls(
block_logits, context_ids, mask_block, is_masked,
repetition_penalty, temperature, top_k, top_p,
no_repeat_ngram_size, block_token_history
):
"""Apply repetition penalty, temperature, top-k, top-p, and n-gram blocking."""
if repetition_penalty != 1.0:
seen_tokens = set(context_ids[0].tolist())
for i in range(mask_block.shape[1]):
if not is_masked[0, i]:
seen_tokens.add(mask_block[0, i].item())
for token_id in seen_tokens:
if token_id < block_logits.shape[-1]:
if block_logits[0, :, token_id].mean() > 0:
block_logits[:, :, token_id] /= repetition_penalty
else:
block_logits[:, :, token_id] *= repetition_penalty
block_logits = block_logits / temperature
if top_k > 0:
top_k_logits, top_k_indices = torch.topk(block_logits, top_k, dim=-1)
block_logits = torch.full_like(block_logits, float('-inf'))
block_logits.scatter_(-1, top_k_indices, top_k_logits)
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(block_logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
block_logits[indices_to_remove] = float('-inf')
if no_repeat_ngram_size > 0 and len(block_token_history) >= no_repeat_ngram_size - 1:
recent_ngram = tuple(block_token_history[-(no_repeat_ngram_size-1):])
full_history = context_ids[0].tolist() + block_token_history
for i in range(len(full_history) - no_repeat_ngram_size + 1):
if tuple(full_history[i:i+no_repeat_ngram_size-1]) == recent_ngram:
blocked_token = full_history[i + no_repeat_ngram_size - 1]
if blocked_token < block_logits.shape[-1]:
block_logits[:, :, blocked_token] = float('-inf')
# Safety check: if all logits are -inf, reset to uniform distribution
all_inf_mask = torch.isinf(block_logits).all(dim=-1)
if all_inf_mask.any():
block_logits[all_inf_mask] = 0.0
return block_logits
# ============== Main Entry Point ==============
def main():
"""Main inference function."""
# Configuration
model_path = "../extra-final-boss/checkpoints/model_fp32.pt"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Allow a quick demo mode to test visualization without loading the model
import sys
if len(sys.argv) > 1 and sys.argv[1] == 'demo':
demo_visualize_truncation()
return
# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model
model, config = load_model(model_path, device)
# Generate text
print("\n" + "=" * 50)
print("Text Generation")
print("=" * 50)
prompt = "Barrack Obama was born in "
print(f"Prompt: {prompt}\n")
# Set visualize=True to see real-time diffusion effect
visualize = True
parallel_blocks = 4 # Generate 2-4 blocks in parallel for speedup
generated = generate_block_diffusion(
model,
tokenizer,
prompt=prompt,
steps=64,
block_size=64,
max_new_tokens=512,
device=device,
temperature=1,
top_k=40,
top_p=0.9,
repetition_penalty=1.3,
no_repeat_ngram_size=3,
visualize=visualize,
parallel_blocks=parallel_blocks,
)
print(f"\nGenerated text:\n{generated}")
if __name__ == "__main__":
main()