File size: 2,009 Bytes
bfd77fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6f51cf
 
bfd77fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6f51cf
bfd77fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import urllib.request
import tarfile
from tqdm import tqdm
import os
import yaml
from ruamel.yaml import YAML

def read_plainconfig(configname):
    if not os.path.exists(configname):
        raise FileNotFoundError(
            f"Config {configname} is not found. Please make sure that the file exists."
        )
    with open(configname) as file:
        return YAML().load(file)

def DownloadModel(modelname, 
                 target_dir):
    """
    Downloads a DeepLabCut Model Zoo Project
    """
    
    def show_progress(count, block_size, total_size):
        pbar.update(block_size)

    def tarfilenamecutting(tarf):
        """' auxfun to extract folder path
        ie. /xyz-trainsetxyshufflez/
        """
        for memberid, member in enumerate(tarf.getmembers()):
            if memberid == 0:
                parent = str(member.path)
                l = len(parent) + 1
            if member.path.startswith(parent):
                member.path = member.path[l:]
                yield member

    neturls = read_plainconfig("DLC_models/pretrained_model_urls.yaml") #FIXME
    
    if modelname in neturls.keys():
        url = neturls[modelname]
        print(url)
        response = urllib.request.urlopen(url)
        print(
            "Downloading the model from the DeepLabCut server @Harvard -> Go Crimson!!! {}....".format(
                url
            )
        )
        total_size = int(response.getheader("Content-Length"))
        pbar = tqdm(unit="B", total=total_size, position=0)
        filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress)
        with tarfile.open(filename, mode="r:gz") as tar:
            tar.extractall(target_dir, members=tarfilenamecutting(tar))
    else:
        models = [
            fn
            for fn in neturls.keys()
            if "resnet_" not in fn and "mobilenet_" not in fn
        ]
        print("Model does not exist: ", modelname)
        print("Pick one of the following: ", models)
    return target_dir