sabrinabenas
add crop in orig img
3c2de81
raw
history blame
No virus
1.99 kB
import urllib.request
import tarfile
from tqdm import tqdm
import os
import yaml
from ruamel.yaml import YAML
def read_plainconfig(configname):
if not os.path.exists(configname):
raise FileNotFoundError(
f"Config {configname} is not found. Please make sure that the file exists."
)
with open(configname) as file:
return YAML().load(file)
def DownloadModel(modelname, target_dir):
"""
Downloads a DeepLabCut Model Zoo Project
"""
def show_progress(count, block_size, total_size):
pbar.update(block_size)
def tarfilenamecutting(tarf):
"""' auxfun to extract folder path
ie. /xyz-trainsetxyshufflez/
"""
for memberid, member in enumerate(tarf.getmembers()):
if memberid == 0:
parent = str(member.path)
l = len(parent) + 1
if member.path.startswith(parent):
member.path = member.path[l:]
yield member
neturls = read_plainconfig("model/pretrained_model_urls.yaml") #FIXME
if modelname in neturls.keys():
url = neturls[modelname]
print(url)
response = urllib.request.urlopen(url)
print(
"Downloading the model from the DeepLabCut server @Harvard -> Go Crimson!!! {}....".format(
url
)
)
total_size = int(response.getheader("Content-Length"))
pbar = tqdm(unit="B", total=total_size, position=0)
filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress)
with tarfile.open(filename, mode="r:gz") as tar:
tar.extractall(target_dir, members=tarfilenamecutting(tar))
else:
models = [
fn
for fn in neturls.keys()
if "resnet_" not in fn and "mobilenet_" not in fn
]
print("Model does not exist: ", modelname)
print("Pick one of the following: ", models)
return target_dir