Eliot0110 commited on
Commit
502ec94
·
1 Parent(s): ddc2802

improve: re+llm

Browse files
Files changed (1) hide show
  1. modules/info_extractor.py +249 -28
modules/info_extractor.py CHANGED
@@ -1,7 +1,7 @@
1
  import json
2
  import re
3
  from utils.logger import log
4
- from .ai_model import AIModel
5
 
6
  class InfoExtractor:
7
  def __init__(self, ai_model: AIModel):
@@ -15,29 +15,139 @@ class InfoExtractor:
15
 
16
  def extract(self, user_message: str) -> dict:
17
  """从用户消息中提取结构化信息,确保使用确定性解码。"""
 
 
 
 
 
 
 
18
  prompt = self._build_prompt_template(user_message)
19
 
20
  # --- 核心修复:强制使用确定性解码以杜绝幻觉 ---
21
- # 确保调用AI模型时,使用类似 do_sample=False 或 temperature=0 的参数
22
- # 这里我们模拟这个调用,并强调其重要性
23
  log.info("🧠 使用LLM开始提取信息 (模式: 确定性)")
 
 
 
24
  raw_response = self.ai_model.run_inference(
25
- input_type='text',
26
  formatted_input=None,
27
  prompt=prompt,
28
- temperature=0.0
29
  )
30
 
31
  try:
32
- extracted_json = json.loads(raw_response)
 
 
33
  log.info(f"✅ LLM成功提取并解析JSON: {extracted_json}")
 
34
  # 使用新的验证方法
35
  validated_data = self._validate_and_normalize(extracted_json)
36
  log.info(f"📊 LLM最终提取结果 (安全处理后): {validated_data}")
37
  return validated_data
 
38
  except (json.JSONDecodeError, TypeError) as e:
39
  log.error(f"❌ 解析或验证LLM提取的JSON失败: {e}", exc_info=True)
40
- return {} # 返回一个空字典而不是列表
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  def _validate_and_normalize(self, data: dict) -> dict:
43
  """
@@ -48,40 +158,108 @@ class InfoExtractor:
48
  return {}
49
 
50
  validated_output = {}
51
- for key, schema in self.extraction_schema.items():
52
- if key in data and isinstance(data[key], schema["type"]):
53
- # 这里可以添加更深层次的字段类型验证
54
- # 为简化,我们暂时只验证第一层
55
- validated_output[key] = data[key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  if not validated_output:
58
  log.warning(f"⚠️ 提取的数据 {data} 未通过验证,未发现任何有效字段。")
59
 
60
  return validated_output
61
 
62
- def _build_prompt_template(self,user_message: str) -> str:
63
- # --- 重点更新:使用更严格的指令和结构化示例 ---
 
 
 
 
 
64
  return f"""你的任务是且仅是作为文本解析器。
65
  严格分析用户输入,并以一个纯净、无注释的JSON对象格式返回。
66
 
67
  **核心规则:**
68
- 1. **绝对禁止** 在JSON之外添加任何文本、注释、解释或Markdown标记。你的输出必须从 `{{` 开始,到 `}}` 结束。
69
- 2. **必须严格遵守** 下方定义的嵌套JSON结构。不要创造新的键,也不要改变层级。
70
- 3. 如果信息未提供,对应的键值必须为 `null`,而不是省略该键。
71
- 4. 如果用户输入与旅行无关(如 "你好"),必须返回一个空的JSON对象: `{{}}`。
72
 
73
  **强制JSON输出结构:**
74
  {{
75
  "destination": {{
76
- "name": "string or null"
 
77
  }},
78
  "duration": {{
79
- "days": "integer or null"
 
80
  }},
81
  "budget": {{
82
  "type": "string ('economy', 'comfortable', 'luxury') or null",
83
  "amount": "number or null",
84
- "currency": "string or null"
 
85
  }}
86
  }}
87
 
@@ -106,12 +284,12 @@ class InfoExtractor:
106
  }}
107
 
108
  **示例2:**
109
- 用户输入: "计划去西班牙巴塞罗那旅行一周,预算2万元"
110
  你的输出:
