|
|
|
|
|
|
|
"""Run pytest using MCP.""" |
|
|
|
import argparse |
|
import time |
|
|
|
from mcli.sdk import (RunConfig, RunStatus, create_run, follow_run_logs, |
|
stop_run, wait_for_run_status) |
|
|
|
if __name__ == '__main__': |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--name', |
|
type=str, |
|
default='mcp-pytest', |
|
help='Base name of run') |
|
parser.add_argument('--cluster', |
|
type=str, |
|
default='r1z4', |
|
help='Cluster to use') |
|
parser.add_argument('--gpu_type', |
|
type=str, |
|
default='a100_40gb', |
|
help='Type of GPU to use') |
|
parser.add_argument('--gpu_num', |
|
type=int, |
|
default=2, |
|
help='Number of the GPU to use') |
|
parser.add_argument('--image', |
|
type=str, |
|
default='mosaicml/pytorch:latest', |
|
help='Docker image to use') |
|
parser.add_argument('--git_branch', |
|
type=str, |
|
help='Git branch to check out') |
|
parser.add_argument( |
|
'--git_commit', |
|
type=str, |
|
help='Git commit to check out. Overrides git_branch if specified') |
|
parser.add_argument( |
|
'--pr_number', |
|
type=int, |
|
help= |
|
'PR number to check out. Overrides git_branch/git_commit if specified') |
|
parser.add_argument('--pytest_markers', |
|
type=str, |
|
help='Markers to pass to pytest') |
|
parser.add_argument('--pytest_command', |
|
type=str, |
|
help='Command to run pytest') |
|
parser.add_argument('--timeout', |
|
type=int, |
|
default=1800, |
|
help='Timeout for run (in seconds)') |
|
args = parser.parse_args() |
|
|
|
name = args.name |
|
git_integration = { |
|
'integration_type': 'git_repo', |
|
'git_repo': 'mosaicml/llm-foundry', |
|
'ssh_clone': 'False', |
|
} |
|
if args.git_branch is not None and args.git_commit is None: |
|
name += f'-branch-{args.git_branch}' |
|
git_integration['git_branch'] = args.git_branch |
|
if args.git_commit is not None: |
|
name += f'-commit-{args.git_commit}' |
|
git_integration['git_commit'] = args.git_commit |
|
|
|
command = 'cd llm-foundry' |
|
|
|
|
|
if args.pr_number is not None: |
|
name += f'-pr-{args.pr_number}' |
|
command += f''' |
|
|
|
git fetch origin pull/{args.pr_number}/head:pr_branch |
|
|
|
git checkout pr_branch |
|
|
|
''' |
|
|
|
|
|
if len(name) > 56: |
|
name = name[:56] |
|
|
|
command += f''' |
|
|
|
pip install --upgrade --user .[all] |
|
|
|
export COMMON_ARGS="-v --durations=20 -m '{args.pytest_markers}'" |
|
|
|
make test PYTEST='{args.pytest_command}' EXTRA_ARGS="$COMMON_ARGS --codeblocks" |
|
|
|
make test-dist PYTEST='{args.pytest_command}' EXTRA_ARGS="$COMMON_ARGS" WORLD_SIZE=2 |
|
|
|
python -m coverage combine |
|
|
|
python -m coverage report |
|
''' |
|
|
|
config = RunConfig( |
|
name=name, |
|
cluster=args.cluster, |
|
gpu_type=args.gpu_type, |
|
gpu_num=args.gpu_num, |
|
image=args.image, |
|
integrations=[git_integration], |
|
command=command, |
|
) |
|
|
|
|
|
run = create_run(config) |
|
print(f'[GHA] Run created: {run.name}') |
|
|
|
|
|
run = wait_for_run_status(run, status='running') |
|
start_time = time.time() |
|
print('[GHA] Run started. Following logs...') |
|
|
|
|
|
for line in follow_run_logs(run): |
|
print(line, end='') |
|
|
|
if time.time() - start_time > args.timeout: |
|
print( |
|
f'[GHA] Run timed out and did not complete in {args.timeout/60} minutes.' |
|
) |
|
run = stop_run(run) |
|
print('[GHA] Run stopped.') |
|
break |
|
|
|
print('[GHA] Run completed. Waiting for run to finish...') |
|
run = wait_for_run_status(run, status='completed') |
|
|
|
|
|
assert run.status == RunStatus.COMPLETED |
|
|