yangdx commited on
Commit
6baf4a8
·
1 Parent(s): 99a3d9e

Translate unit test comment and promts to English

Browse files
Files changed (1) hide show
  1. test_lightrag_ollama_chat.py +161 -182
test_lightrag_ollama_chat.py CHANGED
@@ -19,7 +19,7 @@ from datetime import datetime
19
  from pathlib import Path
20
 
21
  class OutputControl:
22
- """输出控制类,管理测试输出的详细程度"""
23
  _verbose: bool = False
24
 
25
  @classmethod
@@ -32,7 +32,7 @@ class OutputControl:
32
 
33
  @dataclass
34
  class TestResult:
35
- """测试结果数据类"""
36
  name: str
37
  success: bool
38
  duration: float
@@ -44,7 +44,7 @@ class TestResult:
44
  self.timestamp = datetime.now().isoformat()
45
 
46
  class TestStats:
47
- """测试统计信息"""
48
  def __init__(self):
49
  self.results: List[TestResult] = []
50
  self.start_time = datetime.now()
@@ -53,10 +53,9 @@ class TestStats:
53
  self.results.append(result)
54
 
55
  def export_results(self, path: str = "test_results.json"):
56
- """导出测试结果到 JSON 文件
57
-
58
  Args:
59
- path: 输出文件路径
60
  """
