File size: 5,977 Bytes
88aba71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import pytest
from unittest import mock
import sys
import os
import shutil
import functools
import subprocess
import time
from typing import Union, Optional, cast
from weclone.utils.log import logger
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
PROJECT_ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
server_process: Optional[subprocess.Popen] = None
test_logger = logger.bind()
test_logger.remove()
test_logger.add(
sys.stderr,
format="<yellow><b>{message}</b></yellow>",
colorize=True,
level="INFO",
)
def print_test_header(test_name: str):
line_length = 100
test_logger.info("\n" + "─" * line_length)
title = f" Testing Phase: {test_name} "
padding_total = line_length - len(title)
padding_left = padding_total // 2
padding_right = padding_total - padding_left
test_logger.info(" " * padding_left + title + " " * padding_right)
test_logger.info("─" * line_length)
def setup_make_dataset_test_data():
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
DATASET_CSV_DIR = os.path.join(PROJECT_ROOT, "dataset", "csv")
TESTS_DIR = os.path.dirname(__file__)
TEST_DATA_PERSON_DIR = os.path.join(TESTS_DIR, "tests_data", "test_person")
os.makedirs(DATASET_CSV_DIR, exist_ok=True)
if os.path.exists(DATASET_CSV_DIR) and os.listdir(DATASET_CSV_DIR):
if all(f.startswith('.') or f.lower() == 'readme.md' for f in os.listdir(DATASET_CSV_DIR)):
for item_name in os.listdir(TEST_DATA_PERSON_DIR):
source_item_path = os.path.join(TEST_DATA_PERSON_DIR, item_name)
if os.path.isfile(source_item_path) and item_name.lower().endswith('.csv'):
destination_item_path = os.path.join(DATASET_CSV_DIR, item_name)
shutil.copy2(source_item_path, destination_item_path)
def run_cli_command(command: list[str], timeout: int | None = None, background: bool = False) -> Union[subprocess.CompletedProcess, subprocess.Popen]:
"""Execute a CLI command and return the result.
Args:
command: List of commands to execute.
timeout: Timeout in seconds.
background: Whether to run in the background.
Returns:
If background=True, returns a Popen object; otherwise, returns a CompletedProcess object.
"""
env = os.environ.copy()
env["WECLONE_CONFIG_PATH"] = "tests/full_pipe.jsonc" # Set environment variable
if background:
process = subprocess.Popen(
[sys.executable, "-m", "weclone.cli"] + command,
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
text=True,
cwd=PROJECT_ROOT_DIR,
env=env
)
time.sleep(2)
return process
else:
process = subprocess.run(
[sys.executable, "-m", "weclone.cli"] + command,
stderr=None,
stdout=None,
text=True,
cwd=PROJECT_ROOT_DIR, # Execute in the project root directory
timeout=timeout,
env=env # Pass the modified environment variables
)
return process
@pytest.mark.order(1)
def test_cli_make_dataset():
"""Test the make-dataset command."""
print_test_header("make-dataset")
setup_make_dataset_test_data()
result = run_cli_command(["make-dataset"])
assert result.returncode == 0, "make-dataset command execution failed"
@pytest.mark.order(2)
def test_cli_train_sft():
"""Test the train-sft command."""
print_test_header("train-sft")
try:
result = run_cli_command(["train-sft"])
assert result.returncode == 0, "train-sft command failed or did not fail fast as expected"
except subprocess.TimeoutExpired:
test_logger.info("train-sft command terminated due to timeout, which is acceptable in testing, indicating the command has started execution.")
pass
except Exception as e:
pytest.fail(f"An unexpected error occurred during train-sft command execution: {e}")
@pytest.mark.order(3)
def test_cli_webchat_demo():
"""Test the webchat-demo command."""
print_test_header("webchat-demo")
with mock.patch("weclone.eval.web_demo.main") as mock_main:
mock_main.return_value = None
try:
result = run_cli_command(["webchat-demo"], timeout=5)
assert result.returncode == 0, "webchat-demo command execution failed"
except subprocess.TimeoutExpired:
pass
@pytest.mark.order(4)
def test_cli_server():
"""Test the server command.
Start the server in the background, without blocking subsequent tests.
"""
print_test_header("server (background)")
global server_process
server_process = cast(subprocess.Popen, run_cli_command(["server"], background=True))
assert server_process.poll() is None, "Server startup failed"
test_logger.info("服务器已在后台启动")
@pytest.mark.order(5)
def test_cli_test_model():
"""Test the test-model command.
Use the server for testing, and shut down the server after the test is complete.
"""
print_test_header("test-model")
try:
result = run_cli_command(["test-model"])
assert result.returncode == 0, "test-model command execution failed"
finally:
global server_process
if server_process is not None and server_process.poll() is None:
test_logger.info("测试完成,正在关闭服务器...")
server_process.terminate()
server_process.wait(timeout=5)
if server_process.poll() is None:
server_process.kill() # Force kill if the process hasn't terminated
test_logger.info("服务器已关闭")
|