Eliot0110 commited on
Commit
240c11f
·
1 Parent(s): ae5cfb9

improve: re

Browse files
Files changed (1) hide show
  1. modules/info_extractor.py +63 -26
modules/info_extractor.py CHANGED
@@ -1,43 +1,72 @@
1
  import json
 
2
  from utils.logger import log
3
  from .ai_model import AIModel
4
 
5
  class InfoExtractor:
6
  def __init__(self, ai_model):
7
-
8
  self.ai_model = ai_model
9
  self.prompt_template = self._build_prompt_template()
10
 
11
  def _build_prompt_template(self) -> str:
 
 
 
12
 
13
- return """你是一个专业的旅游信息提取AI。
14
- 你的任务是仔细阅读用户的请求,并从中提取出关键的旅行信息。
 
 
 
 
15
 
16
- 请严格按照以下嵌套的JSON格式返回
17
- ---
18
- **重要规则**
19
- 1. 如果某个信息在用户请求中没有明确提及,请将对应的值设为 null。
20
- 2. **如果用户的请求只是简单的问候 (例如 "hi", "你好"),或者完全不包含任何目的地、时间、预算等旅行信息,请必须返回一个空的JSON对象,即 `{{}}`。**
21
- ---
 
 
 
 
 
 
 
 
 
 
 
 
22
 
 
 
 
23
  {{
24
  "destination": {{
25
- "name": "提取出的目的地名称"
26
  }},
27
  "duration": {{
28
- "days": "提取出的天数 (必须是整数)"
29
  }},
30
  "budget": {{
31
- "type": "提取出的预算类型 (从 'economy', 'comfortable', 'luxury' 中选择一个)",
32
- "amount": "提取出的具体预算金额 (必须是数字)",
33
- "currency": "提取出的货币单位 (例如 'EUR', 'USD', 'CNY')"
34
  }}
35
  }}
 
36
 
37
- 用户的输入是:
38
  ---
 
 
 
 
 
39
  {user_message}
40
- ---
 
 
41
  """
42
 
43
  def extract(self, message: str) -> dict:
@@ -50,29 +79,37 @@ class InfoExtractor:
50
  prompt = self.prompt_template.format(user_message=message)
51
 
52
  # 2. 调用AI模型生成结果
53
- # 注意:这里假设你的ai_model有一个 .generate() 方法
54
  raw_response = self.ai_model.generate(prompt)
55
 
56
  if not raw_response:
57
  log.error("❌ LLM模型没有返回任何内容。")
58
  return {}
59
 
60
- # 3. 解析LLM返回的JSON字符串
61
  try:
62
- # 清理可能的Markdown代码块标记
63
- clean_response = raw_response.strip().replace('```json', '').replace('```', '')
64
- extracted_data = json.loads(clean_response)
 
 
 
 
 
 
 
 
 
 
 
65
  log.info(f"✅ LLM成功提取并解析JSON: {extracted_data}")
66
- except json.JSONDecodeError:
67
- log.error(f"❌ 无法解析LLM返回的JSON: '{raw_response}'")
68
- # 在这里可以尝试用正则等方式做最后的补救,但暂时从简
69
  return {}
70
 
71
  # 4. 清理和格式化提取出的数据
72
- # 移除值为null的顶级键
73
  final_info = {
74
  key: value for key, value in extracted_data.items() if value and any(v is not None for v in value.values())
75
  }
76
 
77
- log.info(f"📋 LLM最终提取结果: {list(final_info.keys())}")
78
  return final_info
 
1
  import json
2
+ import re # 导入正则表达式模块
3
  from utils.logger import log
4
  from .ai_model import AIModel
5
 
6
  class InfoExtractor:
7
  def __init__(self, ai_model):
 
8
  self.ai_model = ai_model
9
  self.prompt_template = self._build_prompt_template()
10
 
11
  def _build_prompt_template(self) -> str:
