medical
File size: 9,364 Bytes
5ceacbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import os
import time
import shutil
import logging
import subprocess
import os.path as op
from typing import List
from collections import OrderedDict

import torch.distributed as distributed

logger = logging.getLogger(__name__)

DEFAULT_AZCOPY_PATH = 'azcopy/azcopy'


def disk_usage(path: str) -> float:
    stat = shutil.disk_usage(path)
    return stat.used / stat.total


def is_download_successful(stdout: str) -> bool:
    for line in stdout.split('\n'):
        if line == "Number of Transfers Failed: 0":
            return True
    logger.info("Azcopy message:\n %s" % stdout)
    return False


def ensure_directory(path):
    """Check existence of the given directory path. If not, create a new directory.

    Args:
        path (str): path of a given directory.
    """
    if path == '' or path == '.':
        return
    if path is not None and len(path) > 0:
        assert not op.isfile(path), '{} is a file'.format(path)
        if not op.exists(path) and not op.islink(path):
            os.makedirs(path, exist_ok=True)
        # we should always check if it succeeds.
        assert op.isdir(op.abspath(path)), path


class LRU(OrderedDict):
    def __init__(self, maxsize=3):
        self.maxsize = maxsize

    def __getitem__(self, key):
        value = super().__getitem__(key)
        self.move_to_end(key)
        return value

    def __setitem__(self, key, value):
        if key in self:
            if self[key] is not None:
                self[key].close()
                self.move_to_end(key)

        logger.debug('=> Cache {}'.format(key))
        super().__setitem__(key, value)

        if len(self) > self.maxsize:
            oldest = next(iter(self))
            if self[oldest] is not None:
                self[oldest].close()
            logger.debug('=> Purged {}'.format(oldest))
            del self[oldest]


