File size: 2,111 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import hashlib
from filelock import FileLock

import torch
import gdown


def _download(filename, url, refresh, agent):
    dirpath = f'{torch.hub.get_dir()}/s3prl_cache'
    os.makedirs(dirpath, exist_ok=True)
    filepath = f'{dirpath}/{filename}'
    with FileLock(filepath + ".lock"):
        if not os.path.isfile(filepath) or refresh:
            if agent == 'wget':
                os.system(f'wget {url} -O {filepath}')
            elif agent == 'gdown':
                gdown.download(url, filepath, use_cookies=False)
            else:
                print('[Download] - Unknown download agent. Only \'wget\' and \'gdown\' are supported.')
                raise NotImplementedError
        else:
            print(f'Using cache found in {filepath}\nfor {url}')
    return filepath


def _urls_to_filepaths(*args, refresh=False, agent='wget'):
    """
    Preprocess the URL specified in *args into local file paths after downloading

    Args:
        Any number of URLs (1 ~ any)

    Return:
        Same number of downloaded file paths
    """

    def url_to_filename(url):
        assert type(url) is str
        m = hashlib.sha256()
        m.update(str.encode(url))
        return str(m.hexdigest())

    def url_to_path(url, refresh):
        if type(url) is str and len(url) > 0:
            return _download(url_to_filename(url), url, refresh, agent=agent)
        else:
            return None

    paths = [url_to_path(url, refresh) for url in args]
    return paths if len(paths) > 1 else paths[0]


def _gdriveids_to_filepaths(*args, refresh=False):
    """
    Preprocess the Google Drive id specified in *args into local file paths after downloading

    Args:
        Any number of Google Drive ids (1 ~ any)

    Return:
        Same number of downloaded file paths
    """

    def gdriveid_to_url(gdriveid):
        if type(gdriveid) is str and len(gdriveid) > 0:
            return f'https://drive.google.com/uc?id={gdriveid}'
        else:
            return None

    return _urls_to_filepaths(*[gdriveid_to_url(gid) for gid in args], refresh=refresh, agent='gdown')