File size: 7,351 Bytes
8ca3a29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
# python3.7
"""Misc utility functions."""

import os
import hashlib

from torch.hub import download_url_to_file

__all__ = [
    'REPO_NAME', 'Infix', 'print_and_execute', 'check_file_ext',
    'IMAGE_EXTENSIONS', 'VIDEO_EXTENSIONS', 'MEDIA_EXTENSIONS',
    'parse_file_format', 'set_cache_dir', 'get_cache_dir', 'download_url'
]

REPO_NAME = 'Hammer'  # Name of the repository (project).


class Infix(object):
    """Helper class to create custom infix operators.

    When using it, make sure to put the operator between `<<` and `>>`.
    `<< INFIX_OP_NAME >>` should be considered as a whole operator.

    Examples:

    # Use `Infix` to create infix operators directly.
    add = Infix(lambda a, b: a + b)
    1 << add >> 2  # gives 3
    1 << add >> 2 << add >> 3  # gives 6

    # Use `Infix` as a decorator.
    @Infix
    def mul(a, b):
        return a * b
    2 << mul >> 4  # gives 8
    2 << mul >> 3 << mul >> 7  # gives 42
    """

    def __init__(self, function):
        self.function = function
        self.left_value = None

    def __rlshift__(self, left_value):  # override `<<` before `Infix` instance
        assert self.left_value is None  # make sure left is only called once
        self.left_value = left_value
        return self

    def __rshift__(self, right_value):  # override `>>` after `Infix` instance
        result = self.function(self.left_value, right_value)
        self.left_value = None  # reset to None
        return result


def print_and_execute(cmd):
    """Prints and executes a system command.

    Args:
        cmd: Command to be executed.
    """
    print(cmd)
    os.system(cmd)


def check_file_ext(filename, *ext_list):
    """Checks whether the given filename is with target extension(s).

    NOTE: If `ext_list` is empty, this function will always return `False`.

    Args:
        filename: Filename to check.
        *ext_list: A list of extensions.

    Returns:
        `True` if the filename is with one of extensions in `ext_list`,
        otherwise `False`.
    """
    if len(ext_list) == 0:
        return False
    ext_list = [ext if ext.startswith('.') else '.' + ext for ext in ext_list]
    ext_list = [ext.lower() for ext in ext_list]
    basename = os.path.basename(filename)
    ext = os.path.splitext(basename)[1].lower()
    return ext in ext_list


# File extensions regarding images (not including GIFs).
IMAGE_EXTENSIONS = (
    '.bmp', '.ppm', '.pgm', '.jpeg', '.jpg', '.jpe', '.jp2', '.png', '.webp',
    '.tiff', '.tif'
)
# File extensions regarding videos.
VIDEO_EXTENSIONS = (
    '.avi', '.mkv', '.mp4', '.m4v', '.mov', '.webm', '.flv', '.rmvb', '.rm',
    '.3gp'
)
# File extensions regarding media, i.e., images, videos, GIFs.
MEDIA_EXTENSIONS = ('.gif', *IMAGE_EXTENSIONS, *VIDEO_EXTENSIONS)


def parse_file_format(path):
    """Parses the file format of a given path.

    This function basically parses the file format according to its extension.
    It will also return `dir` is the given path is a directory.

    Parable file formats:

    - zip: with `.zip` extension.
    - tar: with `.tar` / `.tgz` / `.tar.gz` extension.
    - lmdb: a folder ending with `lmdb`.
    - txt: with `.txt` / `.text` extension, OR without extension (e.g. LICENSE).
    - json: with `.json` extension.
    - jpg: with `.jpeg` / `jpg` / `jpe` extension.
    - png: with `.png` extension.

    Args:
        path: The path to the file to parse format from.

    Returns:
        A lower-case string, indicating the file format, or `None` if the format
            cannot be successfully parsed.
    """
    # Handle directory.
    if os.path.isdir(path) or path.endswith('/'):
        if path.rstrip('/').lower().endswith('lmdb'):
            return 'lmdb'
        return 'dir'
    # Handle file.
    if os.path.isfile(path) and os.path.splitext(path)[1] == '':
        return 'txt'
    path = path.lower()
    if path.endswith('.tar.gz'):  # Cannot parse accurate extension.
        return 'tar'
    ext = os.path.splitext(path)[1]
    if ext == '.zip':
        return 'zip'
    if ext in ['.tar', '.tgz']:
        return 'tar'
    if ext in ['.txt', '.text']:
        return 'txt'
    if ext == '.json':
        return 'json'
    if ext in ['.jpeg', '.jpg', '.jpe']:
        return 'jpg'
    if ext == '.png':
        return 'png'
    # Unparsable.
    return None


_cache_dir = None


def set_cache_dir(directory=None):
    """Sets the global cache directory.

    The cache directory can be used to save some files that will be shared
    across jobs. The default cache directory is set as `~/.cache/${REPO_NAME}/`.
    This function can be used to redirect the cache directory. Or, users can use
    `None` to reset the cache directory back to default.

    Args:
        directory: The target directory used to cache files. If set as `None`,
            the cache directory will be reset back to default. (default: None)
    """
    assert directory is None or isinstance(directory, str), 'Invalid directory!'
    global _cache_dir  # pylint: disable=global-statement
    _cache_dir = directory


def get_cache_dir():
    """Gets the global cache directory.

    The global cache directory is primarily set as `~/.cache/${REPO_NAME}/` by
    default, and can be redirected with `set_cache_dir()`.

    Returns:
        A string, representing the global cache directory.
    """
    if _cache_dir is None:
        home = os.path.expanduser('~')
        return os.path.join(home, '.cache', REPO_NAME)
    return _cache_dir


def download_url(url, path=None, filename=None, sha256=None):
    """Downloads file from URL.

    This function downloads a file from given URL, and executes Hash check if
    needed.

    Args:
        url: The URL to download file from.
        path: Path (directory) to save the downloaded file. If set as `None`,
            the cache directory will be used. Please see `get_cache_dir()` for
            more details. (default: None)
        filename: The name to save the file. If set as `None`, this name will be
            automatically parsed from the given URL. (default: None)
        sha256: The expected sha256 of the downloaded file. If set as `None`,
            the hash check will be skipped. Otherwise, this function will check
            whether the sha256 of the downloaded file matches this field.

    Returns:
        A two-element tuple, where the first term is the full path of the
            downloaded file, and the second term indicate the hash check result.
            `True` means hash check passes, `False` means hash check fails,
            while `None` means no hash check is executed.
    """
    # Handle file path.
    if path is None:
        path = get_cache_dir()
    if filename is None:
        filename = os.path.basename(url)
    save_path = os.path.join(path, filename)
    # Download file if needed.
    if not os.path.exists(save_path):
        print(f'Downloading URL `{url}` to path `{save_path}` ...')
        os.makedirs(path, exist_ok=True)
        download_url_to_file(url, save_path, hash_prefix=None, progress=True)
    # Check hash if needed.
    check_result = None
    if sha256 is not None:
        with open(save_path, 'rb') as f:
            file_hash = hashlib.sha256(f.read())
            check_result = (file_hash.hexdigest() == sha256)

    return save_path, check_result