Update pytorch-xla-env-setup.py
Browse files- pytorch-xla-env-setup.py +161 -0
pytorch-xla-env-setup.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Sample usage:
|
3 |
+
# python env-setup.py --version 1.5 --apt-packages libomp5
|
4 |
+
import argparse
|
5 |
+
import collections
|
6 |
+
from datetime import datetime
|
7 |
+
import os
|
8 |
+
import platform
|
9 |
+
import re
|
10 |
+
import requests
|
11 |
+
import subprocess
|
12 |
+
import threading
|
13 |
+
import sys
|
14 |
+
|
15 |
+
VersionConfig = collections.namedtuple('VersionConfig',
|
16 |
+
['wheels', 'tpu', 'py_version', 'cuda_version'])
|
17 |
+
DEFAULT_CUDA_VERSION = '10.2'
|
18 |
+
OLDEST_VERSION = datetime.strptime('20200318', '%Y%m%d')
|
19 |
+
OLDEST_GPU_VERSION = datetime.strptime('20200707', '%Y%m%d')
|
20 |
+
DIST_BUCKET = 'gs://tpu-pytorch/wheels'
|
21 |
+
TORCH_WHEEL_TMPL = 'torch-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl'
|
22 |
+
TORCH_XLA_WHEEL_TMPL = 'torch_xla-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl'
|
23 |
+
TORCHVISION_WHEEL_TMPL = 'torchvision-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl'
|
24 |
+
|
25 |
+
|
26 |
+
def is_gpu_runtime():
|
27 |
+
return os.environ.get('COLAB_GPU', 0) == 1
|
28 |
+
|
29 |
+
|
30 |
+
def is_tpu_runtime():
|
31 |
+
return 'TPU_NAME' in os.environ
|
32 |
+
|
33 |
+
|
34 |
+
def update_tpu_runtime(tpu_name, version):
|
35 |
+
print(f'Updating TPU runtime to {version.tpu} ...')
|
36 |
+
|
37 |
+
try:
|
38 |
+
import cloud_tpu_client
|
39 |
+
except ImportError:
|
40 |
+
subprocess.call([sys.executable, '-m', 'pip', 'install', 'cloud-tpu-client'])
|
41 |
+
import cloud_tpu_client
|
42 |
+
|
43 |
+
client = cloud_tpu_client.Client(tpu_name)
|
44 |
+
client.configure_tpu_version(version.tpu)
|
45 |
+
print('Done updating TPU runtime')
|
46 |
+
|
47 |
+
|
48 |
+
def get_py_version():
|
49 |
+
version_tuple = platform.python_version_tuple()
|
50 |
+
return version_tuple[0] + version_tuple[1] # major_version + minor_version
|
51 |
+
|
52 |
+
|
53 |
+
def get_cuda_version():
|
54 |
+
if is_gpu_runtime():
|
55 |
+
# cuda available, install cuda wheels
|
56 |
+
return DEFAULT_CUDA_VERSION
|
57 |
+
|
58 |
+
|
59 |
+
def get_version(version):
|
60 |
+
cuda_version = get_cuda_version()
|
61 |
+
if version == 'nightly':
|
62 |
+
return VersionConfig(
|
63 |
+
'nightly', 'pytorch-nightly', get_py_version(), cuda_version)
|
64 |
+
|
65 |
+
version_date = None
|
66 |
+
try:
|
67 |
+
version_date = datetime.strptime(version, '%Y%m%d')
|
68 |
+
except ValueError:
|
69 |
+
pass # Not a dated nightly.
|
70 |
+
|
71 |
+
if version_date:
|
72 |
+
if cuda_version and version_date < OLDEST_GPU_VERSION:
|
73 |
+
raise ValueError(
|
74 |
+
f'Oldest nightly version build with CUDA available is {OLDEST_GPU_VERSION}')
|
75 |
+
elif not cuda_version and version_date < OLDEST_VERSION:
|
76 |
+
raise ValueError(f'Oldest nightly version available is {OLDEST_VERSION}')
|
77 |
+
return VersionConfig(f'nightly+{version}', f'pytorch-dev{version}',
|
78 |
+
get_py_version(), cuda_version)
|
79 |
+
|
80 |
+
version_regex = re.compile('^(\d+\.)+\d+$')
|
81 |
+
if not version_regex.match(version):
|
82 |
+
raise ValueError(f'{version} is an invalid torch_xla version pattern')
|
83 |
+
return VersionConfig(
|
84 |
+
version, f'pytorch-{version}', get_py_version(), cuda_version)
|
85 |
+
|
86 |
+
|
87 |
+
def install_vm(version, apt_packages, is_root=False):
|
88 |
+
dist_bucket = DIST_BUCKET
|
89 |
+
if version.cuda_version:
|
90 |
+
dist_bucket = os.path.join(
|
91 |
+
DIST_BUCKET, 'cuda/{}'.format(version.cuda_version.replace('.', '')))
|
92 |
+
torch_whl = TORCH_WHEEL_TMPL.format(
|
93 |
+
whl_version=version.wheels, py_version=version.py_version)
|
94 |
+
torch_whl_path = os.path.join(dist_bucket, torch_whl)
|
95 |
+
torch_xla_whl = TORCH_XLA_WHEEL_TMPL.format(
|
96 |
+
whl_version=version.wheels, py_version=version.py_version)
|
97 |
+
torch_xla_whl_path = os.path.join(dist_bucket, torch_xla_whl)
|
98 |
+
torchvision_whl = TORCHVISION_WHEEL_TMPL.format(
|
99 |
+
whl_version=version.wheels, py_version=version.py_version)
|
100 |
+
torchvision_whl_path = os.path.join(dist_bucket, torchvision_whl)
|
101 |
+
apt_cmd = ['apt-get', 'install', '-y']
|
102 |
+
apt_cmd.extend(apt_packages)
|
103 |
+
|
104 |
+
if not is_root:
|
105 |
+
# Colab/Kaggle run as root, but not GCE VMs so we need privilege
|
106 |
+
apt_cmd.insert(0, 'sudo')
|
107 |
+
|
108 |
+
installation_cmds = [
|
109 |
+
[sys.executable, '-m', 'pip', 'uninstall', '-y', 'torch', 'torchvision'],
|
110 |
+
['gsutil', 'cp', torch_whl_path, '.'],
|
111 |
+
['gsutil', 'cp', torch_xla_whl_path, '.'],
|
112 |
+
['gsutil', 'cp', torchvision_whl_path, '.'],
|
113 |
+
[sys.executable, '-m', 'pip', 'install', torch_whl],
|
114 |
+
[sys.executable, '-m', 'pip', 'install', torch_xla_whl],
|
115 |
+
[sys.executable, '-m', 'pip', 'install', torchvision_whl],
|
116 |
+
apt_cmd,
|
117 |
+
]
|
118 |
+
for cmd in installation_cmds:
|
119 |
+
subprocess.call(cmd)
|
120 |
+
|
121 |
+
|
122 |
+
def run_setup(args):
|
123 |
+
version = get_version(args.version)
|
124 |
+
# Update TPU
|
125 |
+
print('Updating... This may take around 2 minutes.')
|
126 |
+
|
127 |
+
if is_tpu_runtime():
|
128 |
+
update = threading.Thread(
|
129 |
+
target=update_tpu_runtime, args=(
|
130 |
+
args.tpu,
|
131 |
+
version,
|
132 |
+
))
|
133 |
+
update.start()
|
134 |
+
|
135 |
+
install_vm(version, args.apt_packages, is_root=not args.tpu)
|
136 |
+
|
137 |
+
if is_tpu_runtime():
|
138 |
+
update.join()
|
139 |
+
|
140 |
+
|
141 |
+
if __name__ == '__main__':
|
142 |
+
parser = argparse.ArgumentParser()
|
143 |
+
parser.add_argument(
|
144 |
+
'--version',
|
145 |
+
type=str,
|
146 |
+
default='20200515',
|
147 |
+
help='Versions to install (nightly, release version, or YYYYMMDD).',
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
'--apt-packages',
|
151 |
+
nargs='+',
|
152 |
+
default=['libomp5'],
|
153 |
+
help='List of apt packages to install',
|
154 |
+
)
|
155 |
+
parser.add_argument(
|
156 |
+
'--tpu',
|
157 |
+
type=str,
|
158 |
+
help='[GCP] Name of the TPU (same zone, project as VM running script)',
|
159 |
+
)
|
160 |
+
args = parser.parse_args()
|
161 |
+
run_setup(args)
|