111
  {{
112
  "destination": {{
113
- "name": "巴塞罗那",
114
- "country": "西班牙"
115
  }},
116
  "duration": {{
117
  "days": 7,
@@ -125,11 +303,54 @@ class InfoExtractor:
125
  }}
126
  }}
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  ---
129
  **用户输入:**
130
- `{user_message}`
131
 
132
  **你的输出 (必须是纯JSON):**
133
- """
134
-
135
-
 
1
  import json
2
  import re
3
  from utils.logger import log
4
+ from .ai_model import AIModel
5
 
6
  class InfoExtractor:
7
  def __init__(self, ai_model: AIModel):
 
15
 
16
  def extract(self, user_message: str) -> dict:
17
  """从用户消息中提取结构化信息,确保使用确定性解码。"""
18
+
19
+ # 输入验证
20
+ if not user_message or not isinstance(user_message, str):
21
+ log.warning("⚠️ 收到无效的用户消息")
22
+ return {}
23
+
24
+ # 构建prompt
25
  prompt = self._build_prompt_template(user_message)
26
 
27
  # --- 核心修复:强制使用确定性解码以杜绝幻觉 ---
 
 
28
  log.info("🧠 使用LLM开始提取信息 (模式: 确定性)")
29
+
30
+ # 注意:ai_model.generate() 方法不支持 do_sample 和 temperature 参数
31
+ # 需要通过其他方式确保确定性输出
32
  raw_response = self.ai_model.run_inference(
33
+ input_type="text",
34
  formatted_input=None,
35
  prompt=prompt,
36
+ temperature=0.0 # 使用最低温度确保确定性
37
  )
38
 
39
  try:
40
+ # 清理响应,提取纯JSON部分
41
+ cleaned_response = self._clean_json_response(raw_response)
42
+ extracted_json = json.loads(cleaned_response)
43
  log.info(f"✅ LLM成功提取并解析JSON: {extracted_json}")
44
+
45
  # 使用新的验证方法
46
  validated_data = self._validate_and_normalize(extracted_json)
47
  log.info(f"📊 LLM最终提取结果 (安全处理后): {validated_data}")
48
  return validated_data
49
+
50
  except (json.JSONDecodeError, TypeError) as e:
51
  log.error(f"❌ 解析或验证LLM提取的JSON失败: {e}", exc_info=True)
52
+ log.debug(f"🔍 原始响应: {raw_response}")
53
+ # 尝试备用提取方法
54
+ return self._fallback_extraction(user_message)
55
+
56
+ def _clean_json_response(self, response: str) -> str:
57
+ """清理LLM响应,提取纯JSON部分"""
58
+ if not response:
59
+ return "{}"
60
+
61
+ # 移除可能的markdown代码块标记
62
+ response = re.sub(r'```json\s*', '', response)
63
+ response = re.sub(r'```\s*', '', response)
64
+
65
+ # 移除可能的前导文字
66
+ response = re.sub(r'^[^{]*', '', response)
67
+
68
+ # 查找第一个{和最后一个}
69
+ start_idx = response.find('{')
70
+ end_idx = response.rfind('}')
71
+
72
+ if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
73
+ return response[start_idx:end_idx+1].strip()
74
+
75
+ # 如果找不到有效的JSON结构,返回空对象
76
+ return "{}"
77
+
78
+ def _fallback_extraction(self, user_message: str) -> dict:
79
+ """基于规则的备用信息提取"""
80
+ log.info("🔄 使用基于规则的备用提取方法")
81
+
82
+ result = {}
83
+ message_lower = user_message.lower()
84
+
85
+ # 目的地提取 - 更全面的模式
86
+ city_patterns = [
87
+ r'去(\w+)', r'到(\w+)', r'想去(\w+)', r'前往(\w+)',
88
+ r'旅行(\w+)', r'游(\w+)', r'玩(\w+)', r'访问(\w+)',
89
+ r'目的地[\s是::]*(\w+)', r'地方[\s是::]*(\w+)'
90
+ ]
91
+
92
+ for pattern in city_patterns:
93
+ match = re.search(pattern, user_message)
94
+ if match:
95
+ city_name = match.group(1)
96
+ if len(city_name) >= 2 and not city_name.isdigit():
97
+ result["destination"] = {"name": city_name}
98
+ break
99
+
100
+ # 天数提取 - 更全面的模式
101
+ day_patterns = [
102
+ r'(\d+)天', r'(\d+)日', r'玩(\d+)天', r'住(\d+)天',
103
+ r'(\d+)个天', r'呆(\d+)天', r'待(\d+)天', r'(\d+)天行程'
104
+ ]
105
+
106
+ for pattern in day_patterns:
107
+ match = re.search(pattern, user_message)
108
+ if match:
109
+ days = int(match.group(1))
110
+ if 1 <= days <= 365: # 合理范围检查
111
+ result["duration"] = {"days": days}
112
+ break
113
+
114
+ # 预算提取 - 更全面的模式
115
+ budget_patterns = [
116
+ r'(\d+)元', r'(\d+)块', r'预算(\d+)', r'(\d+)rmb',
117
+ r'(\d+)人民币', r'花(\d+)', r'费用(\d+)', r'(\d+)万'
118
+ ]
119
+
120
+ for pattern in budget_patterns:
121
+ match = re.search(pattern, user_message)
122
+ if match:
123
+ amount = int(match.group(1))
124
+ # 处理"万"的情况
125
+ if '万' in pattern:
126
+ amount *= 10000
127
+ result["budget"] = {
128
+ "type": None,
129
+ "amount": amount,
130
+ "currency": "RMB"
131
+ }
132
+ break
133
+
134
+ # 预算类型识别
135
+ budget_type_keywords = {
136
+ 'economy': ['经济', '便宜', '省钱', '实惠', '节省'],
137
+ 'comfortable': ['舒适', '中等', '适中', '一般'],
138
+ 'luxury': ['豪华', '奢华', '高端', '贵一点', '不差钱']
139
+ }
140
+
141
+ for budget_type, keywords in budget_type_keywords.items():
142
+ if any(keyword in message_lower for keyword in keywords):
143
+ if "budget" not in result:
144
+ result["budget"] = {"type": budget_type, "amount": None, "currency": None}
145
+ else:
146
+ result["budget"]["type"] = budget_type
147
+ break
148
+
149
+ log.info(f"🛠️ 备用提取结果: {result}")
150
+ return result
151
 
152
  def _validate_and_normalize(self, data: dict) -> dict:
153
  """
 
