|  | """ | 
					
						
						|  | LightRAG Ollama Compatibility Interface Test Script | 
					
						
						|  |  | 
					
						
						|  | This script tests the LightRAG's Ollama compatibility interface, including: | 
					
						
						|  | 1. Basic functionality tests (streaming and non-streaming responses) | 
					
						
						|  | 2. Query mode tests (local, global, naive, hybrid) | 
					
						
						|  | 3. Error handling tests (including streaming and non-streaming scenarios) | 
					
						
						|  |  | 
					
						
						|  | All responses use the JSON Lines format, complying with the Ollama API specification. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | import requests | 
					
						
						|  | import json | 
					
						
						|  | import argparse | 
					
						
						|  | import time | 
					
						
						|  | from typing import Dict, Any, Optional, List, Callable | 
					
						
						|  | from dataclasses import dataclass, asdict | 
					
						
						|  | from datetime import datetime | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | from enum import Enum, auto | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ErrorCode(Enum): | 
					
						
						|  | """Error codes for MCP errors""" | 
					
						
						|  |  | 
					
						
						|  | InvalidRequest = auto() | 
					
						
						|  | InternalError = auto() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class McpError(Exception): | 
					
						
						|  | """Base exception class for MCP errors""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, code: ErrorCode, message: str): | 
					
						
						|  | self.code = code | 
					
						
						|  | self.message = message | 
					
						
						|  | super().__init__(message) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | DEFAULT_CONFIG = { | 
					
						
						|  | "server": { | 
					
						
						|  | "host": "localhost", | 
					
						
						|  | "port": 9621, | 
					
						
						|  | "model": "lightrag:latest", | 
					
						
						|  | "timeout": 300, | 
					
						
						|  | "max_retries": 1, | 
					
						
						|  | "retry_delay": 1, | 
					
						
						|  | }, | 
					
						
						|  | "test_cases": { | 
					
						
						|  | "basic": {"query": "唐僧有几个徒弟"}, | 
					
						
						|  | "generate": {"query": "电视剧西游记导演是谁"}, | 
					
						
						|  | }, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | EXAMPLE_CONVERSATION = [ | 
					
						
						|  | {"role": "user", "content": "你好"}, | 
					
						
						|  | {"role": "assistant", "content": "你好!我是一个AI助手,很高兴为你服务。"}, | 
					
						
						|  | {"role": "user", "content": "Who are you?"}, | 
					
						
						|  | {"role": "assistant", "content": "I'm a Knowledge base query assistant."}, | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class OutputControl: | 
					
						
						|  | """Output control class, manages the verbosity of test output""" | 
					
						
						|  |  | 
					
						
						|  | _verbose: bool = False | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def set_verbose(cls, verbose: bool) -> None: | 
					
						
						|  | cls._verbose = verbose | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def is_verbose(cls) -> bool: | 
					
						
						|  | return cls._verbose | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class TestResult: | 
					
						
						|  | """Test result data class""" | 
					
						
						|  |  | 
					
						
						|  | name: str | 
					
						
						|  | success: bool | 
					
						
						|  | duration: float | 
					
						
						|  | error: Optional[str] = None | 
					
						
						|  | timestamp: str = "" | 
					
						
						|  |  | 
					
						
						|  | def __post_init__(self): | 
					
						
						|  | if not self.timestamp: | 
					
						
						|  | self.timestamp = datetime.now().isoformat() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TestStats: | 
					
						
						|  | """Test statistics""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self): | 
					
						
						|  | self.results: List[TestResult] = [] | 
					
						
						|  | self.start_time = datetime.now() | 
					
						
						|  |  | 
					
						
						|  | def add_result(self, result: TestResult): | 
					
						
						|  | self.results.append(result) | 
					
						
						|  |  | 
					
						
						|  | def export_results(self, path: str = "test_results.json"): | 
					
						
						|  | """Export test results to a JSON file | 
					
						
						|  | Args: | 
					
						
						|  | path: Output file path | 
					
						
						|  | """ | 
					
						
						|  | results_data = { | 
					
						
						|  | "start_time": self.start_time.isoformat(), | 
					
						
						|  | "end_time": datetime.now().isoformat(), | 
					
						
						|  | "results": [asdict(r) for r in self.results], | 
					
						
						|  | "summary": { | 
					
						
						|  | "total": len(self.results), | 
					
						
						|  | "passed": sum(1 for r in self.results if r.success), | 
					
						
						|  | "failed": sum(1 for r in self.results if not r.success), | 
					
						
						|  | "total_duration": sum(r.duration for r in self.results), | 
					
						
						|  | }, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | with open(path, "w", encoding="utf-8") as f: | 
					
						
						|  | json.dump(results_data, f, ensure_ascii=False, indent=2) | 
					
						
						|  | print(f"\nTest results saved to: {path}") | 
					
						
						|  |  | 
					
						
						|  | def print_summary(self): | 
					
						
						|  | total = len(self.results) | 
					
						
						|  | passed = sum(1 for r in self.results if r.success) | 
					
						
						|  | failed = total - passed | 
					
						
						|  | duration = sum(r.duration for r in self.results) | 
					
						
						|  |  | 
					
						
						|  | print("\n=== Test Summary ===") | 
					
						
						|  | print(f"Start time: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}") | 
					
						
						|  | print(f"Total duration: {duration:.2f} seconds") | 
					
						
						|  | print(f"Total tests: {total}") | 
					
						
						|  | print(f"Passed: {passed}") | 
					
						
						|  | print(f"Failed: {failed}") | 
					
						
						|  |  | 
					
						
						|  | if failed > 0: | 
					
						
						|  | print("\nFailed tests:") | 
					
						
						|  | for result in self.results: | 
					
						
						|  | if not result.success: | 
					
						
						|  | print(f"- {result.name}: {result.error}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def make_request( | 
					
						
						|  | url: str, data: Dict[str, Any], stream: bool = False, check_status: bool = True | 
					
						
						|  | ) -> requests.Response: | 
					
						
						|  | """Send an HTTP request with retry mechanism | 
					
						
						|  | Args: | 
					
						
						|  | url: Request URL | 
					
						
						|  | data: Request data | 
					
						
						|  | stream: Whether to use streaming response | 
					
						
						|  | check_status: Whether to check HTTP status code (default: True) | 
					
						
						|  | Returns: | 
					
						
						|  | requests.Response: Response object | 
					
						
						|  |  | 
					
						
						|  | Raises: | 
					
						
						|  | requests.exceptions.RequestException: Request failed after all retries | 
					
						
						|  | requests.exceptions.HTTPError: HTTP status code is not 200 (when check_status is True) | 
					
						
						|  | """ | 
					
						
						|  | server_config = CONFIG["server"] | 
					
						
						|  | max_retries = server_config["max_retries"] | 
					
						
						|  | retry_delay = server_config["retry_delay"] | 
					
						
						|  | timeout = server_config["timeout"] | 
					
						
						|  |  | 
					
						
						|  | for attempt in range(max_retries): | 
					
						
						|  | try: | 
					
						
						|  | response = requests.post(url, json=data, stream=stream, timeout=timeout) | 
					
						
						|  | if check_status and response.status_code != 200: | 
					
						
						|  | response.raise_for_status() | 
					
						
						|  | return response | 
					
						
						|  | except requests.exceptions.RequestException as e: | 
					
						
						|  | if attempt == max_retries - 1: | 
					
						
						|  | raise | 
					
						
						|  | print(f"\nRequest failed, retrying in {retry_delay} seconds: {str(e)}") | 
					
						
						|  | time.sleep(retry_delay) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_config() -> Dict[str, Any]: | 
					
						
						|  | """Load configuration file | 
					
						
						|  |  | 
					
						
						|  | First try to load from config.json in the current directory, | 
					
						
						|  | if it doesn't exist, use the default configuration | 
					
						
						|  | Returns: | 
					
						
						|  | Configuration dictionary | 
					
						
						|  | """ | 
					
						
						|  | config_path = Path("config.json") | 
					
						
						|  | if config_path.exists(): | 
					
						
						|  | with open(config_path, "r", encoding="utf-8") as f: | 
					
						
						|  | return json.load(f) | 
					
						
						|  | return DEFAULT_CONFIG | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None: | 
					
						
						|  | """Format and print JSON response data | 
					
						
						|  | Args: | 
					
						
						|  | data: Data dictionary to print | 
					
						
						|  | title: Title to print | 
					
						
						|  | indent: Number of spaces for JSON indentation | 
					
						
						|  | """ | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | if title: | 
					
						
						|  | print(f"\n=== {title} ===") | 
					
						
						|  | print(json.dumps(data, ensure_ascii=False, indent=indent)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | CONFIG = load_config() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_base_url(endpoint: str = "chat") -> str: | 
					
						
						|  | """Return the base URL for specified endpoint | 
					
						
						|  | Args: | 
					
						
						|  | endpoint: API endpoint name (chat or generate) | 
					
						
						|  | Returns: | 
					
						
						|  | Complete URL for the endpoint | 
					
						
						|  | """ | 
					
						
						|  | server = CONFIG["server"] | 
					
						
						|  | return f"http://{server['host']}:{server['port']}/api/{endpoint}" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def create_chat_request_data( | 
					
						
						|  | content: str, | 
					
						
						|  | stream: bool = False, | 
					
						
						|  | model: str = None, | 
					
						
						|  | conversation_history: List[Dict[str, str]] = None, | 
					
						
						|  | ) -> Dict[str, Any]: | 
					
						
						|  | """Create chat request data | 
					
						
						|  | Args: | 
					
						
						|  | content: User message content | 
					
						
						|  | stream: Whether to use streaming response | 
					
						
						|  | model: Model name | 
					
						
						|  | conversation_history: List of previous conversation messages | 
					
						
						|  | history_turns: Number of history turns to include | 
					
						
						|  | Returns: | 
					
						
						|  | Dictionary containing complete chat request data | 
					
						
						|  | """ | 
					
						
						|  | messages = conversation_history or [] | 
					
						
						|  | messages.append({"role": "user", "content": content}) | 
					
						
						|  |  | 
					
						
						|  | return { | 
					
						
						|  | "model": model or CONFIG["server"]["model"], | 
					
						
						|  | "messages": messages, | 
					
						
						|  | "stream": stream, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def create_generate_request_data( | 
					
						
						|  | prompt: str, | 
					
						
						|  | system: str = None, | 
					
						
						|  | stream: bool = False, | 
					
						
						|  | model: str = None, | 
					
						
						|  | options: Dict[str, Any] = None, | 
					
						
						|  | ) -> Dict[str, Any]: | 
					
						
						|  | """Create generate request data | 
					
						
						|  | Args: | 
					
						
						|  | prompt: Generation prompt | 
					
						
						|  | system: System prompt | 
					
						
						|  | stream: Whether to use streaming response | 
					
						
						|  | model: Model name | 
					
						
						|  | options: Additional options | 
					
						
						|  | Returns: | 
					
						
						|  | Dictionary containing complete generate request data | 
					
						
						|  | """ | 
					
						
						|  | data = { | 
					
						
						|  | "model": model or CONFIG["server"]["model"], | 
					
						
						|  | "prompt": prompt, | 
					
						
						|  | "stream": stream, | 
					
						
						|  | } | 
					
						
						|  | if system: | 
					
						
						|  | data["system"] = system | 
					
						
						|  | if options: | 
					
						
						|  | data["options"] = options | 
					
						
						|  | return data | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | STATS = TestStats() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def run_test(func: Callable, name: str) -> None: | 
					
						
						|  | """Run a test and record the results | 
					
						
						|  | Args: | 
					
						
						|  | func: Test function | 
					
						
						|  | name: Test name | 
					
						
						|  | """ | 
					
						
						|  | start_time = time.time() | 
					
						
						|  | try: | 
					
						
						|  | func() | 
					
						
						|  | duration = time.time() - start_time | 
					
						
						|  | STATS.add_result(TestResult(name, True, duration)) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | duration = time.time() - start_time | 
					
						
						|  | STATS.add_result(TestResult(name, False, duration, str(e))) | 
					
						
						|  | raise | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def test_non_stream_chat() -> None: | 
					
						
						|  | """Test non-streaming call to /api/chat endpoint""" | 
					
						
						|  | url = get_base_url() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | data = create_chat_request_data( | 
					
						
						|  | CONFIG["test_cases"]["basic"]["query"], | 
					
						
						|  | stream=False, | 
					
						
						|  | conversation_history=EXAMPLE_CONVERSATION, | 
					
						
						|  | ) | 
					
						
						|  | response = make_request(url, data) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print("\n=== Non-streaming call response ===") | 
					
						
						|  | response_json = response.json() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print_json_response( | 
					
						
						|  | {"model": response_json["model"], "message": response_json["message"]}, | 
					
						
						|  | "Response content", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def test_stream_chat() -> None: | 
					
						
						|  | """Test streaming call to /api/chat endpoint | 
					
						
						|  |  | 
					
						
						|  | Use JSON Lines format to process streaming responses, each line is a complete JSON object. | 
					
						
						|  | Response format: | 
					
						
						|  | { | 
					
						
						|  | "model": "lightrag:latest", | 
					
						
						|  | "created_at": "2024-01-15T00:00:00Z", | 
					
						
						|  | "message": { | 
					
						
						|  | "role": "assistant", | 
					
						
						|  | "content": "Partial response content", | 
					
						
						|  | "images": null | 
					
						
						|  | }, | 
					
						
						|  | "done": false | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | The last message will contain performance statistics, with done set to true. | 
					
						
						|  | """ | 
					
						
						|  | url = get_base_url() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | data = create_chat_request_data( | 
					
						
						|  | CONFIG["test_cases"]["basic"]["query"], | 
					
						
						|  | stream=True, | 
					
						
						|  | conversation_history=EXAMPLE_CONVERSATION, | 
					
						
						|  | ) | 
					
						
						|  | response = make_request(url, data, stream=True) | 
					
						
						|  |  | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print("\n=== Streaming call response ===") | 
					
						
						|  | output_buffer = [] | 
					
						
						|  | try: | 
					
						
						|  | for line in response.iter_lines(): | 
					
						
						|  | if line: | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | data = json.loads(line.decode("utf-8")) | 
					
						
						|  | if data.get("done", True): | 
					
						
						|  | if ( | 
					
						
						|  | "total_duration" in data | 
					
						
						|  | ): | 
					
						
						|  |  | 
					
						
						|  | break | 
					
						
						|  | else: | 
					
						
						|  | message = data.get("message", {}) | 
					
						
						|  | content = message.get("content", "") | 
					
						
						|  | if content: | 
					
						
						|  | output_buffer.append(content) | 
					
						
						|  | print( | 
					
						
						|  | content, end="", flush=True | 
					
						
						|  | ) | 
					
						
						|  | except json.JSONDecodeError: | 
					
						
						|  | print("Error decoding JSON from response line") | 
					
						
						|  | finally: | 
					
						
						|  | response.close() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def test_query_modes() -> None: | 
					
						
						|  | """Test different query mode prefixes | 
					
						
						|  |  | 
					
						
						|  | Supported query modes: | 
					
						
						|  | - /local: Local retrieval mode, searches only in highly relevant documents | 
					
						
						|  | - /global: Global retrieval mode, searches across all documents | 
					
						
						|  | - /naive: Naive mode, does not use any optimization strategies | 
					
						
						|  | - /hybrid: Hybrid mode (default), combines multiple strategies | 
					
						
						|  | - /mix: Mix mode | 
					
						
						|  |  | 
					
						
						|  | Each mode will return responses in the same format, but with different retrieval strategies. | 
					
						
						|  | """ | 
					
						
						|  | url = get_base_url() | 
					
						
						|  | modes = ["local", "global", "naive", "hybrid", "mix"] | 
					
						
						|  |  | 
					
						
						|  | for mode in modes: | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print(f"\n=== Testing /{mode} mode ===") | 
					
						
						|  | data = create_chat_request_data( | 
					
						
						|  | f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | response = make_request(url, data) | 
					
						
						|  | response_json = response.json() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print_json_response( | 
					
						
						|  | {"model": response_json["model"], "message": response_json["message"]} | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def create_error_test_data(error_type: str) -> Dict[str, Any]: | 
					
						
						|  | """Create request data for error testing | 
					
						
						|  | Args: | 
					
						
						|  | error_type: Error type, supported: | 
					
						
						|  | - empty_messages: Empty message list | 
					
						
						|  | - invalid_role: Invalid role field | 
					
						
						|  | - missing_content: Missing content field | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Request dictionary containing error data | 
					
						
						|  | """ | 
					
						
						|  | error_data = { | 
					
						
						|  | "empty_messages": {"model": "lightrag:latest", "messages": [], "stream": True}, | 
					
						
						|  | "invalid_role": { | 
					
						
						|  | "model": "lightrag:latest", | 
					
						
						|  | "messages": [{"invalid_role": "user", "content": "Test message"}], | 
					
						
						|  | "stream": True, | 
					
						
						|  | }, | 
					
						
						|  | "missing_content": { | 
					
						
						|  | "model": "lightrag:latest", | 
					
						
						|  | "messages": [{"role": "user"}], | 
					
						
						|  | "stream": True, | 
					
						
						|  | }, | 
					
						
						|  | } | 
					
						
						|  | return error_data.get(error_type, error_data["empty_messages"]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def test_stream_error_handling() -> None: | 
					
						
						|  | """Test error handling for streaming responses | 
					
						
						|  |  | 
					
						
						|  | Test scenarios: | 
					
						
						|  | 1. Empty message list | 
					
						
						|  | 2. Message format error (missing required fields) | 
					
						
						|  |  | 
					
						
						|  | Error responses should be returned immediately without establishing a streaming connection. | 
					
						
						|  | The status code should be 4xx, and detailed error information should be returned. | 
					
						
						|  | """ | 
					
						
						|  | url = get_base_url() | 
					
						
						|  |  | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print("\n=== Testing streaming response error handling ===") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print("\n--- Testing empty message list (streaming) ---") | 
					
						
						|  | data = create_error_test_data("empty_messages") | 
					
						
						|  | response = make_request(url, data, stream=True, check_status=False) | 
					
						
						|  | print(f"Status code: {response.status_code}") | 
					
						
						|  | if response.status_code != 200: | 
					
						
						|  | print_json_response(response.json(), "Error message") | 
					
						
						|  | response.close() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print("\n--- Testing invalid role field (streaming) ---") | 
					
						
						|  | data = create_error_test_data("invalid_role") | 
					
						
						|  | response = make_request(url, data, stream=True, check_status=False) | 
					
						
						|  | print(f"Status code: {response.status_code}") | 
					
						
						|  | if response.status_code != 200: | 
					
						
						|  | print_json_response(response.json(), "Error message") | 
					
						
						|  | response.close() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print("\n--- Testing missing content field (streaming) ---") | 
					
						
						|  | data = create_error_test_data("missing_content") | 
					
						
						|  | response = make_request(url, data, stream=True, check_status=False) | 
					
						
						|  | print(f"Status code: {response.status_code}") | 
					
						
						|  | if response.status_code != 200: | 
					
						
						|  | print_json_response(response.json(), "Error message") | 
					
						
						|  | response.close() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def test_error_handling() -> None: | 
					
						
						|  | """Test error handling for non-streaming responses | 
					
						
						|  |  | 
					
						
						|  | Test scenarios: | 
					
						
						|  | 1. Empty message list | 
					
						
						|  | 2. Message format error (missing required fields) | 
					
						
						|  |  | 
					
						
						|  | Error response format: | 
					
						
						|  | { | 
					
						
						|  | "detail": "Error description" | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | All errors should return appropriate HTTP status codes and clear error messages. | 
					
						
						|  | """ | 
					
						
						|  | url = get_base_url() | 
					
						
						|  |  | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print("\n=== Testing error handling ===") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print("\n--- Testing empty message list ---") | 
					
						
						|  | data = create_error_test_data("empty_messages") | 
					
						
						|  | data["stream"] = False | 
					
						
						|  | response = make_request(url, data, check_status=False) | 
					
						
						|  | print(f"Status code: {response.status_code}") | 
					
						
						|  | print_json_response(response.json(), "Error message") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print("\n--- Testing invalid role field ---") | 
					
						
						|  | data = create_error_test_data("invalid_role") | 
					
						
						|  | data["stream"] = False | 
					
						
						|  | response = make_request(url, data, check_status=False) | 
					
						
						|  | print(f"Status code: {response.status_code}") | 
					
						
						|  | print_json_response(response.json(), "Error message") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print("\n--- Testing missing content field ---") | 
					
						
						|  | data = create_error_test_data("missing_content") | 
					
						
						|  | data["stream"] = False | 
					
						
						|  | response = make_request(url, data, check_status=False) | 
					
						
						|  | print(f"Status code: {response.status_code}") | 
					
						
						|  | print_json_response(response.json(), "Error message") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def test_non_stream_generate() -> None: | 
					
						
						|  | """Test non-streaming call to /api/generate endpoint""" | 
					
						
						|  | url = get_base_url("generate") | 
					
						
						|  | data = create_generate_request_data( | 
					
						
						|  | CONFIG["test_cases"]["generate"]["query"], stream=False | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | response = make_request(url, data) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print("\n=== Non-streaming generate response ===") | 
					
						
						|  | response_json = response.json() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print(json.dumps(response_json, ensure_ascii=False, indent=2)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def test_stream_generate() -> None: | 
					
						
						|  | """Test streaming call to /api/generate endpoint""" | 
					
						
						|  | url = get_base_url("generate") | 
					
						
						|  | data = create_generate_request_data( | 
					
						
						|  | CONFIG["test_cases"]["generate"]["query"], stream=True | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | response = make_request(url, data, stream=True) | 
					
						
						|  |  | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print("\n=== Streaming generate response ===") | 
					
						
						|  | output_buffer = [] | 
					
						
						|  | try: | 
					
						
						|  | for line in response.iter_lines(): | 
					
						
						|  | if line: | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | data = json.loads(line.decode("utf-8")) | 
					
						
						|  | if data.get("done", True): | 
					
						
						|  | if ( | 
					
						
						|  | "total_duration" in data | 
					
						
						|  | ): | 
					
						
						|  | break | 
					
						
						|  | else: | 
					
						
						|  | content = data.get("response", "") | 
					
						
						|  | if content: | 
					
						
						|  | output_buffer.append(content) | 
					
						
						|  | print( | 
					
						
						|  | content, end="", flush=True | 
					
						
						|  | ) | 
					
						
						|  | except json.JSONDecodeError: | 
					
						
						|  | print("Error decoding JSON from response line") | 
					
						
						|  | finally: | 
					
						
						|  | response.close() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def test_generate_with_system() -> None: | 
					
						
						|  | """Test generate with system prompt""" | 
					
						
						|  | url = get_base_url("generate") | 
					
						
						|  | data = create_generate_request_data( | 
					
						
						|  | CONFIG["test_cases"]["generate"]["query"], | 
					
						
						|  | system="你是一个知识渊博的助手", | 
					
						
						|  | stream=False, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | response = make_request(url, data) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print("\n=== Generate with system prompt response ===") | 
					
						
						|  | response_json = response.json() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print_json_response( | 
					
						
						|  | { | 
					
						
						|  | "model": response_json["model"], | 
					
						
						|  | "response": response_json["response"], | 
					
						
						|  | "done": response_json["done"], | 
					
						
						|  | }, | 
					
						
						|  | "Response content", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def test_generate_error_handling() -> None: | 
					
						
						|  | """Test error handling for generate endpoint""" | 
					
						
						|  | url = get_base_url("generate") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print("\n=== Testing empty prompt ===") | 
					
						
						|  | data = create_generate_request_data("", stream=False) | 
					
						
						|  | response = make_request(url, data, check_status=False) | 
					
						
						|  | print(f"Status code: {response.status_code}") | 
					
						
						|  | print_json_response(response.json(), "Error message") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print("\n=== Testing invalid options ===") | 
					
						
						|  | data = create_generate_request_data( | 
					
						
						|  | CONFIG["test_cases"]["basic"]["query"], | 
					
						
						|  | options={"invalid_option": "value"}, | 
					
						
						|  | stream=False, | 
					
						
						|  | ) | 
					
						
						|  | response = make_request(url, data, check_status=False) | 
					
						
						|  | print(f"Status code: {response.status_code}") | 
					
						
						|  | print_json_response(response.json(), "Error message") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def test_generate_concurrent() -> None: | 
					
						
						|  | """Test concurrent generate requests""" | 
					
						
						|  | import asyncio | 
					
						
						|  | import aiohttp | 
					
						
						|  | from contextlib import asynccontextmanager | 
					
						
						|  |  | 
					
						
						|  | @asynccontextmanager | 
					
						
						|  | async def get_session(): | 
					
						
						|  | async with aiohttp.ClientSession() as session: | 
					
						
						|  | yield session | 
					
						
						|  |  | 
					
						
						|  | async def make_request(session, prompt: str, request_id: int): | 
					
						
						|  | url = get_base_url("generate") | 
					
						
						|  | data = create_generate_request_data(prompt, stream=False) | 
					
						
						|  | try: | 
					
						
						|  | async with session.post(url, json=data) as response: | 
					
						
						|  | if response.status != 200: | 
					
						
						|  | error_msg = ( | 
					
						
						|  | f"Request {request_id} failed with status {response.status}" | 
					
						
						|  | ) | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print(f"\n{error_msg}") | 
					
						
						|  | raise McpError(ErrorCode.InternalError, error_msg) | 
					
						
						|  | result = await response.json() | 
					
						
						|  | if "error" in result: | 
					
						
						|  | error_msg = ( | 
					
						
						|  | f"Request {request_id} returned error: {result['error']}" | 
					
						
						|  | ) | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print(f"\n{error_msg}") | 
					
						
						|  | raise McpError(ErrorCode.InternalError, error_msg) | 
					
						
						|  | return result | 
					
						
						|  | except Exception as e: | 
					
						
						|  | error_msg = f"Request {request_id} failed: {str(e)}" | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print(f"\n{error_msg}") | 
					
						
						|  | raise McpError(ErrorCode.InternalError, error_msg) | 
					
						
						|  |  | 
					
						
						|  | async def run_concurrent_requests(): | 
					
						
						|  | prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"] | 
					
						
						|  |  | 
					
						
						|  | async with get_session() as session: | 
					
						
						|  | tasks = [ | 
					
						
						|  | make_request(session, prompt, i + 1) for i, prompt in enumerate(prompts) | 
					
						
						|  | ] | 
					
						
						|  | results = await asyncio.gather(*tasks, return_exceptions=True) | 
					
						
						|  |  | 
					
						
						|  | success_results = [] | 
					
						
						|  | error_messages = [] | 
					
						
						|  |  | 
					
						
						|  | for i, result in enumerate(results): | 
					
						
						|  | if isinstance(result, Exception): | 
					
						
						|  | error_messages.append(f"Request {i+1} failed: {str(result)}") | 
					
						
						|  | else: | 
					
						
						|  | success_results.append((i + 1, result)) | 
					
						
						|  |  | 
					
						
						|  | if error_messages: | 
					
						
						|  | for req_id, result in success_results: | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print(f"\nRequest {req_id} succeeded:") | 
					
						
						|  | print_json_response(result) | 
					
						
						|  |  | 
					
						
						|  | error_summary = "\n".join(error_messages) | 
					
						
						|  | raise McpError( | 
					
						
						|  | ErrorCode.InternalError, | 
					
						
						|  | f"Some concurrent requests failed:\n{error_summary}", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return results | 
					
						
						|  |  | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print("\n=== Testing concurrent generate requests ===") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | results = asyncio.run(run_concurrent_requests()) | 
					
						
						|  |  | 
					
						
						|  | for i, result in enumerate(results, 1): | 
					
						
						|  | print(f"\nRequest {i} result:") | 
					
						
						|  | print_json_response(result) | 
					
						
						|  | except McpError: | 
					
						
						|  |  | 
					
						
						|  | raise | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_test_cases() -> Dict[str, Callable]: | 
					
						
						|  | """Get all available test cases | 
					
						
						|  | Returns: | 
					
						
						|  | A dictionary mapping test names to test functions | 
					
						
						|  | """ | 
					
						
						|  | return { | 
					
						
						|  | "non_stream": test_non_stream_chat, | 
					
						
						|  | "stream": test_stream_chat, | 
					
						
						|  | "modes": test_query_modes, | 
					
						
						|  | "errors": test_error_handling, | 
					
						
						|  | "stream_errors": test_stream_error_handling, | 
					
						
						|  | "non_stream_generate": test_non_stream_generate, | 
					
						
						|  | "stream_generate": test_stream_generate, | 
					
						
						|  | "generate_with_system": test_generate_with_system, | 
					
						
						|  | "generate_errors": test_generate_error_handling, | 
					
						
						|  | "generate_concurrent": test_generate_concurrent, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def create_default_config(): | 
					
						
						|  | """Create a default configuration file""" | 
					
						
						|  | config_path = Path("config.json") | 
					
						
						|  | if not config_path.exists(): | 
					
						
						|  | with open(config_path, "w", encoding="utf-8") as f: | 
					
						
						|  | json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2) | 
					
						
						|  | print(f"Default configuration file created: {config_path}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def parse_args() -> argparse.Namespace: | 
					
						
						|  | """Parse command line arguments""" | 
					
						
						|  | parser = argparse.ArgumentParser( | 
					
						
						|  | description="LightRAG Ollama Compatibility Interface Testing", | 
					
						
						|  | formatter_class=argparse.RawDescriptionHelpFormatter, | 
					
						
						|  | epilog=""" | 
					
						
						|  | Configuration file (config.json): | 
					
						
						|  | { | 
					
						
						|  | "server": { | 
					
						
						|  | "host": "localhost",      # Server address | 
					
						
						|  | "port": 9621,            # Server port | 
					
						
						|  | "model": "lightrag:latest" # Default model name | 
					
						
						|  | }, | 
					
						
						|  | "test_cases": { | 
					
						
						|  | "basic": { | 
					
						
						|  | "query": "Test query",      # Basic query text | 
					
						
						|  | "stream_query": "Stream query" # Stream query text | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  | """, | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "-q", | 
					
						
						|  | "--quiet", | 
					
						
						|  | action="store_true", | 
					
						
						|  | help="Silent mode, only display test result summary", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "-a", | 
					
						
						|  | "--ask", | 
					
						
						|  | type=str, | 
					
						
						|  | help="Specify query content, which will override the query settings in the configuration file", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--init-config", action="store_true", help="Create default configuration file" | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--output", | 
					
						
						|  | type=str, | 
					
						
						|  | default="", | 
					
						
						|  | help="Test result output file path, default is not to output to a file", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--tests", | 
					
						
						|  | nargs="+", | 
					
						
						|  | choices=list(get_test_cases().keys()) + ["all"], | 
					
						
						|  | default=["all"], | 
					
						
						|  | help="Test cases to run, options: %(choices)s. Use 'all' to run all tests (except error tests)", | 
					
						
						|  | ) | 
					
						
						|  | return parser.parse_args() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | args = parse_args() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | OutputControl.set_verbose(not args.quiet) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if args.ask: | 
					
						
						|  | CONFIG["test_cases"]["basic"]["query"] = args.ask | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if args.init_config: | 
					
						
						|  | create_default_config() | 
					
						
						|  | exit(0) | 
					
						
						|  |  | 
					
						
						|  | test_cases = get_test_cases() | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | if "all" in args.tests: | 
					
						
						|  |  | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print("\n【Chat API Tests】") | 
					
						
						|  | run_test(test_non_stream_chat, "Non-streaming Chat Test") | 
					
						
						|  | run_test(test_stream_chat, "Streaming Chat Test") | 
					
						
						|  | run_test(test_query_modes, "Chat Query Mode Test") | 
					
						
						|  |  | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print("\n【Generate API Tests】") | 
					
						
						|  | run_test(test_non_stream_generate, "Non-streaming Generate Test") | 
					
						
						|  | run_test(test_stream_generate, "Streaming Generate Test") | 
					
						
						|  | run_test(test_generate_with_system, "Generate with System Prompt Test") | 
					
						
						|  | run_test(test_generate_concurrent, "Generate Concurrent Test") | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | for test_name in args.tests: | 
					
						
						|  | if OutputControl.is_verbose(): | 
					
						
						|  | print(f"\n【Running Test: {test_name}】") | 
					
						
						|  | run_test(test_cases[test_name], test_name) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"\nAn error occurred: {str(e)}") | 
					
						
						|  | finally: | 
					
						
						|  |  | 
					
						
						|  | STATS.print_summary() | 
					
						
						|  |  | 
					
						
						|  | if args.output: | 
					
						
						|  | STATS.export_results(args.output) | 
					
						
						|  |  |