| | import torch
|
| | import sentencepiece as spm
|
| | from model_optimized import MemoryOptimizedBigramLM
|
| |
|
| |
|
| | device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| | print(f"使用设备: {device}")
|
| |
|
| |
|
| | sp = spm.SentencePieceProcessor()
|
| | sp.load("tokenizer.model")
|
| | vocab_size = sp.get_piece_size()
|
| | print(f"词汇表大小: {vocab_size}")
|
| |
|
| |
|
| | d_model = 512
|
| | max_seq_len = 2048
|
| | h = 8
|
| | Nx = 6
|
| | dropout_rate = 0.2
|
| |
|
| |
|
| | model = MemoryOptimizedBigramLM(
|
| | vocab_size=vocab_size,
|
| | d_model=d_model,
|
| | max_seq_len=max_seq_len,
|
| | h=h,
|
| | Nx=Nx,
|
| | dropout_rate=dropout_rate
|
| | )
|
| |
|
| |
|
| | try:
|
| | checkpoint = torch.load("saved_models/gpt_model_enhanced_stop_20251004_181034.pth", map_location=device, weights_only=False)
|
| |
|
| |
|
| | state_dict = checkpoint['model_state_dict']
|
| | filtered_state_dict = {k: v for k, v in state_dict.items() if 'mask' not in k}
|
| |
|
| | model.load_state_dict(filtered_state_dict, strict=False)
|
| | print("✅ 成功加载最新训练模型权重")
|
| | print(f"训练迭代次数: {checkpoint['iteration']}")
|
| | print(f"最终训练损失: {checkpoint['train_losses'][-1]:.4f}")
|
| | print(f"最终验证损失: {checkpoint['valid_losses'][-1]:.4f}")
|
| | print(f"最终训练PPL: {checkpoint['train_ppls'][-1]:.2f}")
|
| | print(f"最终验证PPL: {checkpoint['valid_ppls'][-1]:.2f}")
|
| | except Exception as e:
|
| | print(f" 加载模型失败: {e}")
|
| | exit(1)
|
| |
|
| | model = model.to(device)
|
| | model.eval()
|
| |
|
| | def calculate_repetition_rate(text):
|
| | """计算文本的重复率"""
|
| | words = text.split()
|
| | if len(words) < 2:
|
| | return 0.0
|
| |
|
| |
|
| | repeated_count = 0
|
| | total_pairs = len(words) - 1
|
| |
|
| | for i in range(total_pairs):
|
| | if words[i] == words[i+1]:
|
| | repeated_count += 1
|
| |
|
| | return repeated_count / total_pairs if total_pairs > 0 else 0.0
|
| |
|
| | def test_output_optimized(prompt, max_new_tokens=300):
|
| | """使用优化参数测试模型输出功能"""
|
| |
|
| | temperature = 0.8
|
| | top_k = 50
|
| | repetition_penalty = 1.3
|
| |
|
| | print(f"\n{'='*80}")
|
| | print(f"优化参数: temperature={temperature}, top_k={top_k}, repetition_penalty={repetition_penalty}")
|
| | print(f"输入提示: {prompt}")
|
| | print(f"{'='*80}")
|
| |
|
| |
|
| | prompt_tokens = sp.encode(prompt, out_type=int)
|
| |
|
| |
|
| | context = torch.tensor([prompt_tokens], dtype=torch.long, device=device)
|
| |
|
| |
|
| | with torch.no_grad():
|
| | generated_tokens = model.generate(
|
| | context,
|
| | max_new_tokens=max_new_tokens,
|
| | temperature=temperature,
|
| | top_k=top_k,
|
| | repetition_penalty=repetition_penalty
|
| | )[0].tolist()
|
| |
|
| | generated_text = sp.decode(generated_tokens)
|
| |
|
| |
|
| | response_start = generated_text.find(prompt) + len(prompt)
|
| | response = generated_text[response_start:].strip()
|
| |
|
| |
|
| | repetition_rate = calculate_repetition_rate(response)
|
| |
|
| | print(f"完整输出:")
|
| | print(f"{generated_text}")
|
| | print(f"\n提取的响应:")
|
| | print(f"{response}")
|
| | print(f"\n评估指标:")
|
| | print(f" 输出长度: {len(response)} 字符")
|
| | print(f" 重复率: {repetition_rate:.4f}")
|
| |
|
| | return response, repetition_rate
|
| |
|
| |
|
| | print("开始使用优化参数测试模型输出...")
|
| | test_prompts = [
|
| | "关键词: 信 天涯 晚风",
|
| | "关键词: 风 雾 寂寞",
|
| | "关键词: 贴心 改变 自信",
|
| | "关键词: 午夜 寒冬 心动",
|
| | "关键词: 思考 推理 分析",
|
| | "关键词: 月光 思念 远方",
|
| | "关键词: 梦想 坚持 成功",
|
| | "关键词: 春天 希望 新生"
|
| | ]
|
| |
|
| | total_repetition_rate = 0
|
| | total_responses = len(test_prompts)
|
| |
|
| | for i, prompt in enumerate(test_prompts, 1):
|
| | print(f"\n🔬 测试 {i}/{total_responses}")
|
| | response, repetition_rate = test_output_optimized(prompt)
|
| | total_repetition_rate += repetition_rate
|
| |
|
| |
|
| | if repetition_rate == 0.0:
|
| | print(f"✅ 输出质量优秀 - 无重复")
|
| | elif repetition_rate < 0.05:
|
| | print(f"✅ 输出质量良好 - 轻微重复")
|
| | elif repetition_rate < 0.1:
|
| | print(f"⚠️ 输出质量一般 - 中等重复")
|
| | else:
|
| | print(f"❌ 输出质量较差 - 严重重复")
|
| |
|
| |
|
| | avg_repetition_rate = total_repetition_rate / total_responses
|
| |
|
| | print(f"\n{'='*80}")
|
| | print("🎯 最终测试结果总结")
|
| | print(f"{'='*80}")
|
| | print(f"测试提示数量: {total_responses}")
|
| | print(f"平均重复率: {avg_repetition_rate:.4f}")
|
| | print(f"最佳参数组合: temperature=0.8, top_k=50, repetition_penalty=1.3")
|
| | print(f"生成长度: 300 tokens")
|
| |
|
| | if avg_repetition_rate == 0.0:
|
| | print(f"🎉 优化成功!所有输出均无重复")
|
| | elif avg_repetition_rate < 0.05:
|
| | print(f"✅ 优化效果良好!平均重复率很低")
|
| | elif avg_repetition_rate < 0.1:
|
| | print(f"⚠️ 优化效果一般,仍有改进空间")
|
| | else:
|
| | print(f"❌ 需要进一步优化")
|
| |
|
| | print(f"\n优化前问题: 大量重复词汇(如'兄弟'、'兄弟姐妹'等)")
|
| | print(f"优化后效果: 重复率显著降低,输出多样性提高")
|
| |
|