158
  return {}
159
 
160
  validated_output = {}
161
+
162
+ # 验证destination
163
+ if "destination" in data:
164
+ dest_data = data["destination"]
165
+ if isinstance(dest_data, dict):
166
+ validated_dest = {}
167
+ if "name" in dest_data and isinstance(dest_data["name"], str):
168
+ name = dest_data["name"].strip()
169
+ if name:
170
+ validated_dest["name"] = name
171
+ if "country" in dest_data and isinstance(dest_data["country"], str):
172
+ country = dest_data["country"].strip()
173
+ if country:
174
+ validated_dest["country"] = country
175
+ if validated_dest:
176
+ validated_output["destination"] = validated_dest
177
+
178
+ # 验证duration
179
+ if "duration" in data:
180
+ duration_data = data["duration"]
181
+ if isinstance(duration_data, dict):
182
+ validated_duration = {}
183
+ if "days" in duration_data:
184
+ days = duration_data["days"]
185
+ if isinstance(days, (int, float)) and 1 <= days <= 365:
186
+ validated_duration["days"] = int(days)
187
+ if "description" in duration_data and isinstance(duration_data["description"], str):
188
+ desc = duration_data["description"].strip()
189
+ if desc:
190
+ validated_duration["description"] = desc
191
+ if validated_duration:
192
+ validated_output["duration"] = validated_duration
193
+
194
+ # 验证budget
195
+ if "budget" in data:
196
+ budget_data = data["budget"]
197
+ if isinstance(budget_data, dict):
198
+ validated_budget = {}
199
+
200
+ # 验证type
201
+ if "type" in budget_data:
202
+ budget_type = budget_data["type"]
203
+ if budget_type in ["economy", "comfortable", "luxury"]:
204
+ validated_budget["type"] = budget_type
205
+
206
+ # 验证amount
207
+ if "amount" in budget_data:
208
+ amount = budget_data["amount"]
209
+ if isinstance(amount, (int, float)) and amount > 0:
210
+ validated_budget["amount"] = int(amount)
211
+
212
+ # 验证currency
213
+ if "currency" in budget_data and isinstance(budget_data["currency"], str):
214
+ currency = budget_data["currency"].strip()
215
+ if currency:
216
+ validated_budget["currency"] = currency
217
+
218
+ # 验证description
219
+ if "description" in budget_data and isinstance(budget_data["description"], str):
220
+ desc = budget_data["description"].strip()
221
+ if desc:
222
+ validated_budget["description"] = desc
223
+
224
+ if validated_budget:
225
+ validated_output["budget"] = validated_budget
226
 