12
+ # --- 重点更新:使用更强大、更明确的Prompt ---
13
+ return """你是一个专门用于从文本中提取结构化旅行信息的AI助理。
14
+ 你的唯一任务是分析用户提供的文本,并严格按照指定的JSON格式输出提取的信息。
15
 
16
+ **输出要求:**
17
+ 1. **严格的JSON格式**: 输出必须是一个单一、完整、有效的JSON对象。
18
+ 2. **禁止任何额外文本**: 不要在JSON对象前后添加任何解释、注释、Markdown标记或任何其他文字。
19
+ 3. **遵循指定结构**: JSON的键和层级结构必须与下方定义的格式完全一致。
20
+ 4. **处理缺失信息**: 如果用户输入中没有提到某个字段,请将该字段的值设为 null。
21
+ 5. **处理无关输入**: 如果用户输入是简单的问候或与旅行无关,请返回一个空的JSON对象 `{{}}`。
22
 
23
+ **JSON输出格式定义:**
24
+ ```json
25
+ {{
26
+ "destination": {{
27
+ "name": "string or null"
28
+ }},
29
+ "duration": {{
30
+ "days": "integer or null"
31
+ }},
32
+ "budget": {{
33
+ "type": "string ('economy', 'comfortable', 'luxury') or null",
34
+ "amount": "number or null",
35
+ "currency": "string or null"
36
+ }}
37
+ }}
38
+ ```
39
+
40
+ **示例:**
41
 
42
+ 用户输入: "我想去巴黎玩一个星期,预算大概是经济型的"
43
+ 你的输出:
44
+ ```json
45
  {{
46
  "destination": {{
47
+ "name": "巴黎"
48
  }},
49
  "duration": {{
50
+ "days": 7
51
  }},
52
  "budget": {{
53
+ "type": "economy",
54
+ "amount": null,
55
+ "currency": null
56
  }}
57
  }}
58
+ ```
59
 
 
60
  ---
61
+
62
+ 现在,请处理以下用户输入。
63
+
64
+ **用户输入:**
65
+ ```
66
  {user_message}
67
+ ```
68
+
69
+ **你的输出:**
70
  """
71
 
72
  def extract(self, message: str) -> dict:
 
79
  prompt = self.prompt_template.format(user_message=message)
80
 
81
  # 2. 调用AI模型生成结果
 
82
  raw_response = self.ai_model.generate(prompt)
83
 
84
  if not raw_response:
85
  log.error("❌ LLM模型没有返回任何内容。")
86
  return {}
87
 
88
+ # --- 重点更新:使用更稳健的JSON解析逻辑 ---
89
  try:
90
+ # 优先使用正则表达式从 ```json ... ``` 代码块中提取
91
+ match = re.search(r'```json\s*(\{.*?\})\s*```', raw_response, re.DOTALL)
92
+ if match:
93
+ json_str = match.group(1)
94
+ else:
95
+ # 如果正则没匹配到,就粗暴地寻找第一个'{'和最后一个'}'
96
+ start_index = raw_response.find('{')
97
+ end_index = raw_response.rfind('}')
98
+ if start_index != -1 and end_index != -1 and end_index > start_index:
99
+ json_str = raw_response[start_index:end_index + 1]
100
+ else:
101
+ raise json.JSONDecodeError("在LLM的返回中未找到有效的JSON对象。", raw_response, 0)
102
+
103
+ extracted_data = json.loads(json_str)
104
  log.info(f"✅ LLM成功提取并解析JSON: {extracted_data}")
105
+ except json.JSONDecodeError as e:
106
+ log.error(f"❌ 无法解析LLM返回的JSON: '{raw_response}'. 错误: {e}")
 
107
  return {}
108
 
109
  # 4. 清理和格式化提取出的数据
 
110
  final_info = {
111
  key: value for key, value in extracted_data.items() if value and any(v is not None for v in value.values())
112
  }
113
 
114
+ log.info(f"📊 LLM最终提取结果: {list(final_info.keys())}")
115
  return final_info