Spaces:
Running
Running
import hashlib | |
import os | |
import tarfile | |
import urllib.request | |
from tqdm import tqdm | |
def print_arguments(args): | |
print("----------- Configuration Arguments -----------") | |
for arg, value in vars(args).items(): | |
print("%s: %s" % (arg, value)) | |
print("------------------------------------------------") | |
def strtobool(val): | |
val = val.lower() | |
if val in ('y', 'yes', 't', 'true', 'on', '1'): | |
return True | |
elif val in ('n', 'no', 'f', 'false', 'off', '0'): | |
return False | |
else: | |
raise ValueError("invalid truth value %r" % (val,)) | |
def str_none(val): | |
if val == 'None': | |
return None | |
else: | |
return val | |
def add_arguments(argname, type, default, help, argparser, **kwargs): | |
type = strtobool if type == bool else type | |
type = str_none if type == str else type | |
argparser.add_argument("--" + argname, | |
default=default, | |
type=type, | |
help=help + ' Default: %(default)s.', | |
**kwargs) | |
def md5file(fname): | |
hash_md5 = hashlib.md5() | |
f = open(fname, "rb") | |
for chunk in iter(lambda: f.read(4096), b""): | |
hash_md5.update(chunk) | |
f.close() | |
return hash_md5.hexdigest() | |
def download(url, md5sum, target_dir): | |
"""Download file from url to target_dir, and check md5sum.""" | |
if not os.path.exists(target_dir): os.makedirs(target_dir) | |
filepath = os.path.join(target_dir, url.split("/")[-1]) | |
if not (os.path.exists(filepath) and md5file(filepath) == md5sum): | |
print(f"Downloading {url} to {filepath} ...") | |
with urllib.request.urlopen(url) as source, open(filepath, "wb") as output: | |
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, | |
unit_divisor=1024) as loop: | |
while True: | |
buffer = source.read(8192) | |
if not buffer: | |
break | |
output.write(buffer) | |
loop.update(len(buffer)) | |
print(f"\nMD5 Chesksum {filepath} ...") | |
if not md5file(filepath) == md5sum: | |
raise RuntimeError("MD5 checksum failed.") | |
else: | |
print(f"File exists, skip downloading. ({filepath})") | |
return filepath | |
def unpack(filepath, target_dir, rm_tar=False): | |
"""Unpack the file to the target_dir.""" | |
print("Unpacking %s ..." % filepath) | |
tar = tarfile.open(filepath) | |
tar.extractall(target_dir) | |
tar.close() | |
if rm_tar: | |
os.remove(filepath) | |
def make_inputs_require_grad(module, input, output): | |
output.requires_grad_(True) | |