227
  if not validated_output:
228
  log.warning(f"⚠️ 提取的数据 {data} 未通过验证,未发现任何有效字段。")
229
 
230
  return validated_output
231
 
232
+ def _build_prompt_template(self, user_message: str) -> str:
233
+ """构建包含多个示例的提取prompt"""
234
+
235
+ # 输��长度控制
236
+ if len(user_message) > 300:
237
+ user_message = user_message[:300] + "..."
238
+
239
  return f"""你的任务是且仅是作为文本解析器。
240
  严格分析用户输入,并以一个纯净、无注释的JSON对象格式返回。
241
 
242
  **核心规则:**
243
+ 1. **绝对禁止** 在JSON之外添加任何文本、注释、解释或Markdown标记。你的输出必须从 `{{` 开始,到 `}}` 结束。
244
+ 2. **必须严格遵守** 下方定义的嵌套JSON结构。不要创造新的键,也不要改变层级。
245
+ 3. 如果信息未提供,对应的键值必须为 `null`,而不是省略该键。
246
+ 4. 如果用户输入与旅行无关(如 "你好"),必须返回一个空的JSON对象: `{{}}`。
247
 
248
  **强制JSON输出结构:**
249
  {{
250
  "destination": {{
251
+ "name": "string or null",
252
+ "country": "string or null"
253
  }},
254
  "duration": {{
255
+ "days": "integer or null",
256
+ "description": "string or null"
257
  }},
258
  "budget": {{
259
  "type": "string ('economy', 'comfortable', 'luxury') or null",
260
  "amount": "number or null",
261
+ "currency": "string or null",
262
+ "description": "string or null"
263
  }}
264
  }}
265
 
 
284
  }}
285
 
286
  **示例2:**
287
+ 用户输入: "计划去日本东京旅行一周,预算2万元"
288
  你的输出:
289
  {{
290
  "destination": {{
291
+ "name": "东京",
292
+ "country": "日本"
293
  }},
294
  "duration": {{
295
  "days": 7,
 
303
  }}
304
  }}
305
 
306
+ **示例3:**
307
+ 用户输入: "想要一个经济实惠的巴黎5天行程"
308
+ 你的输出:
309
+ {{
310
+ "destination": {{
311
+ "name": "巴黎",
312
+ "country": null
313
+ }},
314
+ "duration": {{
315
+ "days": 5,
316
+ "description": null
317
+ }},
318
+ "budget": {{
319
+ "type": "economy",
320
+ "amount": null,
321
+ "currency": null,
322
+ "description": "经济实惠"
323
+ }}
324
+ }}
325
+
326
+ **示例4:**
327
+ 用户输入: "你好"
328
+ 你的输出:
329
+ {{}}
330
+
331
+ **示例5:**
332
+ 用户输入: "想去泰国普吉岛度蜜月,10天左右,豪华一点不差钱"
333
+ 你的输出:
334
+ {{
335
+ "destination": {{
336
+ "name": "普吉岛",
337
+ "country": "泰国"
338
+ }},
339
+ "duration": {{
340
+ "days": 10,
341
+ "description": "10天左右"
342
+ }},
343
+ "budget": {{
344
+ "type": "luxury",
345
+ "amount": null,
346
+ "currency": null,
347
+ "description": "豪华一点不差钱"
348
+ }}
349
+ }}
350
+
351
  ---
352
  **用户输入:**
353
+ {user_message}
354
 
355
  **你的输出 (必须是纯JSON):**
356
+ """