| |
| |
|
|
| import logging |
| from dataclasses import dataclass |
| from typing import Any |
| from typing import Dict |
| from typing import List |
| from typing import Literal |
| from typing import Optional |
| from typing import Tuple |
|
|
| import torch |
| import torch.nn.functional as F |
| from transformers.cache_utils import DynamicCache |
|
|
| from .sliding_utils import drop_tokens_from_cache |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class DuplexWindowConfig: |
| """双工滑窗配置 |
| |
| 滑窗模式: |
| - "off": 禁用滑窗 |
| - "basic": 基础滑窗(按 cache 长度触发) |
| - "context": 带 context 的滑窗(按 unit 数量触发,保留生成文本到 previous) |
| """ |
|
|
| |
| sliding_window_mode: str = "off" |
|
|
| |
| basic_window_high_tokens: int = 4000 |
| basic_window_low_tokens: int = 3500 |
|
|
| |
| context_previous_max_tokens: int = 500 |
| context_max_units: int = 24 |
|
|
| |
| verify_mode: bool = False |
|
|
|
|
| def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("inf")): |
| logits = logits.clone() |
|
|
| |
| if top_k > 0: |
| top_k = min(top_k, logits.size(-1)) |
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
| logits[indices_to_remove] = filter_value |
|
|
| |
| if top_p > 0.0: |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| probs = F.softmax(sorted_logits, dim=-1) |
| cumulative_probs = torch.cumsum(probs, dim=-1) |
|
|
| sorted_indices_to_remove = cumulative_probs > top_p |
| |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| sorted_indices_to_remove[..., 0] = 0 |
|
|
| indices_to_remove = sorted_indices[sorted_indices_to_remove] |
| logits[0, indices_to_remove] = filter_value |
|
|
| return logits |
|
|
|
|
| class StreamDecoder: |
| def __init__(self, llm, tokenizer, special_token_ids=None, forbidden_token_ids=None): |
| self.m = llm |
| self.tokenizer = tokenizer |
| self.listen_id = self.tokenizer.eos_token_id |
|
|
| self.chunk_eos_id = self.tokenizer.convert_tokens_to_ids("<|chunk_eos|>") |
| self.chunk_tts_eos_id = self.tokenizer.convert_tokens_to_ids("<|chunk_tts_eos|>") |
| self.turn_eos_id = self.tokenizer.convert_tokens_to_ids("<|turn_eos|>") |
| self.speak_id = self.tokenizer.convert_tokens_to_ids("<|speak|>") |
|
|
| self.special_token_ids = special_token_ids if special_token_ids is not None else [] |
|
|
| |
| self._all_special_ids = set() |
| self._all_special_tokens_text = set() |
| if self.tokenizer: |
| if hasattr(self.tokenizer, "all_special_ids"): |
| self._all_special_ids = set(self.tokenizer.all_special_ids) |
| if hasattr(self.tokenizer, "all_special_tokens"): |
| self._all_special_tokens_text = set(self.tokenizer.all_special_tokens) |
|
|
| custom_special_tokens = [ |
| "<unit>", |
| "</unit>", |
| "<image>", |
| "</image>", |
| "<slice>", |
| "</slice>", |
| "<|listen|>", |
| "<|speak|>", |
| "<|tts_bos|>", |
| "<|tts_eos|>", |
| "<|audio_start|>", |
| "<|audio_end|>", |
| "<|chunk_eos|>", |
| "<|chunk_tts_eos|>", |
| "<|turn_eos|>", |
| "<|audio_start|>", |
| "<|audio_end|>", |
| ] |
| self._all_special_tokens_text.update(custom_special_tokens) |
| for token in custom_special_tokens: |
| token_id = self.tokenizer.convert_tokens_to_ids(token) |
| if token_id is not None and token_id != self.tokenizer.unk_token_id: |
| self._all_special_ids.add(token_id) |
|
|
| if forbidden_token_ids is None: |
| self.forbidden_token_ids = [] |
| elif isinstance(forbidden_token_ids, int): |
| self.forbidden_token_ids = [self.forbidden_token_ids] |
| else: |
| self.forbidden_token_ids = forbidden_token_ids |
| self.forbidden_token_ids.append(self.chunk_eos_id) |
|
|
| assert isinstance(self.forbidden_token_ids, list) |
|
|
| self.cache = None |
| self.context = "" |
| self.generated_tokens = [] |
| self.generated_special_tokens = [] |
| self.reset() |
| self.embeds = None |
| self.system_embeds = None |
|
|
| |
| self._unit_history: List[Dict[str, Any]] = [] |
| self._next_unit_id: int = 0 |
| self._pending_unit_id: Optional[int] = None |
| self._pending_unit_start_cache_len: int = 0 |
| self._system_preserve_length: int = 0 |
| self._position_offset: int = 0 |
| self._window_config = DuplexWindowConfig() |
| self._window_enabled: bool = True |
| self._rope_inv_freq_cache: Dict[Tuple, torch.Tensor] = {} |
|
|
| |
| |
| |
| |
| self._preserve_prefix_length: int = 0 |
| self._previous_content_length: int = 0 |
| self._suffix_token_ids: List[int] = [] |
|
|
| |
| self._previous_marker: str = "\n\nprevious: " |
| self._previous_marker_token_ids: List[int] = [] |
| self._has_previous: bool = False |
|
|
| |
| self._previous_text: str = "" |
| self._previous_token_ids: List[int] = [] |
|
|
| |
| self._sliding_event_count: int = 0 |
| self._total_dropped_tokens: int = 0 |
| self._total_dropped_units: int = 0 |
|
|
| def sliding_embeds(self): |
| |
| |
| |
| |
| pass |
|
|
| def reset(self): |
| self.context = "" |
| self.cache = None |
| self.generated_tokens = [] |
| self.generated_special_tokens = [] |
| self.embeds = None |
| self.system_embeds = None |
|
|
| |
| old_unit_count = len(self._unit_history) if hasattr(self, "_unit_history") else 0 |
| self._unit_history = [] |
| self._next_unit_id = 0 |
| self._pending_unit_id = None |
| self._pending_unit_start_cache_len = 0 |
| self._system_preserve_length = 0 |
| self._position_offset = 0 |
| self._rope_inv_freq_cache = {} |
|
|
| |
| self._preserve_prefix_length = 0 |
| self._previous_content_length = 0 |
| self._suffix_token_ids = [] |
| self._previous_marker = "\n\nprevious: " |
| self._previous_marker_token_ids = [] |
| self._has_previous = False |
| self._previous_text = "" |
| self._previous_token_ids = [] |
|
|
| |
| self._sliding_event_count = 0 |
| self._total_dropped_tokens = 0 |
| self._total_dropped_units = 0 |
|
|
| if old_unit_count > 0: |
| logger.info("[SW] reset: cleared %d units, all sliding window state reset", old_unit_count) |
|
|
| def get_cache_length(self) -> int: |
| if self.cache is None: |
| return 0 |
| if isinstance(self.cache, DynamicCache): |
| if len(self.cache.key_cache) > 0 and self.cache.key_cache[0].numel() > 0: |
| return self.cache.key_cache[0].shape[2] |
| return 0 |
| |
| return self.cache[0][0].shape[2] |
|
|
| def get_total_generated_tokens(self) -> int: |
| return sum(len(u.get("generated_tokens", [])) for u in self._unit_history) |
|
|
| def register_unit_start(self) -> int: |
| self._pending_unit_id = self._next_unit_id |
| self._pending_unit_start_cache_len = self.get_cache_length() |
| logger.info( |
| "[SW] unit_start: pending_unit_id=%d, cache_len=%d, preserve=%d, units=%d", |
| self._pending_unit_id, |
| self._pending_unit_start_cache_len, |
| self._system_preserve_length, |
| len(self._unit_history), |
| ) |
| return self._pending_unit_id |
|
|
| def register_unit_end( |
| self, |
| input_type: str, |
| generated_tokens: Optional[List[int]] = None, |
| is_listen: bool = False, |
| generated_text: Optional[str] = None, |
| ): |
| """在 unit 结束时调用,记录该 unit 的信息 |
| |
| 应在 feed </unit> token 之后调用 |
| |
| Args: |
| input_type: "audio" / "video" / "omni" / "system" |
| generated_tokens: 该 unit 生成的 tokens(token ids) |
| is_listen: 是否是 listen 状态 |
| generated_text: 该 unit 生成的文本(用于 context 保留模式) |
| """ |
| if self._pending_unit_id is None: |
| logger.warning("register_unit_end called without register_unit_start") |
| return |
|
|
| |
| current_cache_len = self.get_cache_length() |
| unit_len = current_cache_len - self._pending_unit_start_cache_len |
|
|
| if unit_len > 0: |
| entry = { |
| "unit_id": self._pending_unit_id, |
| "length": unit_len, |
| "type": input_type, |
| "generated_tokens": generated_tokens or [], |
| "generated_text": generated_text or "", |
| "is_listen": is_listen, |
| } |
| self._unit_history.append(entry) |
| gen_count = len(generated_tokens) if generated_tokens else 0 |
| gen_text_preview = ( |
| (generated_text[:30] + "...") if generated_text and len(generated_text) > 30 else (generated_text or "") |
| ) |
| logger.info( |
| "[SW] unit_end: unit_id=%d type=%s len=%d gen_tokens=%d is_listen=%s | " |
| "cache=%d preserve=%d total_units=%d | text='%s'", |
| self._pending_unit_id, |
| input_type, |
| unit_len, |
| gen_count, |
| is_listen, |
| current_cache_len, |
| self._system_preserve_length, |
| len(self._unit_history), |
| gen_text_preview, |
| ) |
| else: |
| logger.warning( |
| "[SW] unit_end: unit_id=%d has zero length (start=%d, current=%d), not recorded", |
| self._pending_unit_id, |
| self._pending_unit_start_cache_len, |
| current_cache_len, |
| ) |
|
|
| self._pending_unit_id = None |
| self._pending_unit_start_cache_len = 0 |
| self._next_unit_id += 1 |
|
|
| def register_system_prompt(self): |
| """在 system prompt prefill 完成后调用,记录保护长度""" |
| self._system_preserve_length = self.get_cache_length() |
| logger.info( |
| "[SW] system_prompt registered: preserve_length=%d (will be protected from sliding)", |
| self._system_preserve_length, |
| ) |
|
|
| |
|
|
| def _get_rope_theta(self) -> float: |
| """获取模型的 rope_theta 配置""" |
| return float(getattr(self.m.config, "rope_theta", 10000.0)) |
|
|
| def _drop_tokens_from_cache(self, length: int) -> bool: |
| """从 cache 中移除指定数量的 tokens(保护 system prompt) |
| |
| 移除位于 [preserve, preserve + length) 区间的 tokens |
| 支持 DynamicCache 和 tuple cache 两种格式 |
| """ |
| if self.cache is None or length <= 0: |
| logger.warning("[SW] _drop_tokens_from_cache: cache is None or length<=0 (length=%d)", length) |
| return False |
|
|
| cache_type = "DynamicCache" if isinstance(self.cache, DynamicCache) else "TupleCache" |
| cache_len_before = self.get_cache_length() |
| offset_before = self._position_offset |
|
|
| logger.debug( |
| "[SW] _drop_tokens_from_cache: type=%s, drop=%d tokens from [%d, %d), cache=%d, preserve=%d", |
| cache_type, |
| length, |
| self._system_preserve_length, |
| self._system_preserve_length + length, |
| cache_len_before, |
| self._system_preserve_length, |
| ) |
|
|
| new_cache, new_offset, success = drop_tokens_from_cache( |
| cache=self.cache, |
| length=length, |
| preserve=self._system_preserve_length, |
| position_offset=self._position_offset, |
| rope_theta=self._get_rope_theta(), |
| inv_freq_cache=self._rope_inv_freq_cache, |
| ) |
| if success: |
| self.cache = new_cache |
| self._position_offset = new_offset |
|
|
| if success: |
| logger.debug( |
| "[SW] _drop_tokens_from_cache: SUCCESS cache %d -> %d, offset %d -> %d (RoPE reindexed)", |
| cache_len_before, |
| self.get_cache_length(), |
| offset_before, |
| self._position_offset, |
| ) |
| else: |
| logger.error( |
| "[SW] _drop_tokens_from_cache: FAILED to drop %d tokens (cache=%d, preserve=%d)", |
| length, |
| cache_len_before, |
| self._system_preserve_length, |
| ) |
|
|
| return success |
|
|
| def _drop_unit(self, unit_id: int) -> bool: |
| """移除指定 unit""" |
| entries = [u for u in self._unit_history if u["unit_id"] == unit_id] |
| if not entries: |
| logger.warning("[SW] _drop_unit: unit_id=%d not found", unit_id) |
| return False |
|
|
| total_len = sum(e["length"] for e in entries) |
| if total_len <= 0: |
| logger.warning("[SW] _drop_unit: unit_id=%d has zero total length, removing from history", unit_id) |
| for e in entries: |
| self._unit_history.remove(e) |
| return False |
|
|
| cache_before = self.get_cache_length() |
| if not self._drop_tokens_from_cache(total_len): |
| logger.error( |
| "[SW] _drop_unit: failed to drop %d tokens for unit_id=%d from cache (cache=%d, preserve=%d)", |
| total_len, |
| unit_id, |
| cache_before, |
| self._system_preserve_length, |
| ) |
| return False |
|
|
| cache_after = self.get_cache_length() |
| for e in entries: |
| gen_count = len(e.get("generated_tokens", [])) |
| logger.info( |
| "[SW] 🗑️ DROPPED unit_id=%d type=%s len=%d gen_tokens=%d | cache %d -> %d, offset=%d", |
| e["unit_id"], |
| e["type"], |
| e["length"], |
| gen_count, |
| cache_before, |
| cache_after, |
| self._position_offset, |
| ) |
| self._unit_history.remove(e) |
|
|
| return True |
|
|
| def _drop_next_unit(self) -> bool: |
| """移除最早的一个非 system unit""" |
| for entry in self._unit_history: |
| unit_id = entry.get("unit_id") |
| if unit_id is None: |
| continue |
| |
| if entry.get("type") == "system": |
| logger.debug("[SW] _drop_next_unit: skipping system unit_id=%d", unit_id) |
| continue |
| logger.debug("[SW] _drop_next_unit: attempting to drop unit_id=%d", unit_id) |
| if self._drop_unit(unit_id): |
| return True |
| logger.debug("[SW] _drop_next_unit: no droppable unit found in %d units", len(self._unit_history)) |
| return False |
|
|
| def enforce_window(self) -> bool: |
| """强制执行滑窗策略(与单工保持一致,只看 cache 长度) |
| |
| 当 cache 长度超过高水位线时,循环移除最早的 unit, |
| 直到 cache 长度降到低水位线以下。 |
| """ |
| if not self._window_enabled: |
| logger.info("[SW] enforce_window: window disabled, skip") |
| return False |
|
|
| cfg = self._window_config |
| cache_len_before = self.get_cache_length() |
|
|
| if cache_len_before <= cfg.basic_window_high_tokens: |
| logger.debug( |
| "[SW] enforce_window: cache=%d <= high_water=%d, no sliding needed", |
| cache_len_before, |
| cfg.basic_window_high_tokens, |
| ) |
| return False |
|
|
| |
| logger.info( |
| "[SW] ⚡ SLIDING TRIGGERED: cache=%d > high_water=%d, target=low_water=%d", |
| cache_len_before, |
| cfg.basic_window_high_tokens, |
| cfg.basic_window_low_tokens, |
| ) |
|
|
| dropped_count = 0 |
| cache_len = cache_len_before |
| while cache_len > cfg.basic_window_low_tokens: |
| if not self._drop_next_unit(): |
| logger.warning("[SW] enforce_window: no more units to drop, stopping") |
| break |
| dropped_count += 1 |
| cache_len = self.get_cache_length() |
|
|
| if dropped_count > 0: |
| |
| self._sliding_event_count += 1 |
| self._total_dropped_tokens += cache_len_before - cache_len |
| self._total_dropped_units += dropped_count |
|
|
| |
| expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history) |
| is_consistent = expected == cache_len |
| logger.info( |
| "[SW] ✅ SLIDING DONE: cache %d -> %d, dropped %d units, remaining %d units | " |
| "consistency: expected=%d actual=%d %s", |
| cache_len_before, |
| cache_len, |
| dropped_count, |
| len(self._unit_history), |
| expected, |
| cache_len, |
| "✓" if is_consistent else "✗ MISMATCH!", |
| ) |
| if not is_consistent: |
| logger.error( |
| "[SW] ❌ CONSISTENCY ERROR! preserve=%d + sum(units)=%d != cache=%d, offset=%d", |
| self._system_preserve_length, |
| sum(u["length"] for u in self._unit_history), |
| cache_len, |
| self._position_offset, |
| ) |
|
|
| return dropped_count > 0 |
|
|
| |
|
|
| def register_system_prompt_with_context( |
| self, |
| suffix_token_ids: Optional[List[int]] = None, |
| context_previous_marker: str = "\n\nprevious: ", |
| ): |
| """注册 system prompt(带 context 保留模式) |
| |
| 初始化时 Cache 布局: [prefix] [suffix] [units...] |
| 首次滑窗后布局: [prefix] [context_previous_marker + content] [suffix] [units...] |
| |
| 调用此方法时,cache 中应该只有 prefix(不含 previous 标志) |
| suffix 会在后续 feed 进去 |
| |
| Args: |
| suffix_token_ids: suffix 的 token ids(如 <|im_end|> 的 id) |
| context_previous_marker: previous 标志前缀,如 "\\n\\nprevious: " |
| """ |
| |
| self._preserve_prefix_length = self.get_cache_length() |
| self._previous_content_length = 0 |
| self._suffix_token_ids = suffix_token_ids or [] |
| |
| self._system_preserve_length = self._preserve_prefix_length + len(self._suffix_token_ids) |
|
|
| |
| self._previous_marker = context_previous_marker |
| self._previous_marker_token_ids = ( |
| self.tokenizer.encode(context_previous_marker, add_special_tokens=False) if self.tokenizer else [] |
| ) |
| self._has_previous = False |
| self._previous_text = "" |
| self._previous_token_ids = [] |
|
|
| logger.info( |
| "[SW-CTX] system_prompt registered: prefix_len=%d, suffix_len=%d, marker='%s' (%d tokens)", |
| self._preserve_prefix_length, |
| len(self._suffix_token_ids), |
| context_previous_marker.replace("\n", "\\n"), |
| len(self._previous_marker_token_ids), |
| ) |
| self.log_cache_layout("After register_system_prompt") |
|
|
| def _extract_generated_text(self, units: List[Dict[str, Any]]) -> Tuple[str, List[int]]: |
| """从 units 中提取生成的文本和 token ids |
| |
| Args: |
| units: 要提取的 unit 列表 |
| |
| Returns: |
| (text, token_ids): 拼接后的文本和 token ids(过滤掉 special tokens) |
| """ |
| text_parts = [] |
| token_ids = [] |
|
|
| for u in units: |
| |
| if u.get("is_listen", False): |
| continue |
| gen_text = u.get("generated_text", "") |
| gen_tokens = u.get("generated_tokens", []) |
|
|
| |
| if gen_text: |
| clean_text = gen_text |
| for st in self._all_special_tokens_text: |
| clean_text = clean_text.replace(st, "") |
| if clean_text.strip(): |
| text_parts.append(clean_text) |
|
|
| |
| if gen_tokens: |
| filtered_tokens = [t for t in gen_tokens if t not in self._all_special_ids] |
| token_ids.extend(filtered_tokens) |
|
|
| return "".join(text_parts), token_ids |
|
|
| def _rebuild_cache_with_previous( |
| self, |
| new_previous_tokens: List[int], |
| units_to_keep_len: Optional[int] = None, |
| ) -> bool: |
| """重建 cache,把新的 previous 内容插入到 prefix 和 suffix 之间 |
| |
| Cache 布局变化: |
| [prefix] [old_prev] [suffix] [old_units] → [prefix] [new_prev] [suffix] [remaining_units] |
| |
| Args: |
| new_previous_tokens: 新的 previous token ids |
| units_to_keep_len: 需要保留的 units 长度(从 cache 末尾往回算) |
| 如果为 None,根据 unit_history 计算 |
| |
| Returns: |
| 是否成功重建 |
| """ |
| if self.cache is None: |
| logger.warning("[SW-CTX] _rebuild_cache_with_previous: cache is None") |
| return False |
|
|
| old_previous_len = self._previous_content_length |
| new_previous_len = len(new_previous_tokens) |
| suffix_len = len(self._suffix_token_ids) |
| total_cache_len = self.get_cache_length() |
|
|
| |
| if units_to_keep_len is None: |
| units_to_keep_len = sum(u["length"] for u in self._unit_history) |
|
|
| |
| |
| if new_previous_len == 0 and old_previous_len == 0: |
| |
| |
| preserve_len = self._preserve_prefix_length + suffix_len |
|
|
| |
| |
| if units_to_keep_len > 0: |
| |
| prefix_suffix_cache = self._slice_cache(0, preserve_len) |
| units_cache = self._slice_cache(total_cache_len - units_to_keep_len, None) |
|
|
| |
| dropped_tokens = total_cache_len - preserve_len - units_to_keep_len |
|
|
| |
| |
| if dropped_tokens > 0: |
| old_start = preserve_len + dropped_tokens |
| new_start = preserve_len |
| logger.debug( |
| "[SW-CTX] RoPE reindex (no-op path): old_pos=[%d:%d] -> new_pos=[%d:%d], length=%d", |
| old_start, |
| old_start + units_to_keep_len, |
| new_start, |
| new_start + units_to_keep_len, |
| units_to_keep_len, |
| ) |
| units_cache = self._reindex_rope_for_cache(units_cache, old_start, new_start, units_to_keep_len) |
|
|
| self.cache = self._concat_caches(prefix_suffix_cache, units_cache) |
| else: |
| self.cache = self._slice_cache(0, preserve_len) |
|
|
| logger.info( |
| "[SW-CTX] _rebuild_cache_with_previous (no-op): previous unchanged (0->0), " |
| "just removed unit from cache, cache=%d, units_kept=%d", |
| self.get_cache_length(), |
| units_to_keep_len, |
| ) |
| return True |
|
|
| |
| prefix_end = self._preserve_prefix_length |
| prefix_cache = self._slice_cache(0, prefix_end) |
|
|
| |
| units_start_in_old_cache = total_cache_len - units_to_keep_len |
| units_cache = None |
| if units_to_keep_len > 0: |
| units_cache = self._slice_cache(units_start_in_old_cache, None) |
|
|
| |
| |
| prev_suffix_tokens = new_previous_tokens + self._suffix_token_ids |
| prev_suffix_len = len(prev_suffix_tokens) |
|
|
| new_prefix_prev_suffix_cache = prefix_cache |
| if prev_suffix_len > 0: |
| |
| prev_suffix_embeds = self.embed_tokens(prev_suffix_tokens) |
| |
| start_pos = self._preserve_prefix_length + self._position_offset |
|
|
| |
| with torch.no_grad(): |
| device = prev_suffix_embeds.device |
| position_ids = torch.arange( |
| start_pos, |
| start_pos + prev_suffix_len, |
| device=device, |
| ).unsqueeze(0) |
|
|
| |
| outputs = self.m( |
| inputs_embeds=( |
| prev_suffix_embeds.unsqueeze(0) if prev_suffix_embeds.dim() == 2 else prev_suffix_embeds |
| ), |
| position_ids=position_ids, |
| past_key_values=prefix_cache, |
| use_cache=True, |
| return_dict=True, |
| ) |
| |
| new_prefix_prev_suffix_cache = outputs.past_key_values |
|
|
| |
| |
| |
| new_system_total = prefix_end + new_previous_len + suffix_len |
| if units_cache is not None and self._get_cache_len(units_cache) > 0: |
| old_start = units_start_in_old_cache |
| new_start = new_system_total |
|
|
| if old_start != new_start: |
| units_cache = self._reindex_rope_for_cache(units_cache, old_start, new_start, units_to_keep_len) |
|
|
| |
| if units_cache is not None and self._get_cache_len(units_cache) > 0: |
| self.cache = self._concat_caches(new_prefix_prev_suffix_cache, units_cache) |
| else: |
| self.cache = new_prefix_prev_suffix_cache |
|
|
| |
| self._previous_content_length = new_previous_len |
| |
| self._system_preserve_length = prefix_end + new_previous_len + suffix_len |
|
|
| |
| prev_text_preview = self._previous_text[:50] + "..." if len(self._previous_text) > 50 else self._previous_text |
| suffix_preview = self.tokenizer.decode(self._suffix_token_ids) if self._suffix_token_ids else "" |
| logger.info( |
| "[SW-CTX] _rebuild_cache_with_previous:\n" |
| " prefix_len=%d | previous: %d tokens '%s' | suffix: %d tokens '%s'\n" |
| " cache: %d -> %d, units_kept=%d, preserve=%d", |
| self._preserve_prefix_length, |
| new_previous_len, |
| prev_text_preview, |
| suffix_len, |
| suffix_preview, |
| old_previous_len + self._preserve_prefix_length + suffix_len + units_to_keep_len, |
| self.get_cache_length(), |
| units_to_keep_len, |
| self._system_preserve_length, |
| ) |
| return True |
|
|
| def _slice_cache(self, start: int, end: Optional[int], clone: bool = True): |
| """切片 cache |
| |
| Args: |
| start: 起始位置 |
| end: 结束位置(None 表示到末尾) |
| clone: 是否克隆(默认 True,防止共享内存问题) |
| """ |
| if self.cache is None: |
| return None |
| if isinstance(self.cache, DynamicCache): |
| |
| new_key_cache = [ |
| k[:, :, start:end, :].clone() if clone else k[:, :, start:end, :] for k in self.cache.key_cache |
| ] |
| new_value_cache = [ |
| v[:, :, start:end, :].clone() if clone else v[:, :, start:end, :] for v in self.cache.value_cache |
| ] |
| new_cache = DynamicCache() |
| new_cache.key_cache = new_key_cache |
| new_cache.value_cache = new_value_cache |
| return new_cache |
| else: |
| |
| if clone: |
| return tuple( |
| (layer[0][:, :, start:end, :].clone(), layer[1][:, :, start:end, :].clone()) for layer in self.cache |
| ) |
| else: |
| return tuple((layer[0][:, :, start:end, :], layer[1][:, :, start:end, :]) for layer in self.cache) |
|
|
| def _get_cache_len(self, cache) -> int: |
| """获取 cache 长度""" |
| if cache is None: |
| return 0 |
| if isinstance(cache, DynamicCache): |
| if len(cache.key_cache) > 0 and cache.key_cache[0].numel() > 0: |
| return cache.key_cache[0].shape[2] |
| return 0 |
| |
| if cache and cache[0] and cache[0][0] is not None: |
| return cache[0][0].shape[2] |
| return 0 |
|
|
| def _concat_caches(self, cache1, cache2): |
| """拼接两个 cache""" |
| if cache1 is None: |
| return cache2 |
| if cache2 is None: |
| return cache1 |
|
|
| if isinstance(cache1, DynamicCache): |
| new_cache = DynamicCache() |
| new_cache.key_cache = [torch.cat([k1, k2], dim=2) for k1, k2 in zip(cache1.key_cache, cache2.key_cache)] |
| new_cache.value_cache = [ |
| torch.cat([v1, v2], dim=2) for v1, v2 in zip(cache1.value_cache, cache2.value_cache) |
| ] |
| return new_cache |
| else: |
| |
| return tuple( |
| ( |
| torch.cat([layer1[0], layer2[0]], dim=2), |
| torch.cat([layer1[1], layer2[1]], dim=2), |
| ) |
| for layer1, layer2 in zip(cache1, cache2) |
| ) |
|
|
| def _reindex_rope_for_cache(self, cache, old_start: int, new_start: int, length: int): |
| """对 cache 进行 RoPE 位置调整""" |
| if cache is None or length <= 0: |
| return cache |
|
|
| device = None |
| if isinstance(cache, DynamicCache): |
| device = cache.key_cache[0].device if cache.key_cache else None |
| else: |
| device = cache[0][0].device if cache and cache[0] else None |
|
|
| if device is None: |
| return cache |
|
|
| old_positions = torch.arange(old_start, old_start + length, device=device, dtype=torch.long) |
| new_positions = torch.arange(new_start, new_start + length, device=device, dtype=torch.long) |
|
|
| from .sliding_utils import realign_rotary_suffix |
|
|
| rope_theta = self._get_rope_theta() |
|
|
| if isinstance(cache, DynamicCache): |
| new_key_cache = [] |
| for k in cache.key_cache: |
| new_k = realign_rotary_suffix(k, old_positions, new_positions, rope_theta, self._rope_inv_freq_cache) |
| new_key_cache.append(new_k) |
| cache.key_cache = new_key_cache |
| return cache |
| else: |
| new_cache = [] |
| for layer in cache: |
| new_k = realign_rotary_suffix( |
| layer[0], old_positions, new_positions, rope_theta, self._rope_inv_freq_cache |
| ) |
| new_cache.append((new_k, layer[1])) |
| return tuple(new_cache) |
|
|
| def _update_previous( |
| self, |
| new_text: str, |
| new_tokens: List[int], |
| max_tokens: int, |
| ) -> None: |
| """更新 previous 上下文(同时更新 cache) |
| |
| 首次滑窗时动态添加 marker + 文本,后续滑窗追加文本 |
| 超过 max_tokens 时截断内容(保留 marker) |
| 同时重建 cache 以保持一致 |
| |
| Args: |
| new_text: 新增的文本 |
| new_tokens: 新增的 token ids |
| max_tokens: previous 内容的最大 token 数(不含 marker) |
| """ |
| marker_len = len(self._previous_marker_token_ids) |
| tokens_to_drop = 0 |
|
|
| |
| if not new_tokens and not new_text: |
| logger.info("[SW-CTX] _update_previous: no new content, skip adding to previous") |
| |
| self._rebuild_cache_with_previous(self._previous_token_ids) |
| return |
|
|
| if not self._has_previous: |
| |
| self._previous_text = new_text |
| self._previous_token_ids = self._previous_marker_token_ids.copy() + new_tokens |
| self._has_previous = True |
| logger.info( |
| "[SW-CTX] _update_previous: first slide with content, added marker + %d tokens", |
| len(new_tokens), |
| ) |
| else: |
| |
| self._previous_text += new_text |
| self._previous_token_ids.extend(new_tokens) |
|
|
| |
| content_token_count = len(self._previous_token_ids) - marker_len |
|
|
| |
| if content_token_count > max_tokens: |
| |
| tokens_to_drop = content_token_count - max_tokens |
| old_text = self._previous_text |
| |
| content_tokens = self._previous_token_ids[marker_len + tokens_to_drop :] |
| self._previous_token_ids = self._previous_marker_token_ids.copy() + content_tokens |
| |
| try: |
| self._previous_text = self.tokenizer.decode( |
| content_tokens, |
| skip_special_tokens=True, |
| ) |
| except Exception as e: |
| logger.warning("[SW-CTX] _update_previous: decode failed: %s", e) |
|
|
| |
| logger.info( |
| "[SW-CTX] ⚠️ LEFT TRUNCATION: previous exceeded max_tokens=%d\n" |
| " before: %d content tokens, text='%s'\n" |
| " after: %d content tokens, text='%s'\n" |
| " dropped %d tokens from left", |
| max_tokens, |
| content_token_count, |
| old_text[:60] + "..." if len(old_text) > 60 else old_text, |
| len(content_tokens), |
| self._previous_text[:60] + "..." if len(self._previous_text) > 60 else self._previous_text, |
| tokens_to_drop, |
| ) |
|
|
| |
| self._rebuild_cache_with_previous(self._previous_token_ids) |
|
|
| prev_preview = self._previous_text[:80] + "..." if len(self._previous_text) > 80 else self._previous_text |
| content_len = len(self._previous_token_ids) - marker_len |
| if tokens_to_drop > 0: |
| logger.info( |
| "[SW-CTX] _update_previous: +%d tokens, -%d truncated -> %d content tokens (marker=%d) | '%s'", |
| len(new_tokens), |
| tokens_to_drop, |
| content_len, |
| marker_len, |
| prev_preview, |
| ) |
| else: |
| logger.info( |
| "[SW-CTX] _update_previous: +%d tokens -> %d content tokens (marker=%d) | '%s'", |
| len(new_tokens), |
| content_len, |
| marker_len, |
| prev_preview, |
| ) |
|
|
| def _drop_unit_with_context( |
| self, |
| unit_id: int, |
| max_previous_tokens: int, |
| ) -> Tuple[bool, str, List[int]]: |
| """移除指定 unit 并返回其生成内容(用于 context 保留) |
| |
| 流程: |
| 1. 提取 unit 的生成内容 |
| 2. 先从 cache 移除 unit(不包括 prefix+previous) |
| 3. 追加生成内容到 previous |
| 4. 重建 cache(在 _update_previous 中完成) |
| |
| Args: |
| unit_id: 要移除的 unit ID |
| max_previous_tokens: previous 最大 token 数 |
| |
| Returns: |
| (success, extracted_text, extracted_tokens): 是否成功,提取的文本和 tokens |
| """ |
| entries = [u for u in self._unit_history if u["unit_id"] == unit_id] |
| if not entries: |
| logger.warning("[SW-CTX] _drop_unit_with_context: unit_id=%d not found", unit_id) |
| return False, "", [] |
|
|
| |
| extracted_text, extracted_tokens = self._extract_generated_text(entries) |
|
|
| |
| total_len = sum(e["length"] for e in entries) |
| if total_len <= 0: |
| logger.warning("[SW-CTX] _drop_unit_with_context: unit_id=%d has zero length", unit_id) |
| for e in entries: |
| self._unit_history.remove(e) |
| return False, extracted_text, extracted_tokens |
|
|
| cache_before = self.get_cache_length() |
|
|
| |
| for e in entries: |
| self._unit_history.remove(e) |
|
|
| |
| |
|
|
| |
| self._update_previous(extracted_text, extracted_tokens, max_previous_tokens) |
|
|
| cache_after = self.get_cache_length() |
| for e in entries: |
| logger.info( |
| "[SW-CTX] 🗑️ DROPPED unit_id=%d type=%s len=%d, extracted=%d chars | cache %d -> %d", |
| e["unit_id"], |
| e["type"], |
| e["length"], |
| len(extracted_text), |
| cache_before, |
| cache_after, |
| ) |
|
|
| return True, extracted_text, extracted_tokens |
|
|
| def _drop_next_unit_with_context(self, max_previous_tokens: int) -> bool: |
| """移除最早的一个非 system unit(带 context 保留)""" |
| for entry in self._unit_history: |
| unit_id = entry.get("unit_id") |
| if unit_id is None: |
| continue |
| if entry.get("type") == "system": |
| continue |
| success, _, _ = self._drop_unit_with_context(unit_id, max_previous_tokens) |
| if success: |
| return True |
| return False |
|
|
| def enforce_window_with_context(self) -> bool: |
| """带 context 保留的滑窗执行 |
| |
| 当 unit 数量超过 max_units 时,移除最早的 unit, |
| 并将其生成内容累积到 previous。 |
| Cache 会在 _update_previous 中自动重建。 |
| |
| Returns: |
| 是否执行了滑窗 |
| """ |
| if not self._window_enabled: |
| logger.info("[SW-CTX] enforce_window_with_context: window disabled, skip") |
| return False |
|
|
| cfg = self._window_config |
|
|
| if cfg.sliding_window_mode != "context": |
| |
| return self.enforce_window() |
|
|
| cache_len_before = self.get_cache_length() |
| units_before = len(self._unit_history) |
|
|
| |
| |
| if units_before <= cfg.context_max_units: |
| logger.debug( |
| "[SW-CTX] enforce_window_with_context: no sliding needed (units=%d/%d)", |
| units_before, |
| cfg.context_max_units, |
| ) |
| self.log_cache_layout("No sliding (units=%d/%d)" % (units_before, cfg.context_max_units)) |
| return False |
|
|
| slide_tag = "slide #%d" % (self._sliding_event_count + 1) |
| logger.info( |
| "[SW-CTX] ⚡ SLIDING TRIGGERED (%s): units=%d > max_units=%d, previous=%d tokens", |
| slide_tag, |
| units_before, |
| cfg.context_max_units, |
| len(self._previous_token_ids), |
| ) |
| self.log_cache_layout("Before %s" % slide_tag) |
|
|
| |
| dropped_count = 0 |
| while len(self._unit_history) > cfg.context_max_units: |
| if not self._drop_next_unit_with_context(cfg.context_previous_max_tokens): |
| logger.warning("[SW-CTX] enforce_window_with_context: no more units to drop") |
| break |
|
|
| dropped_count += 1 |
|
|
| cache_len_after = self.get_cache_length() |
|
|
| if dropped_count > 0: |
| |
| self._sliding_event_count += 1 |
| self._total_dropped_tokens += cache_len_before - cache_len_after |
| self._total_dropped_units += dropped_count |
|
|
| |
| expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history) |
| is_consistent = expected == cache_len_after |
| logger.info( |
| "[SW-CTX] ✅ SLIDING DONE: cache %d -> %d, dropped %d units, remaining %d units, " |
| "previous=%d tokens | consistency: %s", |
| cache_len_before, |
| cache_len_after, |
| dropped_count, |
| len(self._unit_history), |
| len(self._previous_token_ids), |
| "✓" if is_consistent else "✗ MISMATCH!", |
| ) |
| self.log_cache_layout("After slide #%d" % self._sliding_event_count) |
|
|
| return dropped_count > 0 |
|
|
| def get_previous_context(self) -> Tuple[str, List[int]]: |
| """获取当前累积的 previous context |
| |
| Returns: |
| (previous_text, previous_token_ids): 当前累积的文本和 token ids |
| """ |
| return self._previous_text, self._previous_token_ids.copy() |
|
|
| |
|
|
| def log_cache_layout(self, tag: str = "") -> None: |
| """打印当前 cache 布局(调试用) |
| |
| 根据滑窗模式显示不同的布局信息: |
| - context 模式:[prefix] [previous] [suffix] [units...] |
| - 其他模式:[system] [units...] |
| """ |
| cache_len = self.get_cache_length() |
| units_len = sum(u["length"] for u in self._unit_history) |
|
|
| if self._window_config.sliding_window_mode == "context": |
| |
| prefix_len = self._preserve_prefix_length |
| prev_len = len(self._previous_token_ids) |
| suffix_len = len(self._suffix_token_ids) |
|
|
| |
| prev_full = "" |
| if prev_len > 0 and self.tokenizer: |
| prev_full = self.tokenizer.decode(self._previous_token_ids) |
| suffix_text = "" |
| if suffix_len > 0 and self.tokenizer: |
| suffix_text = self.tokenizer.decode(self._suffix_token_ids) |
|
|
| logger.info( |
| "[SW-CTX] %s Cache Layout:\n" |
| " [prefix: %d tokens] [previous: %d tokens] [suffix: %d tokens] [units: %d tokens]\n" |
| " preserve=%d | cache=%d | has_previous=%s\n" |
| " previous_full: %s\n" |
| " suffix: %s", |
| tag, |
| prefix_len, |
| prev_len, |
| suffix_len, |
| units_len, |
| self._system_preserve_length, |
| cache_len, |
| self._has_previous, |
| repr(prev_full) if prev_full else "(empty)", |
| repr(suffix_text) if suffix_text else "(empty)", |
| ) |
| else: |
| |
| logger.info( |
| "[SW] %s Cache Layout: [system: %d] [units: %d] | cache=%d", |
| tag, |
| self._system_preserve_length, |
| units_len, |
| cache_len, |
| ) |
|
|
| def get_window_stats(self) -> Dict[str, Any]: |
| """获取滑窗统计信息""" |
| unit_lengths = [u["length"] for u in self._unit_history] |
| return { |
| "cache_length": self.get_cache_length(), |
| "unit_count": len(self._unit_history), |
| "unit_lengths": unit_lengths, |
| "unit_total_length": sum(unit_lengths), |
| "system_preserve_length": self._system_preserve_length, |
| "position_offset": self._position_offset, |
| "window_enabled": self._window_enabled, |
| "total_generated_tokens": self.get_total_generated_tokens(), |
| "pending_unit_id": self._pending_unit_id, |
| "next_unit_id": self._next_unit_id, |
| "config": { |
| "sliding_window_mode": self._window_config.sliding_window_mode, |
| "basic_window_high_tokens": self._window_config.basic_window_high_tokens, |
| "basic_window_low_tokens": self._window_config.basic_window_low_tokens, |
| "context_previous_max_tokens": self._window_config.context_previous_max_tokens, |
| "context_max_units": self._window_config.context_max_units, |
| }, |
| |
| "preserve_prefix_length": self._preserve_prefix_length, |
| "previous_content_length": self._previous_content_length, |
| "suffix_token_count": len(self._suffix_token_ids), |
| "previous_text_length": len(self._previous_text), |
| "previous_token_count": len(self._previous_token_ids), |
| "has_system_template": self._system_prompt_template is not None, |
| } |
|
|
| def _verify_consistency(self) -> bool: |
| """验证 unit 历史与 cache 长度一致""" |
| expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history) |
| actual = self.get_cache_length() |
| return expected == actual |
|
|
| def dump_unit_history(self, prefix: str = "") -> None: |
| """打印当前 unit 历史(调试用)""" |
| cache_len = self.get_cache_length() |
| unit_sum = sum(u["length"] for u in self._unit_history) |
| expected = self._system_preserve_length + unit_sum |
|
|
| logger.info( |
| "[SW] %s=== UNIT HISTORY DUMP === cache=%d, preserve=%d, units=%d, offset=%d", |
| prefix + " " if prefix else "", |
| cache_len, |
| self._system_preserve_length, |
| len(self._unit_history), |
| self._position_offset, |
| ) |
| logger.info( |
| "[SW] Consistency: preserve(%d) + sum(units)(%d) = %d, actual=%d, %s", |
| self._system_preserve_length, |
| unit_sum, |
| expected, |
| cache_len, |
| "✓ MATCH" if expected == cache_len else "✗ MISMATCH!", |
| ) |
| for i, u in enumerate(self._unit_history): |
| gen_count = len(u.get("generated_tokens", [])) |
| logger.info( |
| "[SW] [%d] unit_id=%d type=%-6s len=%4d gen=%3d listen=%s", |
| i, |
| u["unit_id"], |
| u["type"], |
| u["length"], |
| gen_count, |
| u.get("is_listen", False), |
| ) |
|
|
| def print_verification_summary(self) -> Dict[str, Any]: |
| """打印验证摘要(用于对比 off/basic/context 模式) |
| |
| Returns: |
| 包含关键验证数据的字典 |
| """ |
| cfg = self._window_config |
|
|
| |
| all_generated_text = [] |
| all_generated_tokens = [] |
| for u in self._unit_history: |
| if not u.get("is_listen", False): |
| gen_text = u.get("generated_text", "") |
| gen_tokens = u.get("generated_tokens", []) |
| if gen_text: |
| all_generated_text.append(gen_text) |
| if gen_tokens: |
| all_generated_tokens.extend(gen_tokens) |
|
|
| combined_text = "".join(all_generated_text) |
|
|
| summary = { |
| "mode": cfg.sliding_window_mode, |
| "final_cache_length": self.get_cache_length(), |
| "final_unit_count": len(self._unit_history), |
| "sliding_event_count": self._sliding_event_count, |
| "total_dropped_tokens": self._total_dropped_tokens, |
| "total_dropped_units": self._total_dropped_units, |
| "total_generated_tokens": len(all_generated_tokens), |
| "generated_text": combined_text, |
| "previous_text": self._previous_text, |
| "previous_token_count": len(self._previous_token_ids), |
| "position_offset": self._position_offset, |
| "system_preserve_length": self._system_preserve_length, |
| } |
|
|
| logger.info("=" * 70) |
| logger.info("[VERIFY] === SLIDING WINDOW VERIFICATION SUMMARY ===") |
| logger.info("[VERIFY] Mode: %s", cfg.sliding_window_mode) |
| logger.info("[VERIFY] Final cache length: %d", summary["final_cache_length"]) |
| logger.info("[VERIFY] Final unit count: %d", summary["final_unit_count"]) |
| logger.info("[VERIFY] Sliding events: %d", summary["sliding_event_count"]) |
| logger.info( |
| "[VERIFY] Total dropped: %d tokens, %d units", |
| summary["total_dropped_tokens"], |
| summary["total_dropped_units"], |
| ) |
| logger.info("[VERIFY] Total generated tokens: %d", summary["total_generated_tokens"]) |
| logger.info( |
| "[VERIFY] Generated text: '%s'", combined_text[:100] + "..." if len(combined_text) > 100 else combined_text |
| ) |
| if cfg.sliding_window_mode == "context": |
| logger.info( |
| "[VERIFY] Previous content: %d tokens, '%s'", |
| summary["previous_token_count"], |
| self._previous_text[:50] + "..." if len(self._previous_text) > 50 else self._previous_text, |
| ) |
| logger.info("[VERIFY] Position offset: %d", summary["position_offset"]) |
| logger.info("[VERIFY] System preserve length: %d", summary["system_preserve_length"]) |
| logger.info("=" * 70) |
|
|
| return summary |
|
|
| def set_window_config(self, config: DuplexWindowConfig) -> None: |
| """设置滑窗配置""" |
| self._window_config = config |
| logger.info( |
| "[SW] Window config set: high_water=%d, low_water=%d", |
| config.basic_window_high_tokens, |
| config.basic_window_low_tokens, |
| ) |
|
|
| def set_window_enabled(self, enabled: bool) -> None: |
| """启用/禁用滑窗""" |
| old_enabled = self._window_enabled |
| self._window_enabled = enabled |
| if old_enabled != enabled: |
| logger.info("[SW] Window enabled: %s -> %s", old_enabled, enabled) |
|
|
| def get_context(self): |
| return self.context |
|
|
| def embed_token(self, tid): |
| if isinstance(tid, int): |
| tid = torch.tensor([tid], device=self.m.device) |
| return self.m.model.embed_tokens(tid) |
|
|
| def embed_tokens(self, token_ids: List[int]) -> torch.Tensor: |
| """批量嵌入多个 tokens |
| |
| Args: |
| token_ids: token id 列表 |
| |
| Returns: |
| embeddings tensor [L, H] |
| """ |
| if not token_ids: |
| return torch.empty(0, self.m.config.hidden_size, device=self.m.device) |
| tids = torch.tensor(token_ids, device=self.m.device) |
| return self.m.model.embed_tokens(tids) |
|
|
| @torch.no_grad() |
| def feed(self, embeds: torch.Tensor, return_logits: bool = False): |
| """ |
| embeds : [L, H] —— new embedding sequence fed into model at once |
| """ |
| L = embeds.size(0) |
| device = embeds.device |
|
|
| past_len = self.get_cache_length() |
| pos_ids = torch.arange(past_len, past_len + L, device=device).unsqueeze(0) |
|
|
| out = self.m( |
| inputs_embeds=embeds.unsqueeze(0), |
| position_ids=pos_ids, |
| past_key_values=self.cache, |
| |
| return_dict=True, |
| output_hidden_states=True, |
| |
| ) |
| self.cache = out.past_key_values |
|
|
| if return_logits: |
| logits = self.m.lm_head(out.hidden_states[-1])[:, -1] |
| return logits, out.hidden_states[-1] |
|
|
| @torch.no_grad() |
| def decode( |
| self, |
| logits, |
| mode: Literal["sampling", "greedy"] = "sampling", |
| temperature=0.7, |
| top_k=20, |
| top_p=0.8, |
| listen_top_k=None, |
| listen_prob_scale=1.0, |
| text_repetition_penalty=1.05, |
| text_repetition_window_size=512, |
| debug_print_top5=False, |
| ): |
| """ |
| Args: |
| logits: |
| mode: sampling or greedy |
| temperature: |
| top_k: |
| top_p: |
| listen_top_k: force listen_id to be in top-k to keep |
| listen_prob_scale: multiply listen_id probability by a weight (<1 means decrease, >1 means increase) |
| text_repetition_penalty: repetition penalty coefficient, >1.0 means decrease repetition, <1.0 means increase repetition |
| text_repetition_window_size: repetition penalty window size |
| debug_print_top5: whether to print debug information for top 5 tokens |
| |
| Sampling strategy: |
| 1. first sample all tokens with original logits (apply temperature) |
| 2. if sampled chunk_eos, return directly (keep the original model's decision of when to stop) |
| 3. if not sampled chunk_eos, mask it (set logit to -inf), continue sampling text tokens |
| 4. apply repetition penalty, top-k, top-p, etc. to the text tokens for the final sampling |
| """ |
|
|
| logits = logits.clone() |
|
|
| |
| eos_id = self.chunk_eos_id |
|
|
| with torch.no_grad(): |
| if mode == "greedy": |
| sampled_token = torch.argmax(logits[0]).item() |
| else: |
| original_probs = F.softmax(logits[0], dim=-1) |
| sampled_token = torch.multinomial(original_probs, num_samples=1).item() |
|
|
| |
| if sampled_token == eos_id: |
| next_token_id = torch.tensor([eos_id], device=logits.device) |
| next_token_str = self.tokenizer.decode(next_token_id) |
|
|
| return next_token_id |
|
|
| |
| if self.forbidden_token_ids: |
| logits[:, self.forbidden_token_ids] = float("-inf") |
|
|
| |
| if debug_print_top5: |
| print("🔵" * 30) |
| print("【BEFORE repetition penalty】施加重复惩罚之前的 Top-k logits") |
| logits_before_penalty = logits[0] / temperature if mode == "sampling" else logits[0] |
| topk_logits_before, topk_indices_before = torch.topk( |
| logits_before_penalty, k=min(5, logits_before_penalty.size(-1)) |
| ) |
|
|
| for i, (token_id, logit_val) in enumerate(zip(topk_indices_before.tolist(), topk_logits_before.tolist())): |
| token_str = self.tokenizer.decode([token_id]) |
| |
| if token_str == "\n": |
| display_str = "\\n" |
| elif token_str == " ": |
| display_str = "[SPACE]" |
| elif token_str == "": |
| display_str = "[EMPTY]" |
| elif token_str == "\t": |
| display_str = "\\t" |
| else: |
| display_str = token_str |
|
|
| |
| special_mark = "" |
| if token_id == self.listen_id: |
| special_mark = " 🎧[LISTEN]" |
| elif token_id == self.tokenizer.eos_token_id: |
| special_mark = " 🛑[EOS]" |
|
|
| print(f" {i + 1:2d}. {display_str:10s}{special_mark:15s} (id={token_id:5d}): logit={logit_val:.4f}") |
| print("🔵" * 30) |
|
|
| |
| if text_repetition_penalty != 1.0 and len(self.generated_tokens) > 0: |
| |
| recent_tokens = self.generated_tokens[-text_repetition_window_size:] |
|
|
| |
| recent_tokens = list(set(recent_tokens)) |
|
|
| |
| for token_id in recent_tokens: |
| if token_id < logits.size(-1): |
| if text_repetition_penalty > 1.0: |
| |
| logits[0, token_id] /= text_repetition_penalty |
| else: |
| |
| logits[0, token_id] *= 1.0 / text_repetition_penalty |
|
|
| if listen_prob_scale != 1.0: |
| logits[0, self.listen_id] *= listen_prob_scale |
|
|
| listen_rank = (logits[0] > logits[0, self.listen_id]).sum().item() |
|
|
| |
| if debug_print_top5: |
| |
| logits_before_softmax = logits[0] / temperature if mode == "sampling" else logits[0] |
| top5_logits_before, top5_indices_before = torch.topk( |
| logits_before_softmax, k=min(5, logits_before_softmax.size(-1)) |
| ) |
|
|
| print("=" * 20) |
|
|
| print("\n📊 Top 5 tokens BEFORE softmax (temperature={:.2f}, mode={}):".format(temperature, mode)) |
| for i, (token_id, logit_val) in enumerate(zip(top5_indices_before.tolist(), top5_logits_before.tolist())): |
| token_str = self.tokenizer.decode([token_id]) |
| |
| if token_str == "\n": |
| display_str = "\\n" |
| elif token_str == " ": |
| display_str = "[SPACE]" |
| elif token_str == "": |
| display_str = "[EMPTY]" |
| elif token_str == "\t": |
| display_str = "\\t" |
| else: |
| display_str = token_str |
|
|
| |
| special_mark = "" |
| if token_id == self.listen_id: |
| special_mark = " 🎧[LISTEN]" |
| elif token_id == self.tokenizer.eos_token_id: |
| special_mark = " 🛑[EOS]" |
|
|
| print(f" {i + 1}. {display_str:10s}{special_mark:15s} (id={token_id:5d}): logit={logit_val:.4f}") |
|
|
| |
| probs = F.softmax(logits[0] / temperature if mode == "sampling" else logits[0], dim=-1) |
| top5_probs, top5_indices = torch.topk(probs, k=min(5, probs.size(-1))) |
|
|
| print("\n📊 Top 5 tokens AFTER softmax (temperature={:.2f}, mode={}):".format(temperature, mode)) |
| for i, (token_id, prob) in enumerate(zip(top5_indices.tolist(), top5_probs.tolist())): |
| token_str = self.tokenizer.decode([token_id]) |
| |
| if token_str == "\n": |
| display_str = "\\n" |
| elif token_str == " ": |
| display_str = "[SPACE]" |
| elif token_str == "": |
| display_str = "[EMPTY]" |
| elif token_str == "\t": |
| display_str = "\\t" |
| else: |
| display_str = token_str |
|
|
| |
| special_mark = "" |
| if token_id == self.listen_id: |
| special_mark = " 🎧[LISTEN]" |
| elif token_id == self.tokenizer.eos_token_id: |
| special_mark = " 🛑[EOS]" |
|
|
| print( |
| f" {i + 1}. {display_str:10s}{special_mark:15s} (id={token_id:5d}): {prob:.4f} ({prob * 100:.2f}%)" |
| ) |
| |
| if self.listen_id not in top5_indices.tolist(): |
| listen_prob = probs[self.listen_id].item() |
| print(f" ... <|listen|> 🎧 rank={listen_rank + 1}, prob={listen_prob:.6f} ({listen_prob * 100:.4f}%)") |
|
|
| if listen_top_k is not None and listen_rank < listen_top_k: |
| next_token_id = torch.tensor([self.listen_id], device=logits.device) |
| next_token_str = self.tokenizer.decode(next_token_id) |
|
|
| if next_token_str == "<|listen|>": |
| self.context += " " |
| else: |
| self.context += next_token_str |
|
|
| return next_token_id |
|
|
| if mode == "greedy": |
| next_token_id = torch.argmax(logits, dim=-1) |
| elif mode == "sampling": |
| logits = logits / temperature |
| logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) |
| probs = F.softmax(logits, dim=-1) |
| next_token_id = torch.multinomial(probs, num_samples=1).squeeze(1) |
| else: |
| raise ValueError("Unsupported decode mode") |
|
|
| if next_token_id.item() not in self.special_token_ids: |
| self.generated_tokens.append(next_token_id.item()) |
| else: |
| self.generated_special_tokens.append(next_token_id.item()) |
|
|
| return next_token_id |
|
|