Spaces:
Sleeping
Sleeping
File size: 5,033 Bytes
e380bd8 3cfc2e7 e380bd8 3cfc2e7 e380bd8 3cfc2e7 e380bd8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
"""Gradio utilities.
Note that the optional `progress` parameter can be both a `tqdm` module or a
`gr.Progress` instance.
"""
import concurrent.futures
import contextlib
import glob
import hashlib
import logging
import os
import tempfile
import time
import urllib.request
import jax
import numpy as np
from tensorflow.io import gfile
@contextlib.contextmanager
def timed(name):
t0 = time.monotonic()
timing = dict(dt=None)
try:
yield timing
finally:
timing['secs'] = time.monotonic() - t0
logging.info('Timed %s: %.1f secs', name, timing['secs'])
def copy_file(
src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite=False
):
"""Copies a file with progress bar.
Args:
src: Source file (readable by `tf.io.gfile`) or URL.
dst: Destination file. Path must be readable by `tf.io.gfile`.
progress: An object with a `.tqdm` attribute, or `None`.
block_size: Size of individual blocks to be read/written.
overwrite: If `True`, overwrite `dst` if it exists.
"""
if os.path.dirname(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
if os.path.exists(dst) and not overwrite:
return
if src.startswith('http://') or src.startswith('https://'):
opener = urllib.request.urlopen
request = urllib.request.Request(src, method='HEAD')
response = urllib.request.urlopen(request)
content_length = response.headers.get('Content-Length')
n = int(np.ceil(int(content_length) / block_size))
print('content_length', content_length)
else:
opener = lambda path: gfile.GFile(path, 'rb')
stats = gfile.stat(src)
n = int(np.ceil(stats.length / block_size))
if progress is None:
range_or_trange = range
else:
range_or_trange = lambda n: progress.tqdm(list(range(n)), desc='download')
with opener(src) as fin:
with gfile.GFile(f'{dst}-PARTIAL', 'wb') as fout:
for _ in range_or_trange(n):
fout.write(fin.read(block_size))
gfile.rename(f'{dst}-PARTIAL', dst)
_estimated_real = [(10, 10)]
_memory_cache = {}
def get_with_progress(getter, secs, progress, step=0.1):
"""Returns result from `getter` while showing a progress bar."""
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(getter)
for _ in progress.tqdm(list(range(int(np.ceil(secs/step)))), desc='read'):
if not future.done():
time.sleep(step)
return future.result()
def _get_array_sizes(tree):
return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)]
def get_memory_cache(
key, getter, max_cache_size_bytes, progress=None, estimated_secs=None
):
"""Keeps cache below specified size by removing elements not last accessed."""
if key in _memory_cache:
_memory_cache[key] = _memory_cache.pop(key) # updated "last accessed" order
return _memory_cache[key]
est, real = zip(*_estimated_real)
if estimated_secs is None:
estimated_secs = sum(est) / len(est)
with timed(f'loading {key}') as timing:
estimated_secs *= sum(real) / sum(est)
_memory_cache[key] = get_with_progress(getter, estimated_secs, progress)
_estimated_real.append((estimated_secs, timing['secs']))
sz = sum(_get_array_sizes(list(_memory_cache.values())))
logging.info('New memory cache size=%.1f MB', sz/1e6)
while sz > max_cache_size_bytes:
k, v = next(iter(_memory_cache.items()))
if k == key:
break
s = sum(_get_array_sizes(v))
logging.info('Removing %s from memory cache (%.1f MB)', k, s/1e6)
_memory_cache.pop(k)
sz -= s
return _memory_cache[key]
def get_memory_cache_info():
"""Returns number of items and total size in bytes."""
sizes = _get_array_sizes(_memory_cache)
return len(_memory_cache), sum(sizes)
CACHE_DIR = os.path.join(tempfile.gettempdir(), 'downloads_cache')
def get_disk_cache(path_or_url, max_cache_size_bytes, progress=None):
"""Keeps cache below specified size by removing elements not last accessed."""
fname = os.path.basename(path_or_url)
path_hash = hashlib.md5(path_or_url.encode()).hexdigest() + '__' + fname
dst = os.path.join(CACHE_DIR, path_hash, fname)
if os.path.exists(dst):
return dst
os.makedirs(os.path.dirname(dst), exist_ok=True)
with timed(f'copying {path_or_url}'):
copy_file(path_or_url, dst, progress=progress)
atimes_sizes_paths = sorted([
(os.path.getatime(p), os.path.getsize(p), p)
for p in glob.glob(os.path.join(CACHE_DIR, '*', '*'))
if os.path.isfile(p)
])
sz = sum(sz for _, sz, _ in atimes_sizes_paths)
logging.info('New disk cache size=%.1f MB', sz/1e6)
while sz > max_cache_size_bytes:
_, s, path = atimes_sizes_paths.pop(0)
if path == dst:
break
logging.info('Removing %s from memory cache (%.1f MB)', fname, sz/1e6)
os.unlink(fname)
sz -= s
return dst
def get_disk_cache_info():
"""Returns number of items and total size in bytes."""
sizes = [
os.path.getsize(p)
for p in glob.glob(os.path.join(CACHE_DIR, '*', '*'))
]
return len(sizes), sum(sizes)
|