Andromeda / .github /mcp /mcp_pytest.py
kye's picture
Upload 73 files
ca4fc4d
raw
history blame
4.42 kB
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
"""Run pytest using MCP."""
import argparse
import time
from mcli.sdk import (RunConfig, RunStatus, create_run, follow_run_logs,
stop_run, wait_for_run_status)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--name',
type=str,
default='mcp-pytest',
help='Base name of run')
parser.add_argument('--cluster',
type=str,
default='r1z4',
help='Cluster to use')
parser.add_argument('--gpu_type',
type=str,
default='a100_40gb',
help='Type of GPU to use')
parser.add_argument('--gpu_num',
type=int,
default=2,
help='Number of the GPU to use')
parser.add_argument('--image',
type=str,
default='mosaicml/pytorch:latest',
help='Docker image to use')
parser.add_argument('--git_branch',
type=str,
help='Git branch to check out')
parser.add_argument(
'--git_commit',
type=str,
help='Git commit to check out. Overrides git_branch if specified')
parser.add_argument(
'--pr_number',
type=int,
help=
'PR number to check out. Overrides git_branch/git_commit if specified')
parser.add_argument('--pytest_markers',
type=str,
help='Markers to pass to pytest')
parser.add_argument('--pytest_command',
type=str,
help='Command to run pytest')
parser.add_argument('--timeout',
type=int,
default=1800,
help='Timeout for run (in seconds)')
args = parser.parse_args()
name = args.name
git_integration = {
'integration_type': 'git_repo',
'git_repo': 'mosaicml/llm-foundry',
'ssh_clone': 'False',
}
if args.git_branch is not None and args.git_commit is None:
name += f'-branch-{args.git_branch}'
git_integration['git_branch'] = args.git_branch
if args.git_commit is not None:
name += f'-commit-{args.git_commit}'
git_integration['git_commit'] = args.git_commit
command = 'cd llm-foundry'
# Checkout a specific PR if specified
if args.pr_number is not None:
name += f'-pr-{args.pr_number}'
command += f'''
git fetch origin pull/{args.pr_number}/head:pr_branch
git checkout pr_branch
'''
# Shorten name if too long
if len(name) > 56:
name = name[:56]
command += f'''
pip install --upgrade --user .[all]
export COMMON_ARGS="-v --durations=20 -m '{args.pytest_markers}'"
make test PYTEST='{args.pytest_command}' EXTRA_ARGS="$COMMON_ARGS --codeblocks"
make test-dist PYTEST='{args.pytest_command}' EXTRA_ARGS="$COMMON_ARGS" WORLD_SIZE=2
python -m coverage combine
python -m coverage report
'''
config = RunConfig(
name=name,
cluster=args.cluster,
gpu_type=args.gpu_type,
gpu_num=args.gpu_num,
image=args.image,
integrations=[git_integration],
command=command,
)
# Create run
run = create_run(config)
print(f'[GHA] Run created: {run.name}')
# Wait until run starts before fetching logs
run = wait_for_run_status(run, status='running')
start_time = time.time()
print('[GHA] Run started. Following logs...')
# Print logs
for line in follow_run_logs(run):
print(line, end='')
# Check if args.timeout seconds have elapsed
if time.time() - start_time > args.timeout:
print(
f'[GHA] Run timed out and did not complete in {args.timeout/60} minutes.'
)
run = stop_run(run)
print('[GHA] Run stopped.')
break
print('[GHA] Run completed. Waiting for run to finish...')
run = wait_for_run_status(run, status='completed')
# Fail if command exited with non-zero exit code or timed out
assert run.status == RunStatus.COMPLETED