akhaliq's picture
akhaliq HF staff
add files
81170fd
raw
history blame
No virus
2.1 kB
import jax
import flax
import numpy as np
from tqdm import tqdm
import requests
import os
import tempfile
import logging
logger = logging.getLogger(__name__)
def download(url, ckpt_dir=None):
name = url[url.rfind('/') + 1 : url.rfind('?')]
if ckpt_dir is None:
ckpt_dir = tempfile.gettempdir()
ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels')
ckpt_file = os.path.join(ckpt_dir, name)
if not os.path.exists(ckpt_file):
logger.info(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}')
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
# first create temp file, in case the download fails
ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp')
with open(ckpt_file_temp, 'wb') as file:
for data in response.iter_content(chunk_size=1024):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
logger.error('An error occured while downloading, please try again.')
if os.path.exists(ckpt_file_temp):
os.remove(ckpt_file_temp)
else:
# if download was successful, rename the temp file
os.rename(ckpt_file_temp, ckpt_file)
return ckpt_file
def get(dictionary, key):
if dictionary is None or key not in dictionary:
return None
return dictionary[key]
def prefetch(dataset, n_prefetch):
# Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
ds_iter = iter(dataset)
ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
ds_iter)
if n_prefetch:
ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
return ds_iter