#!/usr/bin/env python # Sample usage: # python env-setup.py --version 1.5 --apt-packages libomp5 import argparse import collections from datetime import datetime import os import platform import re import requests import subprocess import threading import sys VersionConfig = collections.namedtuple('VersionConfig', ['wheels', 'tpu', 'py_version', 'cuda_version']) DEFAULT_CUDA_VERSION = '10.2' OLDEST_VERSION = datetime.strptime('20200318', '%Y%m%d') OLDEST_GPU_VERSION = datetime.strptime('20200707', '%Y%m%d') DIST_BUCKET = 'gs://tpu-pytorch/wheels' TORCH_WHEEL_TMPL = 'torch-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl' TORCH_XLA_WHEEL_TMPL = 'torch_xla-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl' TORCHVISION_WHEEL_TMPL = 'torchvision-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl' def is_gpu_runtime(): return os.environ.get('COLAB_GPU', 0) == 1 def is_tpu_runtime(): return 'TPU_NAME' in os.environ def update_tpu_runtime(tpu_name, version): print(f'Updating TPU runtime to {version.tpu} ...') try: import cloud_tpu_client except ImportError: subprocess.call([sys.executable, '-m', 'pip', 'install', 'cloud-tpu-client']) import cloud_tpu_client client = cloud_tpu_client.Client(tpu_name) client.configure_tpu_version(version.tpu) print('Done updating TPU runtime') def get_py_version(): version_tuple = platform.python_version_tuple() return version_tuple[0] + version_tuple[1] # major_version + minor_version def get_cuda_version(): if is_gpu_runtime(): # cuda available, install cuda wheels return DEFAULT_CUDA_VERSION def get_version(version): cuda_version = get_cuda_version() if version == 'nightly': return VersionConfig( 'nightly', 'pytorch-nightly', get_py_version(), cuda_version) version_date = None try: version_date = datetime.strptime(version, '%Y%m%d') except ValueError: pass # Not a dated nightly. if version_date: if cuda_version and version_date < OLDEST_GPU_VERSION: raise ValueError( f'Oldest nightly version build with CUDA available is {OLDEST_GPU_VERSION}') elif not cuda_version and version_date < OLDEST_VERSION: raise ValueError(f'Oldest nightly version available is {OLDEST_VERSION}') return VersionConfig(f'nightly+{version}', f'pytorch-dev{version}', get_py_version(), cuda_version) version_regex = re.compile('^(\d+\.)+\d+$') if not version_regex.match(version): raise ValueError(f'{version} is an invalid torch_xla version pattern') return VersionConfig( version, f'pytorch-{version}', get_py_version(), cuda_version) def install_vm(version, apt_packages, is_root=False): dist_bucket = DIST_BUCKET if version.cuda_version: dist_bucket = os.path.join( DIST_BUCKET, 'cuda/{}'.format(version.cuda_version.replace('.', ''))) torch_whl = TORCH_WHEEL_TMPL.format( whl_version=version.wheels, py_version=version.py_version) torch_whl_path = os.path.join(dist_bucket, torch_whl) torch_xla_whl = TORCH_XLA_WHEEL_TMPL.format( whl_version=version.wheels, py_version=version.py_version) torch_xla_whl_path = os.path.join(dist_bucket, torch_xla_whl) torchvision_whl = TORCHVISION_WHEEL_TMPL.format( whl_version=version.wheels, py_version=version.py_version) torchvision_whl_path = os.path.join(dist_bucket, torchvision_whl) apt_cmd = ['apt-get', 'install', '-y'] apt_cmd.extend(apt_packages) if not is_root: # Colab/Kaggle run as root, but not GCE VMs so we need privilege apt_cmd.insert(0, 'sudo') installation_cmds = [ [sys.executable, '-m', 'pip', 'uninstall', '-y', 'torch', 'torchvision'], ['gsutil', 'cp', torch_whl_path, '.'], ['gsutil', 'cp', torch_xla_whl_path, '.'], ['gsutil', 'cp', torchvision_whl_path, '.'], [sys.executable, '-m', 'pip', 'install', torch_whl], [sys.executable, '-m', 'pip', 'install', torch_xla_whl], [sys.executable, '-m', 'pip', 'install', torchvision_whl], apt_cmd, ] for cmd in installation_cmds: subprocess.call(cmd) def run_setup(args): version = get_version(args.version) # Update TPU print('Updating... This may take around 2 minutes.') if is_tpu_runtime(): update = threading.Thread( target=update_tpu_runtime, args=( args.tpu, version, )) update.start() install_vm(version, args.apt_packages, is_root=not args.tpu) if is_tpu_runtime(): update.join() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( '--version', type=str, default='20200515', help='Versions to install (nightly, release version, or YYYYMMDD).', ) parser.add_argument( '--apt-packages', nargs='+', default=['libomp5'], help='List of apt packages to install', ) parser.add_argument( '--tpu', type=str, help='[GCP] Name of the TPU (same zone, project as VM running script)', ) args = parser.parse_args() run_setup(args)