File size: 7,351 Bytes
8ca3a29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
# python3.7
"""Misc utility functions."""
import os
import hashlib
from torch.hub import download_url_to_file
__all__ = [
'REPO_NAME', 'Infix', 'print_and_execute', 'check_file_ext',
'IMAGE_EXTENSIONS', 'VIDEO_EXTENSIONS', 'MEDIA_EXTENSIONS',
'parse_file_format', 'set_cache_dir', 'get_cache_dir', 'download_url'
]
REPO_NAME = 'Hammer' # Name of the repository (project).
class Infix(object):
"""Helper class to create custom infix operators.
When using it, make sure to put the operator between `<<` and `>>`.
`<< INFIX_OP_NAME >>` should be considered as a whole operator.
Examples:
# Use `Infix` to create infix operators directly.
add = Infix(lambda a, b: a + b)
1 << add >> 2 # gives 3
1 << add >> 2 << add >> 3 # gives 6
# Use `Infix` as a decorator.
@Infix
def mul(a, b):
return a * b
2 << mul >> 4 # gives 8
2 << mul >> 3 << mul >> 7 # gives 42
"""
def __init__(self, function):
self.function = function
self.left_value = None
def __rlshift__(self, left_value): # override `<<` before `Infix` instance
assert self.left_value is None # make sure left is only called once
self.left_value = left_value
return self
def __rshift__(self, right_value): # override `>>` after `Infix` instance
result = self.function(self.left_value, right_value)
self.left_value = None # reset to None
return result
def print_and_execute(cmd):
"""Prints and executes a system command.
Args:
cmd: Command to be executed.
"""
print(cmd)
os.system(cmd)
def check_file_ext(filename, *ext_list):
"""Checks whether the given filename is with target extension(s).
NOTE: If `ext_list` is empty, this function will always return `False`.
Args:
filename: Filename to check.
*ext_list: A list of extensions.
Returns:
`True` if the filename is with one of extensions in `ext_list`,
otherwise `False`.
"""
if len(ext_list) == 0:
return False
ext_list = [ext if ext.startswith('.') else '.' + ext for ext in ext_list]
ext_list = [ext.lower() for ext in ext_list]
basename = os.path.basename(filename)
ext = os.path.splitext(basename)[1].lower()
return ext in ext_list
# File extensions regarding images (not including GIFs).
IMAGE_EXTENSIONS = (
'.bmp', '.ppm', '.pgm', '.jpeg', '.jpg', '.jpe', '.jp2', '.png', '.webp',
'.tiff', '.tif'
)
# File extensions regarding videos.
VIDEO_EXTENSIONS = (
'.avi', '.mkv', '.mp4', '.m4v', '.mov', '.webm', '.flv', '.rmvb', '.rm',
'.3gp'
)
# File extensions regarding media, i.e., images, videos, GIFs.
MEDIA_EXTENSIONS = ('.gif', *IMAGE_EXTENSIONS, *VIDEO_EXTENSIONS)
def parse_file_format(path):
"""Parses the file format of a given path.
This function basically parses the file format according to its extension.
It will also return `dir` is the given path is a directory.
Parable file formats:
- zip: with `.zip` extension.
- tar: with `.tar` / `.tgz` / `.tar.gz` extension.
- lmdb: a folder ending with `lmdb`.
- txt: with `.txt` / `.text` extension, OR without extension (e.g. LICENSE).
- json: with `.json` extension.
- jpg: with `.jpeg` / `jpg` / `jpe` extension.
- png: with `.png` extension.
Args:
path: The path to the file to parse format from.
Returns:
A lower-case string, indicating the file format, or `None` if the format
cannot be successfully parsed.
"""
# Handle directory.
if os.path.isdir(path) or path.endswith('/'):
if path.rstrip('/').lower().endswith('lmdb'):
return 'lmdb'
return 'dir'
# Handle file.
if os.path.isfile(path) and os.path.splitext(path)[1] == '':
return 'txt'
path = path.lower()
if path.endswith('.tar.gz'): # Cannot parse accurate extension.
return 'tar'
ext = os.path.splitext(path)[1]
if ext == '.zip':
return 'zip'
if ext in ['.tar', '.tgz']:
return 'tar'
if ext in ['.txt', '.text']:
return 'txt'
if ext == '.json':
return 'json'
if ext in ['.jpeg', '.jpg', '.jpe']:
return 'jpg'
if ext == '.png':
return 'png'
# Unparsable.
return None
_cache_dir = None
def set_cache_dir(directory=None):
"""Sets the global cache directory.
The cache directory can be used to save some files that will be shared
across jobs. The default cache directory is set as `~/.cache/${REPO_NAME}/`.
This function can be used to redirect the cache directory. Or, users can use
`None` to reset the cache directory back to default.
Args:
directory: The target directory used to cache files. If set as `None`,
the cache directory will be reset back to default. (default: None)
"""
assert directory is None or isinstance(directory, str), 'Invalid directory!'
global _cache_dir # pylint: disable=global-statement
_cache_dir = directory
def get_cache_dir():
"""Gets the global cache directory.
The global cache directory is primarily set as `~/.cache/${REPO_NAME}/` by
default, and can be redirected with `set_cache_dir()`.
Returns:
A string, representing the global cache directory.
"""
if _cache_dir is None:
home = os.path.expanduser('~')
return os.path.join(home, '.cache', REPO_NAME)
return _cache_dir
def download_url(url, path=None, filename=None, sha256=None):
"""Downloads file from URL.
This function downloads a file from given URL, and executes Hash check if
needed.
Args:
url: The URL to download file from.
path: Path (directory) to save the downloaded file. If set as `None`,
the cache directory will be used. Please see `get_cache_dir()` for
more details. (default: None)
filename: The name to save the file. If set as `None`, this name will be
automatically parsed from the given URL. (default: None)
sha256: The expected sha256 of the downloaded file. If set as `None`,
the hash check will be skipped. Otherwise, this function will check
whether the sha256 of the downloaded file matches this field.
Returns:
A two-element tuple, where the first term is the full path of the
downloaded file, and the second term indicate the hash check result.
`True` means hash check passes, `False` means hash check fails,
while `None` means no hash check is executed.
"""
# Handle file path.
if path is None:
path = get_cache_dir()
if filename is None:
filename = os.path.basename(url)
save_path = os.path.join(path, filename)
# Download file if needed.
if not os.path.exists(save_path):
print(f'Downloading URL `{url}` to path `{save_path}` ...')
os.makedirs(path, exist_ok=True)
download_url_to_file(url, save_path, hash_prefix=None, progress=True)
# Check hash if needed.
check_result = None
if sha256 is not None:
with open(save_path, 'rb') as f:
file_hash = hashlib.sha256(f.read())
check_result = (file_hash.hexdigest() == sha256)
return save_path, check_result
|