Spaces:
Build error
Build error
import urllib.request | |
import tarfile | |
from tqdm import tqdm | |
import os | |
import yaml | |
from ruamel.yaml import YAML | |
def read_plainconfig(configname): | |
if not os.path.exists(configname): | |
raise FileNotFoundError( | |
f"Config {configname} is not found. Please make sure that the file exists." | |
) | |
with open(configname) as file: | |
return YAML().load(file) | |
def DownloadModel(modelname, target_dir): | |
""" | |
Downloads a DeepLabCut Model Zoo Project | |
""" | |
def show_progress(count, block_size, total_size): | |
pbar.update(block_size) | |
def tarfilenamecutting(tarf): | |
"""' auxfun to extract folder path | |
ie. /xyz-trainsetxyshufflez/ | |
""" | |
for memberid, member in enumerate(tarf.getmembers()): | |
if memberid == 0: | |
parent = str(member.path) | |
l = len(parent) + 1 | |
if member.path.startswith(parent): | |
member.path = member.path[l:] | |
yield member | |
neturls = read_plainconfig("./model/pretrained_model_urls.yaml") #FIXME | |
if modelname in neturls.keys(): | |
url = neturls[modelname] | |
print(url) | |
response = urllib.request.urlopen(url) | |
print( | |
"Downloading the model from the DeepLabCut server @Harvard -> Go Crimson!!! {}....".format( | |
url | |
) | |
) | |
total_size = int(response.getheader("Content-Length")) | |
pbar = tqdm(unit="B", total=total_size, position=0) | |
filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress) | |
with tarfile.open(filename, mode="r:gz") as tar: | |
tar.extractall(target_dir, members=tarfilenamecutting(tar)) | |
else: | |
models = [ | |
fn | |
for fn in neturls.keys() | |
if "resnet_" not in fn and "mobilenet_" not in fn | |
] | |
print("Model does not exist: ", modelname) | |
print("Pick one of the following: ", models) | |
return target_dir | |