from tqdm import tqdm import requests import os import tempfile def download(ckpt_dir, url): name = url[url.rfind('/') + 1 : url.rfind('?')] if ckpt_dir is None: ckpt_dir = tempfile.gettempdir() ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels') ckpt_file = os.path.join(ckpt_dir, name) if not os.path.exists(ckpt_file): print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}') if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) response = requests.get(url, stream=True) total_size_in_bytes = int(response.headers.get('content-length', 0)) progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) # first create temp file, in case the download fails ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp') with open(ckpt_file_temp, 'wb') as file: for data in response.iter_content(chunk_size=1024): progress_bar.update(len(data)) file.write(data) progress_bar.close() if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: print('An error occured while downloading, please try again.') if os.path.exists(ckpt_file_temp): os.remove(ckpt_file_temp) else: # if download was successful, rename the temp file os.rename(ckpt_file_temp, ckpt_file) return ckpt_file