|
import os |
|
import time |
|
import shutil |
|
import logging |
|
import subprocess |
|
import os.path as op |
|
from typing import List |
|
from collections import OrderedDict |
|
|
|
import torch.distributed as distributed |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
DEFAULT_AZCOPY_PATH = 'azcopy/azcopy' |
|
|
|
|
|
def disk_usage(path: str) -> float: |
|
stat = shutil.disk_usage(path) |
|
return stat.used / stat.total |
|
|
|
|
|
def is_download_successful(stdout: str) -> bool: |
|
for line in stdout.split('\n'): |
|
if line == "Number of Transfers Failed: 0": |
|
return True |
|
logger.info("Azcopy message:\n %s" % stdout) |
|
return False |
|
|
|
|
|
def ensure_directory(path): |
|
"""Check existence of the given directory path. If not, create a new directory. |
|
|
|
Args: |
|
path (str): path of a given directory. |
|
""" |
|
if path == '' or path == '.': |
|
return |
|
if path is not None and len(path) > 0: |
|
assert not op.isfile(path), '{} is a file'.format(path) |
|
if not op.exists(path) and not op.islink(path): |
|
os.makedirs(path, exist_ok=True) |
|
|
|
assert op.isdir(op.abspath(path)), path |
|
|
|
|
|
class LRU(OrderedDict): |
|
def __init__(self, maxsize=3): |
|
self.maxsize = maxsize |
|
|
|
def __getitem__(self, key): |
|
value = super().__getitem__(key) |
|
self.move_to_end(key) |
|
return value |
|
|
|
def __setitem__(self, key, value): |
|
if key in self: |
|
if self[key] is not None: |
|
self[key].close() |
|
self.move_to_end(key) |
|
|
|
logger.debug('=> Cache {}'.format(key)) |
|
super().__setitem__(key, value) |
|
|
|
if len(self) > self.maxsize: |
|
oldest = next(iter(self)) |
|
if self[oldest] is not None: |
|
self[oldest].close() |
|
logger.debug('=> Purged {}'.format(oldest)) |
|
del self[oldest] |
|
|
|
|
|
class BlobStorage(OrderedDict): |
|
""" Pseudo Blob Storage manager |
|
|
|
The registered blobs are maintained in a LRU cache. |
|
Limit size, evicting the least recently looked-up key when full. |
|
https://docs.python.org/3/library/collections.html#collections.OrderedDict |
|
|
|
Input argument: |
|
sas_token (str): path to SAS token. |
|
""" |
|
def __init__(self, |
|
is_train: bool, |
|
sas_token_path: str = None, |
|
azcopy_path: str = None, |
|
*args, **kwds): |
|
super().__init__(*args, **kwds) |
|
self.maxsize = 2 if is_train else 10 |
|
self.is_train = is_train |
|
|
|
if sas_token_path: |
|
self.sas_token = BlobStorage.read_sas_token(sas_token_path) |
|
self.base_url = self.sas_token[:self.sas_token.index("?")] |
|
self.query_string = self.sas_token[self.sas_token.index("?"):] |
|
self.container = BlobStorage.extract_container(self.sas_token) |
|
else: |
|
self.sas_token = None |
|
self.base_url = None |
|
self.query_string = None |
|
self.container = None |
|
|
|
logger.debug( |
|
f"=> [BlobStorage] Base url: {self.base_url}" |
|
f"=> [BlobStorage] Query string: {self.query_string}" |
|
f"=> [BlobStorage] Container name: {self.container}" |
|
) |
|
|
|
self.azcopy_path = azcopy_path if azcopy_path else DEFAULT_AZCOPY_PATH |
|
self._cached_files = LRU(3) |
|
|
|
def __getitem__(self, key): |
|
value = super().__getitem__(key) |
|
self.move_to_end(key) |
|
return value |
|
|
|
def __setitem__(self, key, value): |
|
if key in self: |
|
self.move_to_end(key) |
|
super().__setitem__(key, value) |
|
|
|
|
|
if len(self) > self.maxsize: |
|
oldest = next(iter(self)) |
|
del self[oldest] |
|
|
|
@staticmethod |
|
def read_sas_token(path: str) -> str: |
|
with open(path, 'r') as f: |
|
token = f.readline().strip() |
|
return token |
|
|
|
@staticmethod |
|
def extract_container(token: str) -> str: |
|
""" |
|
Input argument: |
|
token (str): the full URI of Shared Access Signature (SAS) in the following format. |
|
https://[storage_account].blob.core.windows.net/[container_name][SAS_token] |
|
""" |
|
return os.path.basename(token.split('?')[0]) |
|
|
|
def _convert_to_blob_url(self, local_path: str): |
|
return self.base_url + local_path.split("azcopy")[1] + self.query_string |
|
|
|
def _convert_to_blob_folder_url(self, local_path: str): |
|
return self.base_url + local_path.split("azcopy")[1] + "/*" + self.query_string |
|
|
|
def fetch_blob(self, local_path: str) -> None: |
|
if op.exists(local_path): |
|
logger.info('=> Try to open {}'.format(local_path)) |
|
fp = open(local_path, 'r') |
|
self._cached_files[local_path] = fp |
|
logger.debug("=> %s downloaded. Skip." % local_path) |
|
return |
|
blob_url = self._convert_to_blob_url(local_path) |
|
rank = '0' if 'RANK' not in os.environ else os.environ['RANK'] |
|
cmd = [self.azcopy_path, "copy", blob_url, local_path + rank] |
|
curr_usage = disk_usage('/') |
|
logger.info( |
|
"=> Downloading %s with azcopy ... (disk usage: %.2f%%)" |
|
% (local_path, curr_usage * 100) |
|
) |
|
proc = subprocess.run(cmd, stdout=subprocess.PIPE) |
|
while not is_download_successful(proc.stdout.decode()): |
|
logger.info("=> Azcopy failed to download {}. Retrying ...".format(blob_url)) |
|
proc = subprocess.run(cmd, stdout=subprocess.PIPE) |
|
if not op.exists(local_path): |
|
os.rename(local_path + rank, local_path) |
|
else: |
|
os.remove(local_path + rank) |
|
logger.info( |
|
"=> Downloaded %s with azcopy ... (disk usage: %.2f%% => %.2f%%)" % |
|
(local_path, curr_usage * 100, disk_usage('/') * 100) |
|
) |
|
|
|
def fetch_blob_folder(self, local_path: str, azcopy_args: list=[]) -> None: |
|
blob_url = self._convert_to_blob_folder_url(local_path) |
|
cmd = [self.azcopy_path, "copy", blob_url, local_path] + azcopy_args |
|
curr_usage = disk_usage('/') |
|
logger.info( |
|
"=> Downloading %s with azcopy args %s ... (disk usage: %.2f%%)" |
|
% (local_path, ' '.join(azcopy_args), curr_usage * 100) |
|
) |
|
proc = subprocess.run(cmd, stdout=subprocess.PIPE) |
|
while not is_download_successful(proc.stdout.decode()): |
|
logger.info("=> Azcopy failed to download {} with args {}. Retrying ...".format(blob_url, ' '.join(azcopy_args))) |
|
proc = subprocess.run(cmd, stdout=subprocess.PIPE) |
|
logger.info( |
|
"=> Downloaded %s with azcopy args %s ... (disk usage: %.2f%% => %.2f%%)" % |
|
(local_path, ' '.join(azcopy_args), curr_usage * 100, disk_usage('/') * 100) |
|
) |
|
|
|
def register_local_tsv_paths(self, local_paths: List[str]) -> List[str]: |
|
if self.sas_token: |
|
tsv_paths_new = [] |
|
lineidx_paths = set() |
|
linelist_paths = set() |
|
for path in local_paths: |
|
tsv_path_az = path.replace(self.container, 'azcopy') |
|
tsv_paths_new.append(tsv_path_az) |
|
logger.debug("=> Registering {}".format(tsv_path_az)) |
|
|
|
if not self.is_train: |
|
logger.info('=> Downloading {}...'.format(tsv_path_az)) |
|
self.fetch_blob(tsv_path_az) |
|
logger.info('=> Downloaded {}'.format(tsv_path_az)) |
|
|
|
lineidx = op.splitext(path)[0] + '.lineidx' |
|
lineidx_ = lineidx.replace(self.container, 'azcopy') |
|
if self.is_train: |
|
if not op.isfile(lineidx_) and op.dirname(lineidx_) not in lineidx_paths: |
|
lineidx_paths.add(op.dirname(lineidx_)) |
|
else: |
|
if not op.isfile(lineidx_): |
|
ensure_directory(op.dirname(lineidx_)) |
|
self.fetch_blob(lineidx_) |
|
|
|
linelist = op.splitext(path)[0] + '.linelist' |
|
linelist_ = linelist.replace(self.container, 'azcopy') |
|
|
|
if self.is_train: |
|
if op.isfile(linelist) and not op.isfile(linelist_) and op.dirname(linelist_) not in linelist_paths: |
|
linelist_paths.add(op.dirname(linelist_)) |
|
else: |
|
if op.isfile(linelist) and not op.isfile(linelist_): |
|
ensure_directory(op.dirname(linelist_)) |
|
self.fetch_blob(linelist_) |
|
|
|
if self.is_train: |
|
for path in lineidx_paths: |
|
self.fetch_blob_folder(path, azcopy_args=['--include-pattern', '*.lineidx']) |
|
|
|
for path in linelist_paths: |
|
self.fetch_blob_folder(path, azcopy_args=['--include-pattern', '*.linelist']) |
|
|
|
return tsv_paths_new |
|
else: |
|
return local_paths |
|
|
|
def open(self, local_path: str): |
|
if self.sas_token and 'azcopy' in local_path: |
|
while not op.exists(local_path): |
|
time.sleep(1) |
|
fid = open(local_path, 'r') |
|
return fid |
|
|