61
  results_data = {
62
  "start_time": self.start_time.isoformat(),
@@ -72,7 +71,7 @@ class TestStats:
72
 
73
  with open(path, "w", encoding="utf-8") as f:
74
  json.dump(results_data, f, ensure_ascii=False, indent=2)
75
- print(f"\n测试结果已保存到: {path}")
76
 
77
  def print_summary(self):
78
  total = len(self.results)
@@ -80,28 +79,27 @@ class TestStats:
80
  failed = total - passed
81
  duration = sum(r.duration for r in self.results)
82
 
83
- print("\n=== 测试结果摘要 ===")
84
- print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
85
- print(f"总用时: {duration:.2f}")
86
- print(f"总计: {total} 个测试")
87
- print(f"通过: {passed}")
88
- print(f"失败: {failed}")
89
 
90
  if failed > 0:
91
- print("\n失败的测试:")
92
  for result in self.results:
93
  if not result.success:
94
  print(f"- {result.name}: {result.error}")
95
 
96
- # 默认配置
97
  DEFAULT_CONFIG = {
98
  "server": {
99
  "host": "localhost",
100
  "port": 9621,
101
  "model": "lightrag:latest",
102
- "timeout": 30, # 请求超时时间(秒)
103
- "max_retries": 3, # 最大重试次数
104
- "retry_delay": 1 # 重试间隔(秒)
105
  },
106
  "test_cases": {
107
  "basic": {
@@ -111,18 +109,16 @@ DEFAULT_CONFIG = {
111
  }
112
 
113
  def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
114
- """发送 HTTP 请求,支持重试机制
115
-
116
  Args:
117
- url: 请求 URL
118
- data: 请求数据
119
- stream: 是否使用流式响应
120
-
121
  Returns:
122
- requests.Response: 对象
123
 
124
  Raises:
125
- requests.exceptions.RequestException: 请求失败且重试次数用完
126
  """
127
  server_config = CONFIG["server"]
128
  max_retries = server_config["max_retries"]
@@ -139,19 +135,18 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques
139
  )
140
  return response
141
  except requests.exceptions.RequestException as e:
142
- if attempt == max_retries - 1: # 最后一次重试
143
  raise
144
- print(f"\n请求失败,{retry_delay}秒后重试: {str(e)}")
145
  time.sleep(retry_delay)
146
 
147
  def load_config() -> Dict[str, Any]:
148
- """加载配置文件
149
-
150
- 首先尝试从当前目录的 config.json 加载,
151
- 如果不存在则使用默认配置
152
 
 
 
153
  Returns:
154
- 配置字典
155
  """
156
  config_path = Path("config.json")
157
  if config_path.exists():
@@ -160,23 +155,22 @@ def load_config() -> Dict[str, Any]:
160
  return DEFAULT_CONFIG
161
 
162
  def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
163
- """格式化打印 JSON 响应数据
164
-
165
  Args:
166
- data: 要打印的数据字典
167
- title: 打印的标题
168
- indent: JSON 缩进空格数
169
  """
170
  if OutputControl.is_verbose():
171
  if title:
172
  print(f"\n=== {title} ===")
173
  print(json.dumps(data, ensure_ascii=False, indent=indent))
174
 
175
- # 全局配置
176
  CONFIG = load_config()
177
 
178
  def get_base_url() -> str:
179
- """返回基础 URL"""
180
  server = CONFIG["server"]
181
  return f"http://{server['host']}:{server['port']}/api/chat"
182
 
@@ -185,15 +179,13 @@ def create_request_data(
185
  stream: bool = False,
186
  model: str = None
187
  ) -> Dict[str, Any]:
188
- """创建基本的请求数据
189
-
190
  Args:
191
- content: 用户消息内容
192
- stream: 是否使用流式响应
193
- model: 模型名称
194
-
195
  Returns:
196
- 包含完整请求数据的字典
197
  """
198
  return {
199
  "model": model or CONFIG["server"]["model"],
@@ -206,15 +198,14 @@ def create_request_data(
206
  "stream": stream
207
  }
208
 
209
- # 全局测试统计
210
  STATS = TestStats()
211
 
212
  def run_test(func: Callable, name: str) -> None:
213
- """运行测试并记录结果
214
-
215
  Args:
216
- func: 测试函数
217
- name: 测试名称
218
  """
219
  start_time = time.time()
220
  try:
@@ -227,54 +218,43 @@ def run_test(func: Callable, name: str) -> None:
227
  raise
228
 
229
  def test_non_stream_chat():
230
- """测试非流式调用 /api/chat 接口"""
231
  url = get_base_url()
232
  data = create_request_data(
233
  CONFIG["test_cases"]["basic"]["query"],
234
  stream=False
235
  )
236
 
237
- # 发送请求
238
  response = make_request(url, data)
239
 
240
- # 打印响应
241
  if OutputControl.is_verbose():
242
- print("\n=== 非流式调用响应 ===")
243
  response_json = response.json()
244
 
245
- # 打印响应内容
246
  print_json_response({
247
  "model": response_json["model"],
248
  "message": response_json["message"]
249
- }, "响应内容")
250
-
251
- # # 打印性能统计
252
- # print_json_response({
253
- # "total_duration": response_json["total_duration"],
254
- # "load_duration": response_json["load_duration"],
255
- # "prompt_eval_count": response_json["prompt_eval_count"],
256
- # "prompt_eval_duration": response_json["prompt_eval_duration"],
257
- # "eval_count": response_json["eval_count"],
258
- # "eval_duration": response_json["eval_duration"]
259
- # }, "性能统计")
260
-
261
  def test_stream_chat():
262
- """测试流式调用 /api/chat 接口
263
 
264
- 使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。
265
- 响应格式:
266
  {
267
  "model": "lightrag:latest",
268
  "created_at": "2024-01-15T00:00:00Z",
269
  "message": {
270
  "role": "assistant",
271
- "content": "部分响应内容",
272
  "images": null
273
  },
274
  "done": false
275
  }
276
 
277
- 最后一条消息会包含性能统计信息,done true
278
  """
279
  url = get_base_url()
280
  data = create_request_data(
@@ -282,79 +262,79 @@ def test_stream_chat():
282
  stream=True
283
  )
284
 
285
- # 发送请求并获取流式响应
286
  response = make_request(url, data, stream=True)
287
 
288
  if OutputControl.is_verbose():
289
- print("\n=== 流式调用响应 ===")
290
  output_buffer = []
291
  try:
292
  for line in response.iter_lines():
293
- if line: # 跳过空行
294
  try:
295
- # 解码并解析 JSON
296
  data = json.loads(line.decode('utf-8'))
297
- if data.get("done", True): # 如果是完成标记
298
- if "total_duration" in data: # 最终的性能统计消息
299
- # print_json_response(data, "性能统计")
300
  break
301
- else: # 正常的���容消息
302
  message = data.get("message", {})
303
  content = message.get("content", "")
304
- if content: # 只收集非空内容
305
  output_buffer.append(content)
306
- print(content, end="", flush=True) # 实时打印内容
307
  except json.JSONDecodeError:
308
  print("Error decoding JSON from response line")
309
  finally:
310
- response.close() # 确保关闭响应连接
311
 
312
- # 打印一个换行
313
  print()
314
 
315
  def test_query_modes():
316
- """测试不同的查询模式前缀
317
 
318
- 支持的查询模式:
319
- - /local: 本地检索模式,只在相关度高的文档中搜索
320
- - /global: 全局检索模式,在所有文档中搜索
321
- - /naive: 朴素模式,不使用任何优化策略
322
- - /hybrid: 混合模式(默认),结合多种策略
 
323
 
324
- 每个模式都会返回相同格式的响应,但检索策略不同。
325
  """
326
  url = get_base_url()
327
- modes = ["local", "global", "naive", "hybrid", "mix"] # 支持的查询模式
328
 
329
  for mode in modes:
330
  if OutputControl.is_verbose():
331
- print(f"\n=== 测试 /{mode} 模式 ===")
332
  data = create_request_data(
333
  f"/{mode} {CONFIG['test_cases']['basic']['query']}",
334
  stream=False
335
  )
336
 
337
- # 发送请求
338
  response = make_request(url, data)
339
  response_json = response.json()
340
 
341
- # 打印响应内容
342
  print_json_response({
343
  "model": response_json["model"],
344
  "message": response_json["message"]
345
  })
346
 
347
  def create_error_test_data(error_type: str) -> Dict[str, Any]:
348
- """创建用于错误测试的请求数据
349
-
350
  Args:
351
- error_type: 错误类型,支持:
352
- - empty_messages: 空消息列表
353
- - invalid_role: 无效的角色字段
354
- - missing_content: 缺少内容字段
355
 
356
  Returns:
357
- 包含错误数据的请求字典
358
  """
359
  error_data = {
360
  "empty_messages": {
@@ -367,7 +347,7 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
367
  "messages": [
368
  {
369
  "invalid_role": "user",
370
- "content": "测试消息"
371
  }
372
  ],
373
  "stream": True
@@ -385,101 +365,100 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
385
  return error_data.get(error_type, error_data["empty_messages"])
386
 
387
  def test_stream_error_handling():
388
- """测试流式响应的错误处理
389
 
390
- 测试场景:
391
- 1. 空消息列表
392
- 2. 消息格式错误(缺少必需字段)
393
 
394
- 错误响应会立即返回,不会建立流式连接。
395
- 状态码应该是 4xx,并返回详细的错误信息。
396
  """
397
  url = get_base_url()
398
 
399
  if OutputControl.is_verbose():
400
- print("\n=== 测试流式响应错误处理 ===")
401
 
402
- # 测试空消息列表
403
  if OutputControl.is_verbose():
404
- print("\n--- 测试空消息列表(流式)---")
405
  data = create_error_test_data("empty_messages")
406
  response = make_request(url, data, stream=True)
407
- print(f"状态码: {response.status_code}")
408
  if response.status_code != 200:
409
- print_json_response(response.json(), "错误信息")
410
  response.close()
411
 
412
- # 测试无效角色字段
413
  if OutputControl.is_verbose():
414
- print("\n--- 测试无效角色字段(流式)---")
415
  data = create_error_test_data("invalid_role")
416
  response = make_request(url, data, stream=True)
417
- print(f"状态码: {response.status_code}")
418
  if response.status_code != 200:
419
- print_json_response(response.json(), "错误信息")
420
  response.close()
421
 
422
- # 测试缺少内容字段
423
  if OutputControl.is_verbose():
424
- print("\n--- 测试缺少内容字段(流式)---")
425
  data = create_error_test_data("missing_content")
426
  response = make_request(url, data, stream=True)
427
- print(f"状态码: {response.status_code}")
428
  if response.status_code != 200:
429
- print_json_response(response.json(), "错误信息")
430
  response.close()
431
 
432
  def test_error_handling():
433
- """测试非流式响应的错误处理
434
 
435
- 测试场景:
436
- 1. 空消息列表
437
- 2. 消息格式错误(缺少必需字段)
438
 
439
- 错误响应格式:
440
  {
441
- "detail": "错误描述"
442
  }
443
 
444
- 所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。
445
  """
446
  url = get_base_url()
447
 
448
  if OutputControl.is_verbose():
449
- print("\n=== 测试错误处理 ===")
450
 
451
- # 测试空消息列表
452
  if OutputControl.is_verbose():
453
- print("\n--- 测试空消息列表 ---")
454
  data = create_error_test_data("empty_messages")
455
- data["stream"] = False # 修改为非流式模式
456
  response = make_request(url, data)
457
- print(f"状态码: {response.status_code}")
458
- print_json_response(response.json(), "错误信息")
459
 
460
- # 测试无效角色字段
461
  if OutputControl.is_verbose():
462
- print("\n--- 测试无效角色字段 ---")
463
  data = create_error_test_data("invalid_role")
464
- data["stream"] = False # 修改为非流式模式
465
  response = make_request(url, data)
466
- print(f"状态码: {response.status_code}")
467
- print_json_response(response.json(), "错误信息")
468
 
469
- # 测试缺少内容字段
470
  if OutputControl.is_verbose():
471
- print("\n--- 测试缺少内容字段 ---")
472
  data = create_error_test_data("missing_content")
473
- data["stream"] = False # 修改为非流式模式
474
  response = make_request(url, data)
475
- print(f"状态码: {response.status_code}")
476
- print_json_response(response.json(), "错误信息")
477
 
478
  def get_test_cases() -> Dict[str, Callable]:
479
- """获取所有可用的测试用例
480
-
481
  Returns:
482
- 测试名称到测试函数的映射字典
483
  """
484
  return {
485
  "non_stream": test_non_stream_chat,
@@ -490,30 +469,30 @@ def get_test_cases() -> Dict[str, Callable]:
490
  }
491
 
492
  def create_default_config():
493
- """创建默认配置文件"""
494
  config_path = Path("config.json")
495
  if not config_path.exists():
496
  with open(config_path, "w", encoding="utf-8") as f:
497
  json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
498
- print(f"已创建默认配置文件: {config_path}")
499
 
500
  def parse_args() -> argparse.Namespace:
501
- """解析命令行参数"""
502
  parser = argparse.ArgumentParser(
503
- description="LightRAG Ollama 兼容接口测试",
504
  formatter_class=argparse.RawDescriptionHelpFormatter,
505
  epilog="""
506
- 配置文件 (config.json):
507
  {
508
  "server": {
509
- "host": "localhost", # 服务器地址
510
- "port": 9621, # 服务器端口
511
- "model": "lightrag:latest" # 默认模型名称
512
  },
513
  "test_cases": {
514
  "basic": {
515
- "query": "测试查询", # 基本查询文本
516
- "stream_query": "流式查询" # 流式查询文本
517
  }
518
  }
519
  }
@@ -522,44 +501,44 @@ def parse_args() -> argparse.Namespace:
522
  parser.add_argument(
523
  "-q", "--quiet",
524
  action="store_true",
525
- help="静默模式,只显示测试结果摘要"
526
  )
527
  parser.add_argument(
528
  "-a", "--ask",
529
  type=str,
530
- help="指定查询内容,会覆盖配置文件中的查询设置"
531
  )
532
  parser.add_argument(
533
  "--init-config",
534
  action="store_true",
535
- help="创建默认配置文件"
536
  )
537
  parser.add_argument(
538
  "--output",
539
  type=str,
540
  default="",
541
- help="测试结果输出文件路径,默认不输出到文件"
542
  )
543
  parser.add_argument(
544
  "--tests",
545
  nargs="+",
546
  choices=list(get_test_cases().keys()) + ["all"],
547
  default=["all"],
548
- help="要运行的测试用例,可选: %(choices)s。使用 all 运行所有测试"
549
  )
550
  return parser.parse_args()
551
 
552
  if __name__ == "__main__":
553
  args = parse_args()
554
 
555
- # 设置输出模式
556
  OutputControl.set_verbose(not args.quiet)
557
 
558
- # 如果指定了查询内容,更新配置
559
  if args.ask:
560
  CONFIG["test_cases"]["basic"]["query"] = args.ask
561
 
562
- # 如果指定了创建配置文件
563
  if args.init_config:
564
  create_default_config()
565
  exit(0)
@@ -568,31 +547,31 @@ if __name__ == "__main__":
568
 
569
  try:
570
  if "all" in args.tests:
571
- # 运行所有测试
572
  if OutputControl.is_verbose():
573
- print("\n【基本功能测试】")
574
- run_test(test_non_stream_chat, "非流式调用测试")
575
- run_test(test_stream_chat, "流式调用测试")
576
 
577
  if OutputControl.is_verbose():
578
- print("\n【查询模式测试】")
579
- run_test(test_query_modes, "查询模式测试")
580
 
581
  if OutputControl.is_verbose():
582
- print("\n【错误处理测试】")
583
- run_test(test_error_handling, "错误处理测试")
584
- run_test(test_stream_error_handling, "流式错误处理测试")
585
  else:
586
- # 运行指定的测试
587
  for test_name in args.tests:
588
  if OutputControl.is_verbose():
589
- print(f"\n【运行测试: {test_name}】")
590
  run_test(test_cases[test_name], test_name)
591
  except Exception as e:
592
- print(f"\n发生错误: {str(e)}")
593
  finally:
594
- # 打印测试统计
595
  STATS.print_summary()
596
- # 如果指定了输出文件路径,则导出结果
597
  if args.output:
598
  STATS.export_results(args.output)
 
19
  from pathlib import Path
20
 
21
  class OutputControl:
22
+ """Output control class, manages the verbosity of test output"""
23
  _verbose: bool = False
24
 
25
  @classmethod
 
32
 
33
  @dataclass
34
  class TestResult:
35
+ """Test result data class"""
36
  name: str
37
  success: bool
38
  duration: float
 
44
  self.timestamp = datetime.now().isoformat()
45
 
46
  class TestStats:
47
+ """Test statistics"""
48
  def __init__(self):
49
  self.results: List[TestResult] = []
50
  self.start_time = datetime.now()
 
53
  self.results.append(result)
54
 
55
  def export_results(self, path: str = "test_results.json"):
56
+ """Export test results to a JSON file
 
57
  Args:
58
+ path: Output file path
59
  """
60
  results_data = {
61
  "start_time": self.start_time.isoformat(),
 
71
 
72
  with open(path, "w", encoding="utf-8") as f:
73
  json.dump(results_data, f, ensure_ascii=False, indent=2)
74
+ print(f"\nTest results saved to: {path}")
75
 
76
  def print_summary(self):
77
  total = len(self.results)
 
79
  failed = total - passed
80
  duration = sum(r.duration for r in self.results)
81
 
82
+ print("\n=== Test Summary ===")
83
+ print(f"Start time: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
84
+ print(f"Total duration: {duration:.2f} seconds")
85
+ print(f"Total tests: {total}")
86
+ print(f"Passed: {passed}")
87
+ print(f"Failed: {failed}")
88
 
89
  if failed > 0:
90
+ print("\nFailed tests:")
91
  for result in self.results:
92
  if not result.success:
93
  print(f"- {result.name}: {result.error}")
94
 
 
95
  DEFAULT_CONFIG = {
96
  "server": {
97
  "host": "localhost",
98
  "port": 9621,
99
  "model": "lightrag:latest",
100
+ "timeout": 30,
101
+ "max_retries": 3,
102
+ "retry_delay": 1
103
  },
104
  "test_cases": {
105
  "basic": {
 
109
  }
110
 
111
  def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
112
+ """Send an HTTP request with retry mechanism
 
113
  Args:
114
+ url: Request URL
115
+ data: Request data
116
+ stream: Whether to use streaming response
 
117
  Returns:
118
+ requests.Response: Response object
119
 
120
  Raises:
121
+ requests.exceptions.RequestException: Request failed after all retries
122
  """
123
  server_config = CONFIG["server"]
124
  max_retries = server_config["max_retries"]
 
135
  )
136
  return response
137
  except requests.exceptions.RequestException as e:
138
+ if attempt == max_retries - 1: # Last retry
139
  raise
140
+ print(f"\nRequest failed, retrying in {retry_delay} seconds: {str(e)}")
141
  time.sleep(retry_delay)
142
 
143
  def load_config() -> Dict[str, Any]:
144
+ """Load configuration file
 
 
 
145
 
146
+ First try to load from config.json in the current directory,
147
+ if it doesn't exist, use the default configuration
148
  Returns:
149
+ Configuration dictionary
150
  """
151
  config_path = Path("config.json")
152
  if config_path.exists():
 
155
  return DEFAULT_CONFIG
156
 
157
  def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
158
+ """Format and print JSON response data
 
159
  Args:
160
+ data: Data dictionary to print
161
+ title: Title to print
162
+ indent: Number of spaces for JSON indentation
163
  """
164
  if OutputControl.is_verbose():
165
  if title:
166
  print(f"\n=== {title} ===")
167
  print(json.dumps(data, ensure_ascii=False, indent=indent))
168
 
169
+ # Global configuration
170
  CONFIG = load_config()
171
 
172
  def get_base_url() -> str:
173
+ """Return the base URL"""
174
  server = CONFIG["server"]
175
  return f"http://{server['host']}:{server['port']}/api/chat"
176
 
 
179
  stream: bool = False,
180
  model: str = None
181
  ) -> Dict[str, Any]:
182
+ """Create basic request data
 
183
  Args:
184
+ content: User message content
185
+ stream: Whether to use streaming response
186
+ model: Model name
 
187
  Returns:
188
+ Dictionary containing complete request data
189
  """
190
  return {
191
  "model": model or CONFIG["server"]["model"],
 
198
  "stream": stream
199
  }
200
 
201
+ # Global test statistics
202
  STATS = TestStats()
203
 
204
  def run_test(func: Callable, name: str) -> None:
205
+ """Run a test and record the results
 
206
  Args:
207
+ func: Test function
208
+ name: Test name
209
  """
210
  start_time = time.time()
211
  try:
 
218
  raise
219
 
220
  def test_non_stream_chat():
221
+ """Test non-streaming call to /api/chat endpoint"""
222
  url = get_base_url()
223
  data = create_request_data(
224
  CONFIG["test_cases"]["basic"]["query"],
225
  stream=False
226
  )
227
 
228
+ # Send request
229
  response = make_request(url, data)
230
 
231
+ # Print response
232
  if OutputControl.is_verbose():
233
+ print("\n=== Non-streaming call response ===")
234
  response_json = response.json()
235
 
236
+ # Print response content
237
  print_json_response({
238
  "model": response_json["model"],
239
  "message": response_json["message"]
240
+ }, "Response content")
 
 
 
 
 
 
 
 
 
 
 
241
  def test_stream_chat():
242
+ """Test streaming call to /api/chat endpoint
243
 
244
+ Use JSON Lines format to process streaming responses, each line is a complete JSON object.
245
+ Response format:
246
  {
247
  "model": "lightrag:latest",
248
  "created_at": "2024-01-15T00:00:00Z",
249
  "message": {
250
  "role": "assistant",
251
+ "content": "Partial response content",
252
  "images": null
253
  },
254
  "done": false
255
  }
256
 
257
+ The last message will contain performance statistics, with done set to true.
258
  """
259
  url = get_base_url()
260
  data = create_request_data(
 
262
  stream=True
263
  )
264
 
265
+ # Send request and get streaming response
266
  response = make_request(url, data, stream=True)
267
 
268
  if OutputControl.is_verbose():
269
+ print("\n=== Streaming call response ===")
270
  output_buffer = []
271
  try:
272
  for line in response.iter_lines():
273
+ if line: # Skip empty lines
274
  try:
275
+ # Decode and parse JSON
276
  data = json.loads(line.decode('utf-8'))
277
+ if data.get("done", True): # If it's the completion marker
278
+ if "total_duration" in data: # Final performance statistics message
279
+ # print_json_response(data, "Performance statistics")
280
  break
281
+ else: # Normal content message
282
  message = data.get("message", {})
283
  content = message.get("content", "")
284
+ if content: # Only collect non-empty content
285
  output_buffer.append(content)
286
+ print(content, end="", flush=True) # Print content in real-time
287
  except json.JSONDecodeError:
288
  print("Error decoding JSON from response line")
289
  finally:
290
+ response.close() # Ensure the response connection is closed
291
 
292
+ # Print a newline
293
  print()
294
 
295
  def test_query_modes():
296
+ """Test different query mode prefixes
297
 
298
+ Supported query modes:
299
+ - /local: Local retrieval mode, searches only in highly relevant documents
300
+ - /global: Global retrieval mode, searches across all documents
301
+ - /naive: Naive mode, does not use any optimization strategies
302
+ - /hybrid: Hybrid mode (default), combines multiple strategies
303
+ - /mix: Mix mode
304
 
305
+ Each mode will return responses in the same format, but with different retrieval strategies.
306
  """
307
  url = get_base_url()
308
+ modes = ["local", "global", "naive", "hybrid", "mix"]
309
 
310
  for mode in modes:
311
  if OutputControl.is_verbose():
312
+ print(f"\n=== Testing /{mode} mode ===")
313
  data = create_request_data(
314
  f"/{mode} {CONFIG['test_cases']['basic']['query']}",
315
  stream=False
316
  )
317
 
318
+ # Send request
319
  response = make_request(url, data)
320
  response_json = response.json()
321
 
322
+ # Print response content
323
  print_json_response({
324
  "model": response_json["model"],
325
  "message": response_json["message"]
326
  })
327
 
328
  def create_error_test_data(error_type: str) -> Dict[str, Any]:
329
+ """Create request data for error testing
 
330
  Args:
331
+ error_type: Error type, supported:
332
+ - empty_messages: Empty message list
333
+ - invalid_role: Invalid role field
334
+ - missing_content: Missing content field
335
 
336
  Returns:
337
+ Request dictionary containing error data
338
  """
339
  error_data = {
340
  "empty_messages": {
 
347
  "messages": [
348
  {
349
  "invalid_role": "user",
350
+ "content": "Test message"
351
  }
352
  ],
353
  "stream": True
 
365
  return error_data.get(error_type, error_data["empty_messages"])
366
 
367
  def test_stream_error_handling():
368
+ """Test error handling for streaming responses
369
 
370
+ Test scenarios:
371
+ 1. Empty message list
372
+ 2. Message format error (missing required fields)
373
 
374
+ Error responses should be returned immediately without establishing a streaming connection.
375
+ The status code should be 4xx, and detailed error information should be returned.
376
  """
377
  url = get_base_url()
378
 
379
  if OutputControl.is_verbose():
380
+ print("\n=== Testing streaming response error handling ===")
381
 
382
+ # Test empty message list
383
  if OutputControl.is_verbose():
384
+ print("\n--- Testing empty message list (streaming) ---")
385
  data = create_error_test_data("empty_messages")
386
  response = make_request(url, data, stream=True)
387
+ print(f"Status code: {response.status_code}")
388
  if response.status_code != 200:
389
+ print_json_response(response.json(), "Error message")
390
  response.close()
391
 
392
+ # Test invalid role field
393
  if OutputControl.is_verbose():
394
+ print("\n--- Testing invalid role field (streaming) ---")
395
  data = create_error_test_data("invalid_role")
396
  response = make_request(url, data, stream=True)
397
+ print(f"Status code: {response.status_code}")
398
  if response.status_code != 200:
399
+ print_json_response(response.json(), "Error message")
400
  response.close()
401
 
402
+ # Test missing content field
403
  if OutputControl.is_verbose():
404
+ print("\n--- Testing missing content field (streaming) ---")
405
  data = create_error_test_data("missing_content")
406
  response = make_request(url, data, stream=True)
407
+ print(f"Status code: {response.status_code}")
408
  if response.status_code != 200:
409
+ print_json_response(response.json(), "Error message")
410
  response.close()
411
 
412
  def test_error_handling():
413
+ """Test error handling for non-streaming responses
414
 
415
+ Test scenarios:
416
+ 1. Empty message list
417
+ 2. Message format error (missing required fields)
418
 
419
+ Error response format:
420
  {
421
+ "detail": "Error description"
422
  }
423
 
424
+ All errors should return appropriate HTTP status codes and clear error messages.
425
  """
426
  url = get_base_url()
427
 
428
  if OutputControl.is_verbose():
429
+ print("\n=== Testing error handling ===")
430
 
431
+ # Test empty message list
432
  if OutputControl.is_verbose():
433
+ print("\n--- Testing empty message list ---")
434
  data = create_error_test_data("empty_messages")
435
+ data["stream"] = False # Change to non-streaming mode
436
  response = make_request(url, data)
437
+ print(f"Status code: {response.status_code}")
438
+ print_json_response(response.json(), "Error message")
439
 
440
+ # Test invalid role field
441
  if OutputControl.is_verbose():
442
+ print("\n--- Testing invalid role field ---")
443
  data = create_error_test_data("invalid_role")
444
+ data["stream"] = False # Change to non-streaming mode
445
  response = make_request(url, data)
446
+ print(f"Status code: {response.status_code}")
447
+ print_json_response(response.json(), "Error message")
448
 
449
+ # Test missing content field
450
  if OutputControl.is_verbose():
451
+ print("\n--- Testing missing content field ---")
452
  data = create_error_test_data("missing_content")
453
+ data["stream"] = False # Change to non-streaming mode
454
  response = make_request(url, data)
455
+ print(f"Status code: {response.status_code}")
456
+ print_json_response(response.json(), "Error message")
457
 
458
  def get_test_cases() -> Dict[str, Callable]:
459
+ """Get all available test cases
 
460
  Returns:
461
+ A dictionary mapping test names to test functions
462
  """
463
  return {
464
  "non_stream": test_non_stream_chat,
 
469
  }
470
 
471
  def create_default_config():
472
+ """Create a default configuration file"""
473
  config_path = Path("config.json")
474
  if not config_path.exists():
475
  with open(config_path, "w", encoding="utf-8") as f:
476
  json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
477
+ print(f"Default configuration file created: {config_path}")
478
 
479
  def parse_args() -> argparse.Namespace:
480
+ """Parse command line arguments"""
481
  parser = argparse.ArgumentParser(
482
+ description="LightRAG Ollama Compatibility Interface Testing",
483
  formatter_class=argparse.RawDescriptionHelpFormatter,
484
  epilog="""
485
+ Configuration file (config.json):
486
  {
487
  "server": {
488
+ "host": "localhost", # Server address
489
+ "port": 9621, # Server port
490
+ "model": "lightrag:latest" # Default model name
491
  },
492
  "test_cases": {
493
  "basic": {
494
+ "query": "Test query", # Basic query text
495
+ "stream_query": "Stream query" # Stream query text
496
  }
497
  }
498
  }
 
501
  parser.add_argument(
502
  "-q", "--quiet",
503
  action="store_true",
504
+ help="Silent mode, only display test result summary"
505
  )
506
  parser.add_argument(
507
  "-a", "--ask",
508
  type=str,
509
+ help="Specify query content, which will override the query settings in the configuration file"
510
  )
511
  parser.add_argument(
512
  "--init-config",
513
  action="store_true",
514
+ help="Create default configuration file"
515
  )
516
  parser.add_argument(
517
  "--output",
518
  type=str,
519
  default="",
520
+ help="Test result output file path, default is not to output to a file"
521
  )
522
  parser.add_argument(
523
  "--tests",
524
  nargs="+",
525
  choices=list(get_test_cases().keys()) + ["all"],
526
  default=["all"],
527
+ help="Test cases to run, options: %(choices)s. Use 'all' to run all tests"
528
  )
529
  return parser.parse_args()
530
 
531
  if __name__ == "__main__":
532
  args = parse_args()
533
 
534
+ # Set output mode
535
  OutputControl.set_verbose(not args.quiet)
536
 
537
+ # If query content is specified, update the configuration
538
  if args.ask:
539
  CONFIG["test_cases"]["basic"]["query"] = args.ask
540
 
541
+ # If specified to create a configuration file
542
  if args.init_config:
543
  create_default_config()
544
  exit(0)
 
547
 
548
  try:
549
  if "all" in args.tests:
550
+ # Run all tests
551
  if OutputControl.is_verbose():
552
+ print("\n【Basic Functionality Tests】")
553
+ run_test(test_non_stream_chat, "Non-streaming Call Test")
554
+ run_test(test_stream_chat, "Streaming Call Test")
555
 
556
  if OutputControl.is_verbose():
557
+ print("\n【Query Mode Tests】")
558
+ run_test(test_query_modes, "Query Mode Test")
559
 
560
  if OutputControl.is_verbose():
561
+ print("\n【Error Handling Tests】")
562
+ run_test(test_error_handling, "Error Handling Test")
563
+ run_test(test_stream_error_handling, "Streaming Error Handling Test")
564
  else:
565
+ # Run specified tests
566
  for test_name in args.tests:
567
  if OutputControl.is_verbose():
568
+ print(f"\n【Running Test: {test_name}】")
569
  run_test(test_cases[test_name], test_name)
570
  except Exception as e:
571
+ print(f"\nAn error occurred: {str(e)}")
572
  finally:
573
+ # Print test statistics
574
  STATS.print_summary()
575
+ # If an output file path is specified, export the results
576
  if args.output:
577
  STATS.export_results(args.output)