import torch import torch.nn.functional as F def get_last_attn(attn_map): for i, layer in enumerate(attn_map): attn_map[i] = layer[:, :, -1, :].unsqueeze(2) return attn_map def sample_token(logits, top_k=None, top_p=None, temperature=1.0): # Optionally apply temperature logits = logits / temperature # Apply top-k sampling if top_k is not None: top_k = min(top_k, logits.size(-1)) # Ensure top_k <= vocab size values, indices = torch.topk(logits, top_k) probs = F.softmax(values, dim=-1) next_token_id = indices[torch.multinomial(probs, 1)] return next_token_id return logits.argmax(dim=-1).squeeze()