sofmi's picture
refactoring and small fixes (#27)
b6f51cf
import urllib.request
import tarfile
from tqdm import tqdm
import os
import yaml
from ruamel.yaml import YAML
def read_plainconfig(configname):
if not os.path.exists(configname):
raise FileNotFoundError(
f"Config {configname} is not found. Please make sure that the file exists."
)
with open(configname) as file:
return YAML().load(file)
def DownloadModel(modelname,
target_dir):
"""
Downloads a DeepLabCut Model Zoo Project
"""
def show_progress(count, block_size, total_size):
pbar.update(block_size)
def tarfilenamecutting(tarf):
"""' auxfun to extract folder path
ie. /xyz-trainsetxyshufflez/
"""
for memberid, member in enumerate(tarf.getmembers()):
if memberid == 0:
parent = str(member.path)
l = len(parent) + 1
if member.path.startswith(parent):
member.path = member.path[l:]
yield member
neturls = read_plainconfig("DLC_models/pretrained_model_urls.yaml") #FIXME
if modelname in neturls.keys():
url = neturls[modelname]
print(url)
response = urllib.request.urlopen(url)
print(
"Downloading the model from the DeepLabCut server @Harvard -> Go Crimson!!! {}....".format(
url
)
)
total_size = int(response.getheader("Content-Length"))
pbar = tqdm(unit="B", total=total_size, position=0)
filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress)
with tarfile.open(filename, mode="r:gz") as tar:
tar.extractall(target_dir, members=tarfilenamecutting(tar))
else:
models = [
fn
for fn in neturls.keys()
if "resnet_" not in fn and "mobilenet_" not in fn
]
print("Model does not exist: ", modelname)
print("Pick one of the following: ", models)
return target_dir