vAIbe_diffutslator / inference.py
forthezero's picture
Upload 28 files
2651102 verified
"""
推理脚本
可视化扩散翻译过程
"""
import os
import argparse
import torch
import torch.nn.functional as F
from typing import Optional, Tuple, List
from config import Config
from tokenizer import Tokenizer
from embedding import DualLanguageEmbedding, DualOutputProjection
from model import create_model
from diffusion import get_diffusion
from switcher import create_switcher
class Translator:
"""翻译器"""
def __init__(self, config: Config, checkpoint_path: Optional[str] = None):
self.config = config
self.device = torch.device("cpu")
# 加载分词器
cache_dir = os.path.join(config.project_dir, config.data.cache_dir)
self.zh_tokenizer = Tokenizer.load(os.path.join(cache_dir, "tokenizer_zh.json"))
self.en_tokenizer = Tokenizer.load(os.path.join(cache_dir, "tokenizer_en.json"))
# 初始化模型组件
self.embedding = DualLanguageEmbedding(
vocab_size_zh=self.zh_tokenizer.vocab_size_actual,
vocab_size_en=self.en_tokenizer.vocab_size_actual,
d_model=config.model.d_model,
max_len=config.model.max_len,
dropout=0.0, # 推理时不使用dropout
)
self.output_proj = DualOutputProjection(
d_model=config.model.d_model,
vocab_size_zh=self.zh_tokenizer.vocab_size_actual,
vocab_size_en=self.en_tokenizer.vocab_size_actual,
)
self.model = create_model(config)
self.switcher = create_switcher(config)
self.diffusion, self.ddim_sampler = get_diffusion(config)
# 加载权重
if checkpoint_path:
self._load_checkpoint(checkpoint_path)
def _load_checkpoint(self, path: str):
"""加载检查点"""
state = torch.load(path, map_location=self.device, weights_only=False)
self.embedding.load_state_dict(state['embedding'])
self.output_proj.load_state_dict(state['output_proj'])
self.model.load_state_dict(state['model'])
self.switcher.load_state_dict(state['switcher'])
print(f"已加载检查点: {path}")
def _encode(self, text: str, lang: str) -> torch.Tensor:
"""编码文本"""
if lang == "zh":
ids = self.zh_tokenizer.encode(text, add_sos=True, add_eos=True)
return torch.tensor(ids, dtype=torch.long).unsqueeze(0)
else:
ids = self.en_tokenizer.encode(text, add_sos=True, add_eos=True)
return torch.tensor(ids, dtype=torch.long).unsqueeze(0)
def _decode(self, ids: torch.Tensor, lang: str) -> str:
"""解码为文本"""
ids = ids[0].tolist()
if lang == "zh":
return self.zh_tokenizer.decode(ids, skip_special=True)
else:
return self.en_tokenizer.decode(ids, skip_special=True)
def _embed_to_tokens(self, x: torch.Tensor, lang: str) -> torch.Tensor:
"""从嵌入空间解码到token"""
logits = self.output_proj(x, lang)
ids = logits.argmax(dim=-1)
return ids
@torch.no_grad()
def translate(
self,
text: str,
source_lang: str,
verbose: bool = True,
ddim: bool = True,
) -> str:
"""翻译文本
Args:
text: 输入文本
source_lang: 源语言 "zh" 或 "en"
verbose: 是否打印扩散过程
ddim: 是否使用DDIM加速
Returns:
翻译结果
"""
self.model.eval()
self.embedding.eval()
self.output_proj.eval()
self.switcher.eval()
target_lang = "en" if source_lang == "zh" else "zh"
if verbose:
print(f"\n翻译模式: {source_lang.upper()}{target_lang.upper()}")
print(f"输入: {text}")
print(f"\n扩散过程:")
# 编码源语言
source_ids = self._encode(text, source_lang)
source_len = torch.tensor([source_ids.size(1)])
# 嵌入源语言
source_emb = self.embedding(source_ids, source_lang, source_len)
# 完整前向扩散到纯噪声
if verbose:
print(f" 前向扩散: {source_lang} → 噪声空间")
batch_size = source_emb.size(0)
t_full = torch.full((batch_size,), self.config.diffusion.timesteps - 1, dtype=torch.long)
noise = torch.randn_like(source_emb)
x_t, _ = self.diffusion.q_sample(source_emb, t_full, noise)
# DDIM反向扩散
if ddim:
result = self._ddim_reverse(
x_t, source_lang, target_lang, verbose
)
else:
result = self._ddpm_reverse(
x_t, source_lang, target_lang, verbose
)
return result
def _ddim_reverse(
self,
x_t: torch.Tensor,
source_lang: str,
target_lang: str,
verbose: bool,
) -> str:
"""DDIM反向扩散"""
ddim_steps = self.config.diffusion.ddim_steps
timesteps = self.ddim_sampler.ddim_timesteps
total_steps = len(timesteps)
switch_point = total_steps // 2 # 在中间切换语言
for i, t in enumerate(timesteps[:-1]):
t_prev = timesteps[i + 1]
# 根据进度决定用哪种语言去噪和显示
# 前半段:源语言,后半段:目标语言
if i < switch_point:
current_lang = source_lang
else:
current_lang = target_lang
# 预测噪声
t_tensor = torch.full((x_t.size(0),), t, dtype=torch.long)
predicted_noise = self.model(x_t, t_tensor, lang=current_lang)
if verbose:
# 显示当前语言的解码结果
current_ids = self._embed_to_tokens(x_t, current_lang)
current_text = self._decode(current_ids, current_lang)
if len(current_text) > 50:
current_text = current_text[:50] + "..."
print(f" Step {t:4d}{current_text}")
# DDIM步骤
x_t = self.ddim_sampler.ddim_step(x_t, t, t_prev, predicted_noise, eta=0.0)
# 最终解码
final_ids = self._embed_to_tokens(x_t, target_lang)
result = self._decode(final_ids, target_lang)
if verbose:
print(f"\n输出: {result}")
return result
def _ddpm_reverse(
self,
x_t: torch.Tensor,
source_lang: str,
target_lang: str,
verbose: bool,
) -> str:
"""DDPM反向扩散(标准方法,较慢)"""
total_steps = self.config.diffusion.timesteps
switch_point = total_steps // 2 # 在中间切换语言
for t in range(total_steps - 1, -1, -1):
# 根据时间步决定用哪种语言
if t > switch_point:
current_lang = source_lang
else:
current_lang = target_lang
t_tensor = torch.full((x_t.size(0),), t, dtype=torch.long)
# 预测噪声
predicted_noise = self.model(x_t, t_tensor, lang=current_lang)
if verbose:
current_ids = self._embed_to_tokens(x_t, current_lang)
current_text = self._decode(current_ids, current_lang)
if len(current_text) > 50:
current_text = current_text[:50] + "..."
print(f" Step {t:4d}{current_text}")
# DDPM步骤
x_t = self.diffusion.p_sample(x_t, t_tensor, predicted_noise)
# 解码
final_ids = self._embed_to_tokens(x_t, target_lang)
result = self._decode(final_ids, target_lang)
if verbose:
print(f"\n输出: {result}")
return result
def interactive(self):
"""交互模式"""
print("\n" + "=" * 50)
print("Diffutslator 交互翻译模式")
print("=" * 50)
print("输入 'zh: 文本' 翻译中文到英文")
print("输入 'en: text' 翻译英文到中文")
print("输入 'quit' 或 'exit' 退出")
print("=" * 50 + "\n")
while True:
try:
user_input = input(">>> ").strip()
if user_input.lower() in ['quit', 'exit', 'q']:
print("再见!")
break
if not user_input:
continue
# 解析输入
if user_input.lower().startswith('zh:'):
text = user_input[3:].strip()
source_lang = "zh"
elif user_input.lower().startswith('en:'):
text = user_input[3:].strip()
source_lang = "en"
else:
# 自动检测(简单判断)
if any('\u4e00' <= c <= '\u9fff' for c in user_input):
text = user_input
source_lang = "zh"
else:
text = user_input
source_lang = "en"
# 翻译
result = self.translate(text, source_lang, verbose=True)
except KeyboardInterrupt:
print("\n再见!")
break
except Exception as e:
print(f"错误: {e}")
def main():
parser = argparse.ArgumentParser(description="Diffutslator 推理脚本")
parser.add_argument("--checkpoint", type=str, default=None, help="检查点路径")
parser.add_argument("--text", type=str, default=None, help="要翻译的文本")
parser.add_argument("--zh", action="store_true", help="输入是中文")
parser.add_argument("--en", action="store_true", help="输入是英文")
parser.add_argument("--interactive", "-i", action="store_true", help="交互模式")
parser.add_argument("--quiet", "-q", action="store_true", help="安静模式,不打印过程")
parser.add_argument("--ddim-steps", type=int, default=50, help="DDIM步数")
args = parser.parse_args()
# 配置
config = Config()
config.diffusion.ddim_steps = args.ddim_steps
# 找检查点
checkpoint_path = args.checkpoint
if checkpoint_path is None:
checkpoint_dir = os.path.join(config.project_dir, config.training.checkpoint_dir)
best_path = os.path.join(checkpoint_dir, "best.pt")
if os.path.exists(best_path):
checkpoint_path = best_path
else:
# 找最新的检查点
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pt')]
if checkpoints:
checkpoint_path = os.path.join(checkpoint_dir, checkpoints[-1])
if checkpoint_path is None:
print("错误: 未找到检查点,请先训练模型")
return
# 创建翻译器
translator = Translator(config, checkpoint_path)
# 模式
if args.interactive:
translator.interactive()
elif args.text:
if args.zh:
source_lang = "zh"
elif args.en:
source_lang = "en"
else:
# 自动检测
if any('\u4e00' <= c <= '\u9fff' for c in args.text):
source_lang = "zh"
else:
source_lang = "en"
result = translator.translate(args.text, source_lang, verbose=not args.quiet)
if args.quiet:
print(result)
else:
# 默认交互模式
translator.interactive()
if __name__ == "__main__":
main()