system HF staff commited on
Commit
e84b953
1 Parent(s): e4a5b84

Update pytorch-xla-env-setup.py

Browse files
Files changed (1) hide show
  1. pytorch-xla-env-setup.py +161 -0
pytorch-xla-env-setup.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Sample usage:
3
+ # python env-setup.py --version 1.5 --apt-packages libomp5
4
+ import argparse
5
+ import collections
6
+ from datetime import datetime
7
+ import os
8
+ import platform
9
+ import re
10
+ import requests
11
+ import subprocess
12
+ import threading
13
+ import sys
14
+
15
+ VersionConfig = collections.namedtuple('VersionConfig',
16
+ ['wheels', 'tpu', 'py_version', 'cuda_version'])
17
+ DEFAULT_CUDA_VERSION = '10.2'
18
+ OLDEST_VERSION = datetime.strptime('20200318', '%Y%m%d')
19
+ OLDEST_GPU_VERSION = datetime.strptime('20200707', '%Y%m%d')
20
+ DIST_BUCKET = 'gs://tpu-pytorch/wheels'
21
+ TORCH_WHEEL_TMPL = 'torch-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl'
22
+ TORCH_XLA_WHEEL_TMPL = 'torch_xla-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl'
23
+ TORCHVISION_WHEEL_TMPL = 'torchvision-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl'
24
+
25
+
26
+ def is_gpu_runtime():
27
+ return os.environ.get('COLAB_GPU', 0) == 1
28
+
29
+
30
+ def is_tpu_runtime():
31
+ return 'TPU_NAME' in os.environ
32
+
33
+
34
+ def update_tpu_runtime(tpu_name, version):
35
+ print(f'Updating TPU runtime to {version.tpu} ...')
36
+
37
+ try:
38
+ import cloud_tpu_client
39
+ except ImportError:
40
+ subprocess.call([sys.executable, '-m', 'pip', 'install', 'cloud-tpu-client'])
41
+ import cloud_tpu_client
42
+
43
+ client = cloud_tpu_client.Client(tpu_name)
44
+ client.configure_tpu_version(version.tpu)
45
+ print('Done updating TPU runtime')
46
+
47
+
48
+ def get_py_version():
49
+ version_tuple = platform.python_version_tuple()
50
+ return version_tuple[0] + version_tuple[1] # major_version + minor_version
51
+
52
+
53
+ def get_cuda_version():
54
+ if is_gpu_runtime():
55
+ # cuda available, install cuda wheels
56
+ return DEFAULT_CUDA_VERSION
57
+
58
+
59
+ def get_version(version):
60
+ cuda_version = get_cuda_version()
61
+ if version == 'nightly':
62
+ return VersionConfig(
63
+ 'nightly', 'pytorch-nightly', get_py_version(), cuda_version)
64
+
65
+ version_date = None
66
+ try:
67
+ version_date = datetime.strptime(version, '%Y%m%d')
68
+ except ValueError:
69
+ pass # Not a dated nightly.
70
+
71
+ if version_date:
72
+ if cuda_version and version_date < OLDEST_GPU_VERSION:
73
+ raise ValueError(
74
+ f'Oldest nightly version build with CUDA available is {OLDEST_GPU_VERSION}')
75
+ elif not cuda_version and version_date < OLDEST_VERSION:
76
+ raise ValueError(f'Oldest nightly version available is {OLDEST_VERSION}')
77
+ return VersionConfig(f'nightly+{version}', f'pytorch-dev{version}',
78
+ get_py_version(), cuda_version)
79
+
80
+ version_regex = re.compile('^(\d+\.)+\d+$')
81
+ if not version_regex.match(version):
82
+ raise ValueError(f'{version} is an invalid torch_xla version pattern')
83
+ return VersionConfig(
84
+ version, f'pytorch-{version}', get_py_version(), cuda_version)
85
+
86
+
87
+ def install_vm(version, apt_packages, is_root=False):
88
+ dist_bucket = DIST_BUCKET
89
+ if version.cuda_version:
90
+ dist_bucket = os.path.join(
91
+ DIST_BUCKET, 'cuda/{}'.format(version.cuda_version.replace('.', '')))
92
+ torch_whl = TORCH_WHEEL_TMPL.format(
93
+ whl_version=version.wheels, py_version=version.py_version)
94
+ torch_whl_path = os.path.join(dist_bucket, torch_whl)
95
+ torch_xla_whl = TORCH_XLA_WHEEL_TMPL.format(
96
+ whl_version=version.wheels, py_version=version.py_version)
97
+ torch_xla_whl_path = os.path.join(dist_bucket, torch_xla_whl)
98
+ torchvision_whl = TORCHVISION_WHEEL_TMPL.format(
99
+ whl_version=version.wheels, py_version=version.py_version)
100
+ torchvision_whl_path = os.path.join(dist_bucket, torchvision_whl)
101
+ apt_cmd = ['apt-get', 'install', '-y']
102
+ apt_cmd.extend(apt_packages)
103
+
104
+ if not is_root:
105
+ # Colab/Kaggle run as root, but not GCE VMs so we need privilege
106
+ apt_cmd.insert(0, 'sudo')
107
+
108
+ installation_cmds = [
109
+ [sys.executable, '-m', 'pip', 'uninstall', '-y', 'torch', 'torchvision'],
110
+ ['gsutil', 'cp', torch_whl_path, '.'],
111
+ ['gsutil', 'cp', torch_xla_whl_path, '.'],
112
+ ['gsutil', 'cp', torchvision_whl_path, '.'],
113
+ [sys.executable, '-m', 'pip', 'install', torch_whl],
114
+ [sys.executable, '-m', 'pip', 'install', torch_xla_whl],
115
+ [sys.executable, '-m', 'pip', 'install', torchvision_whl],
116
+ apt_cmd,
117
+ ]
118
+ for cmd in installation_cmds:
119
+ subprocess.call(cmd)
120
+
121
+
122
+ def run_setup(args):
123
+ version = get_version(args.version)
124
+ # Update TPU
125
+ print('Updating... This may take around 2 minutes.')
126
+
127
+ if is_tpu_runtime():
128
+ update = threading.Thread(
129
+ target=update_tpu_runtime, args=(
130
+ args.tpu,
131
+ version,
132
+ ))
133
+ update.start()
134
+
135
+ install_vm(version, args.apt_packages, is_root=not args.tpu)
136
+
137
+ if is_tpu_runtime():
138
+ update.join()
139
+
140
+
141
+ if __name__ == '__main__':
142
+ parser = argparse.ArgumentParser()
143
+ parser.add_argument(
144
+ '--version',
145
+ type=str,
146
+ default='20200515',
147
+ help='Versions to install (nightly, release version, or YYYYMMDD).',
148
+ )
149
+ parser.add_argument(
150
+ '--apt-packages',
151
+ nargs='+',
152
+ default=['libomp5'],
153
+ help='List of apt packages to install',
154
+ )
155
+ parser.add_argument(
156
+ '--tpu',
157
+ type=str,
158
+ help='[GCP] Name of the TPU (same zone, project as VM running script)',
159
+ )
160
+ args = parser.parse_args()
161
+ run_setup(args)