class BlobStorage(OrderedDict):
    """ Pseudo Blob Storage manager

    The registered blobs are maintained in a LRU cache.
    Limit size, evicting the least recently looked-up key when full.
    https://docs.python.org/3/library/collections.html#collections.OrderedDict

    Input argument:
        sas_token (str): path to SAS token.
    """
    def __init__(self,
                 is_train: bool,
                 sas_token_path: str = None,
                 azcopy_path: str = None,
                 *args, **kwds):
        super().__init__(*args, **kwds)
        self.maxsize = 2 if is_train else 10    # Set maxsize to large number such val data never get purged.
        self.is_train = is_train

        if sas_token_path:
            self.sas_token = BlobStorage.read_sas_token(sas_token_path)
            self.base_url = self.sas_token[:self.sas_token.index("?")]
            self.query_string = self.sas_token[self.sas_token.index("?"):]
            self.container = BlobStorage.extract_container(self.sas_token)
        else:
            self.sas_token = None
            self.base_url = None
            self.query_string = None
            self.container = None

        logger.debug(
            f"=> [BlobStorage] Base url: {self.base_url}"
            f"=> [BlobStorage] Query string: {self.query_string}"
            f"=> [BlobStorage] Container name: {self.container}"
        )

        self.azcopy_path = azcopy_path if azcopy_path else DEFAULT_AZCOPY_PATH
        self._cached_files = LRU(3)

    def __getitem__(self, key):
        value = super().__getitem__(key)
        self.move_to_end(key)
        return value

    def __setitem__(self, key, value):
        if key in self:
            self.move_to_end(key)
        super().__setitem__(key, value)
        # NOTE: purge the least recently used data if the disk usage is high.
        # ITP restarts GPU clusters when disk usage reaches 80%.
        if len(self) > self.maxsize:
            oldest = next(iter(self))
            del self[oldest]

    @staticmethod
    def read_sas_token(path: str) -> str:
        with open(path, 'r') as f:
            token = f.readline().strip()
        return token

    @staticmethod
    def extract_container(token: str) -> str:
        """
        Input argument:
            token (str): the full URI of Shared Access Signature (SAS) in the following format.
            https://[storage_account].blob.core.windows.net/[container_name][SAS_token]
        """
        return os.path.basename(token.split('?')[0])

    def _convert_to_blob_url(self, local_path: str):
        return self.base_url + local_path.split("azcopy")[1] + self.query_string

    def _convert_to_blob_folder_url(self, local_path: str):
        return self.base_url + local_path.split("azcopy")[1] + "/*" + self.query_string

    def fetch_blob(self, local_path: str) -> None:
        if op.exists(local_path):
            logger.info('=> Try to open {}'.format(local_path))
            fp = open(local_path, 'r')
            self._cached_files[local_path] = fp
            logger.debug("=> %s downloaded. Skip." % local_path)
            return
        blob_url = self._convert_to_blob_url(local_path)
        rank = '0' if 'RANK' not in os.environ else os.environ['RANK']
        cmd = [self.azcopy_path, "copy", blob_url, local_path + rank]
        curr_usage = disk_usage('/')
        logger.info(
            "=> Downloading %s with azcopy ... (disk usage: %.2f%%)"
            % (local_path, curr_usage * 100)
        )
        proc = subprocess.run(cmd, stdout=subprocess.PIPE)
        while not is_download_successful(proc.stdout.decode()):
            logger.info("=> Azcopy failed to download {}. Retrying ...".format(blob_url))
            proc = subprocess.run(cmd, stdout=subprocess.PIPE)
        if not op.exists(local_path):
            os.rename(local_path + rank, local_path)
        else:
            os.remove(local_path + rank)
        logger.info(
            "=> Downloaded %s with azcopy ... (disk usage: %.2f%% => %.2f%%)" %
            (local_path, curr_usage * 100, disk_usage('/') * 100)
        )

    def fetch_blob_folder(self, local_path: str, azcopy_args: list=[]) -> None:
        blob_url = self._convert_to_blob_folder_url(local_path)
        cmd = [self.azcopy_path, "copy", blob_url, local_path] + azcopy_args
        curr_usage = disk_usage('/')
        logger.info(
            "=> Downloading %s with azcopy args %s ... (disk usage: %.2f%%)"
            % (local_path, ' '.join(azcopy_args), curr_usage * 100)
        )
        proc = subprocess.run(cmd, stdout=subprocess.PIPE)
        while not is_download_successful(proc.stdout.decode()):
            logger.info("=> Azcopy failed to download {} with args {}. Retrying ...".format(blob_url, ' '.join(azcopy_args)))
            proc = subprocess.run(cmd, stdout=subprocess.PIPE)
        logger.info(
            "=> Downloaded %s with azcopy args %s ... (disk usage: %.2f%% => %.2f%%)" %
            (local_path, ' '.join(azcopy_args), curr_usage * 100, disk_usage('/') * 100)
        )

    def register_local_tsv_paths(self, local_paths: List[str]) -> List[str]:
        if self.sas_token:
            tsv_paths_new = []
            lineidx_paths = set()
            linelist_paths = set()
            for path in local_paths:
                tsv_path_az = path.replace(self.container, 'azcopy')
                tsv_paths_new.append(tsv_path_az)
                logger.debug("=> Registering {}".format(tsv_path_az))

                if not self.is_train:
                    logger.info('=> Downloading {}...'.format(tsv_path_az))
                    self.fetch_blob(tsv_path_az)
                    logger.info('=> Downloaded {}'.format(tsv_path_az))

                lineidx = op.splitext(path)[0] + '.lineidx'
                lineidx_ = lineidx.replace(self.container, 'azcopy')
                if self.is_train:
                    if not op.isfile(lineidx_) and op.dirname(lineidx_) not in lineidx_paths:
                        lineidx_paths.add(op.dirname(lineidx_))
                else:
                    if not op.isfile(lineidx_):
                        ensure_directory(op.dirname(lineidx_))
                        self.fetch_blob(lineidx_)

                linelist = op.splitext(path)[0] + '.linelist'
                linelist_ = linelist.replace(self.container, 'azcopy')
                # .linelist does not always exist. Check existence before fetch
                if self.is_train:
                    if op.isfile(linelist) and not op.isfile(linelist_) and op.dirname(linelist_) not in linelist_paths:
                        linelist_paths.add(op.dirname(linelist_))
                else:
                    if op.isfile(linelist) and not op.isfile(linelist_):
                        ensure_directory(op.dirname(linelist_))
                        self.fetch_blob(linelist_)

            if self.is_train:
                for path in lineidx_paths:
                    self.fetch_blob_folder(path, azcopy_args=['--include-pattern', '*.lineidx'])

                for path in linelist_paths:
                    self.fetch_blob_folder(path, azcopy_args=['--include-pattern', '*.linelist'])

            return tsv_paths_new
        else:
            return local_paths

    def open(self, local_path: str):
        if self.sas_token and 'azcopy' in local_path:
            while not op.exists(local_path):
                time.sleep(1)
        fid = open(local_path, 'r')
        return fid