| """ |
| 推理脚本 |
| 可视化扩散翻译过程 |
| """ |
|
|
| import os |
| import argparse |
| import torch |
| import torch.nn.functional as F |
| from typing import Optional, Tuple, List |
|
|
| from config import Config |
| from tokenizer import Tokenizer |
| from embedding import DualLanguageEmbedding, DualOutputProjection |
| from model import create_model |
| from diffusion import get_diffusion |
| from switcher import create_switcher |
|
|
|
|
| class Translator: |
| """翻译器""" |
| |
| def __init__(self, config: Config, checkpoint_path: Optional[str] = None): |
| self.config = config |
| self.device = torch.device("cpu") |
| |
| |
| cache_dir = os.path.join(config.project_dir, config.data.cache_dir) |
| self.zh_tokenizer = Tokenizer.load(os.path.join(cache_dir, "tokenizer_zh.json")) |
| self.en_tokenizer = Tokenizer.load(os.path.join(cache_dir, "tokenizer_en.json")) |
| |
| |
| self.embedding = DualLanguageEmbedding( |
| vocab_size_zh=self.zh_tokenizer.vocab_size_actual, |
| vocab_size_en=self.en_tokenizer.vocab_size_actual, |
| d_model=config.model.d_model, |
| max_len=config.model.max_len, |
| dropout=0.0, |
| ) |
| |
| self.output_proj = DualOutputProjection( |
| d_model=config.model.d_model, |
| vocab_size_zh=self.zh_tokenizer.vocab_size_actual, |
| vocab_size_en=self.en_tokenizer.vocab_size_actual, |
| ) |
| |
| self.model = create_model(config) |
| self.switcher = create_switcher(config) |
| |
| self.diffusion, self.ddim_sampler = get_diffusion(config) |
| |
| |
| if checkpoint_path: |
| self._load_checkpoint(checkpoint_path) |
| |
| def _load_checkpoint(self, path: str): |
| """加载检查点""" |
| state = torch.load(path, map_location=self.device, weights_only=False) |
| |
| self.embedding.load_state_dict(state['embedding']) |
| self.output_proj.load_state_dict(state['output_proj']) |
| self.model.load_state_dict(state['model']) |
| self.switcher.load_state_dict(state['switcher']) |
| |
| print(f"已加载检查点: {path}") |
| |
| def _encode(self, text: str, lang: str) -> torch.Tensor: |
| """编码文本""" |
| if lang == "zh": |
| ids = self.zh_tokenizer.encode(text, add_sos=True, add_eos=True) |
| return torch.tensor(ids, dtype=torch.long).unsqueeze(0) |
| else: |
| ids = self.en_tokenizer.encode(text, add_sos=True, add_eos=True) |
| return torch.tensor(ids, dtype=torch.long).unsqueeze(0) |
| |
| def _decode(self, ids: torch.Tensor, lang: str) -> str: |
| """解码为文本""" |
| ids = ids[0].tolist() |
| if lang == "zh": |
| return self.zh_tokenizer.decode(ids, skip_special=True) |
| else: |
| return self.en_tokenizer.decode(ids, skip_special=True) |
| |
| def _embed_to_tokens(self, x: torch.Tensor, lang: str) -> torch.Tensor: |
| """从嵌入空间解码到token""" |
| logits = self.output_proj(x, lang) |
| ids = logits.argmax(dim=-1) |
| return ids |
| |
| @torch.no_grad() |
| def translate( |
| self, |
| text: str, |
| source_lang: str, |
| verbose: bool = True, |
| ddim: bool = True, |
| ) -> str: |
| """翻译文本 |
| |
| Args: |
| text: 输入文本 |
| source_lang: 源语言 "zh" 或 "en" |
| verbose: 是否打印扩散过程 |
| ddim: 是否使用DDIM加速 |
| |
| Returns: |
| 翻译结果 |
| """ |
| self.model.eval() |
| self.embedding.eval() |
| self.output_proj.eval() |
| self.switcher.eval() |
| |
| target_lang = "en" if source_lang == "zh" else "zh" |
| |
| if verbose: |
| print(f"\n翻译模式: {source_lang.upper()} → {target_lang.upper()}") |
| print(f"输入: {text}") |
| print(f"\n扩散过程:") |
| |
| |
| source_ids = self._encode(text, source_lang) |
| source_len = torch.tensor([source_ids.size(1)]) |
| |
| |
| source_emb = self.embedding(source_ids, source_lang, source_len) |
| |
| |
| if verbose: |
| print(f" 前向扩散: {source_lang} → 噪声空间") |
| |
| batch_size = source_emb.size(0) |
| t_full = torch.full((batch_size,), self.config.diffusion.timesteps - 1, dtype=torch.long) |
| noise = torch.randn_like(source_emb) |
| x_t, _ = self.diffusion.q_sample(source_emb, t_full, noise) |
| |
| |
| if ddim: |
| result = self._ddim_reverse( |
| x_t, source_lang, target_lang, verbose |
| ) |
| else: |
| result = self._ddpm_reverse( |
| x_t, source_lang, target_lang, verbose |
| ) |
| |
| return result |
| |
| def _ddim_reverse( |
| self, |
| x_t: torch.Tensor, |
| source_lang: str, |
| target_lang: str, |
| verbose: bool, |
| ) -> str: |
| """DDIM反向扩散""" |
| ddim_steps = self.config.diffusion.ddim_steps |
| timesteps = self.ddim_sampler.ddim_timesteps |
| total_steps = len(timesteps) |
| switch_point = total_steps // 2 |
| |
| for i, t in enumerate(timesteps[:-1]): |
| t_prev = timesteps[i + 1] |
| |
| |
| |
| if i < switch_point: |
| current_lang = source_lang |
| else: |
| current_lang = target_lang |
| |
| |
| t_tensor = torch.full((x_t.size(0),), t, dtype=torch.long) |
| predicted_noise = self.model(x_t, t_tensor, lang=current_lang) |
| |
| if verbose: |
| |
| current_ids = self._embed_to_tokens(x_t, current_lang) |
| current_text = self._decode(current_ids, current_lang) |
| if len(current_text) > 50: |
| current_text = current_text[:50] + "..." |
| |
| print(f" Step {t:4d} → {current_text}") |
| |
| |
| x_t = self.ddim_sampler.ddim_step(x_t, t, t_prev, predicted_noise, eta=0.0) |
| |
| |
| final_ids = self._embed_to_tokens(x_t, target_lang) |
| result = self._decode(final_ids, target_lang) |
| |
| if verbose: |
| print(f"\n输出: {result}") |
| |
| return result |
| |
| def _ddpm_reverse( |
| self, |
| x_t: torch.Tensor, |
| source_lang: str, |
| target_lang: str, |
| verbose: bool, |
| ) -> str: |
| """DDPM反向扩散(标准方法,较慢)""" |
| total_steps = self.config.diffusion.timesteps |
| switch_point = total_steps // 2 |
| |
| for t in range(total_steps - 1, -1, -1): |
| |
| if t > switch_point: |
| current_lang = source_lang |
| else: |
| current_lang = target_lang |
| |
| t_tensor = torch.full((x_t.size(0),), t, dtype=torch.long) |
| |
| |
| predicted_noise = self.model(x_t, t_tensor, lang=current_lang) |
| |
| if verbose: |
| current_ids = self._embed_to_tokens(x_t, current_lang) |
| current_text = self._decode(current_ids, current_lang) |
| if len(current_text) > 50: |
| current_text = current_text[:50] + "..." |
| print(f" Step {t:4d} → {current_text}") |
| |
| |
| x_t = self.diffusion.p_sample(x_t, t_tensor, predicted_noise) |
| |
| |
| final_ids = self._embed_to_tokens(x_t, target_lang) |
| result = self._decode(final_ids, target_lang) |
| |
| if verbose: |
| print(f"\n输出: {result}") |
| |
| return result |
| |
| def interactive(self): |
| """交互模式""" |
| print("\n" + "=" * 50) |
| print("Diffutslator 交互翻译模式") |
| print("=" * 50) |
| print("输入 'zh: 文本' 翻译中文到英文") |
| print("输入 'en: text' 翻译英文到中文") |
| print("输入 'quit' 或 'exit' 退出") |
| print("=" * 50 + "\n") |
| |
| while True: |
| try: |
| user_input = input(">>> ").strip() |
| |
| if user_input.lower() in ['quit', 'exit', 'q']: |
| print("再见!") |
| break |
| |
| if not user_input: |
| continue |
| |
| |
| if user_input.lower().startswith('zh:'): |
| text = user_input[3:].strip() |
| source_lang = "zh" |
| elif user_input.lower().startswith('en:'): |
| text = user_input[3:].strip() |
| source_lang = "en" |
| else: |
| |
| if any('\u4e00' <= c <= '\u9fff' for c in user_input): |
| text = user_input |
| source_lang = "zh" |
| else: |
| text = user_input |
| source_lang = "en" |
| |
| |
| result = self.translate(text, source_lang, verbose=True) |
| |
| except KeyboardInterrupt: |
| print("\n再见!") |
| break |
| except Exception as e: |
| print(f"错误: {e}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Diffutslator 推理脚本") |
| |
| parser.add_argument("--checkpoint", type=str, default=None, help="检查点路径") |
| parser.add_argument("--text", type=str, default=None, help="要翻译的文本") |
| parser.add_argument("--zh", action="store_true", help="输入是中文") |
| parser.add_argument("--en", action="store_true", help="输入是英文") |
| parser.add_argument("--interactive", "-i", action="store_true", help="交互模式") |
| parser.add_argument("--quiet", "-q", action="store_true", help="安静模式,不打印过程") |
| parser.add_argument("--ddim-steps", type=int, default=50, help="DDIM步数") |
| |
| args = parser.parse_args() |
| |
| |
| config = Config() |
| config.diffusion.ddim_steps = args.ddim_steps |
| |
| |
| checkpoint_path = args.checkpoint |
| if checkpoint_path is None: |
| checkpoint_dir = os.path.join(config.project_dir, config.training.checkpoint_dir) |
| best_path = os.path.join(checkpoint_dir, "best.pt") |
| if os.path.exists(best_path): |
| checkpoint_path = best_path |
| else: |
| |
| checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pt')] |
| if checkpoints: |
| checkpoint_path = os.path.join(checkpoint_dir, checkpoints[-1]) |
| |
| if checkpoint_path is None: |
| print("错误: 未找到检查点,请先训练模型") |
| return |
| |
| |
| translator = Translator(config, checkpoint_path) |
| |
| |
| if args.interactive: |
| translator.interactive() |
| elif args.text: |
| if args.zh: |
| source_lang = "zh" |
| elif args.en: |
| source_lang = "en" |
| else: |
| |
| if any('\u4e00' <= c <= '\u9fff' for c in args.text): |
| source_lang = "zh" |
| else: |
| source_lang = "en" |
| |
| result = translator.translate(args.text, source_lang, verbose=not args.quiet) |
| if args.quiet: |
| print(result) |
| else: |
| |
| translator.interactive() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|