File size: 5,033 Bytes
e380bd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cfc2e7
 
 
e380bd8
 
 
 
 
 
 
3cfc2e7
e380bd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cfc2e7
 
 
e380bd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""Gradio utilities.

Note that the optional `progress` parameter can be both a `tqdm` module or a
`gr.Progress` instance.
"""

import concurrent.futures
import contextlib
import glob
import hashlib
import logging
import os
import tempfile
import time
import urllib.request

import jax
import numpy as np
from tensorflow.io import gfile


@contextlib.contextmanager
def timed(name):
  t0 = time.monotonic()
  timing = dict(dt=None)
  try:
    yield timing
  finally:
    timing['secs'] = time.monotonic() - t0
    logging.info('Timed %s: %.1f secs', name, timing['secs'])


def copy_file(
    src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite=False
):
  """Copies a file with progress bar.

  Args:
    src: Source file (readable by `tf.io.gfile`) or URL.
    dst: Destination file. Path must be readable by `tf.io.gfile`.
    progress: An object with a `.tqdm` attribute, or `None`.
    block_size: Size of individual blocks to be read/written.
    overwrite: If `True`, overwrite `dst` if it exists.
  """
  if os.path.dirname(dst):
    os.makedirs(os.path.dirname(dst), exist_ok=True)
  if os.path.exists(dst) and not overwrite:
    return

  if src.startswith('http://') or src.startswith('https://'):
    opener = urllib.request.urlopen
    request = urllib.request.Request(src, method='HEAD')
    response = urllib.request.urlopen(request)
    content_length = response.headers.get('Content-Length')
    n = int(np.ceil(int(content_length) / block_size))
    print('content_length', content_length)
  else:
    opener = lambda path: gfile.GFile(path, 'rb')
    stats = gfile.stat(src)
    n = int(np.ceil(stats.length / block_size))

  if progress is None:
    range_or_trange = range
  else:
    range_or_trange = lambda n: progress.tqdm(list(range(n)), desc='download')

  with opener(src) as fin:
    with gfile.GFile(f'{dst}-PARTIAL', 'wb') as fout:
      for _ in range_or_trange(n):
        fout.write(fin.read(block_size))
  gfile.rename(f'{dst}-PARTIAL', dst)


_estimated_real = [(10, 10)]
_memory_cache = {}


def get_with_progress(getter, secs, progress, step=0.1):
  """Returns result from `getter` while showing a progress bar."""
  with concurrent.futures.ThreadPoolExecutor() as executor:
    future = executor.submit(getter)
    for _ in progress.tqdm(list(range(int(np.ceil(secs/step)))), desc='read'):
      if not future.done():
        time.sleep(step)
  return future.result()


def _get_array_sizes(tree):
  return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)]


def get_memory_cache(
    key, getter, max_cache_size_bytes, progress=None, estimated_secs=None
):
  """Keeps cache below specified size by removing elements not last accessed."""
  if key in _memory_cache:
    _memory_cache[key] = _memory_cache.pop(key)  # updated "last accessed" order
    return _memory_cache[key]

  est, real = zip(*_estimated_real)
  if estimated_secs is None:
    estimated_secs = sum(est) / len(est)
  with timed(f'loading {key}') as timing:
    estimated_secs *= sum(real) / sum(est)
    _memory_cache[key] = get_with_progress(getter, estimated_secs, progress)
  _estimated_real.append((estimated_secs, timing['secs']))

  sz = sum(_get_array_sizes(list(_memory_cache.values())))
  logging.info('New memory cache size=%.1f MB', sz/1e6)

  while sz > max_cache_size_bytes:
    k, v = next(iter(_memory_cache.items()))
    if k == key:
      break
    s = sum(_get_array_sizes(v))
    logging.info('Removing %s from memory cache (%.1f MB)', k, s/1e6)
    _memory_cache.pop(k)
    sz -= s

  return _memory_cache[key]


def get_memory_cache_info():
  """Returns number of items and total size in bytes."""
  sizes = _get_array_sizes(_memory_cache)
  return len(_memory_cache), sum(sizes)


CACHE_DIR = os.path.join(tempfile.gettempdir(), 'downloads_cache')


def get_disk_cache(path_or_url, max_cache_size_bytes, progress=None):
  """Keeps cache below specified size by removing elements not last accessed."""
  fname = os.path.basename(path_or_url)
  path_hash = hashlib.md5(path_or_url.encode()).hexdigest() + '__' + fname
  dst = os.path.join(CACHE_DIR, path_hash, fname)
  if os.path.exists(dst):
    return dst

  os.makedirs(os.path.dirname(dst), exist_ok=True)
  with timed(f'copying {path_or_url}'):
    copy_file(path_or_url, dst, progress=progress)

  atimes_sizes_paths = sorted([
      (os.path.getatime(p), os.path.getsize(p), p)
      for p in glob.glob(os.path.join(CACHE_DIR, '*', '*'))
      if os.path.isfile(p)
  ])
  sz = sum(sz for _, sz, _ in atimes_sizes_paths)
  logging.info('New disk cache size=%.1f MB', sz/1e6)

  while sz > max_cache_size_bytes:
    _, s, path = atimes_sizes_paths.pop(0)
    if path == dst:
      break
    logging.info('Removing %s from memory cache (%.1f MB)', fname, sz/1e6)
    os.unlink(fname)
    sz -= s

  return dst


def get_disk_cache_info():
  """Returns number of items and total size in bytes."""
  sizes = [
      os.path.getsize(p)
      for p in glob.glob(os.path.join(CACHE_DIR, '*', '*'))
  ]
  return len(sizes), sum(sizes)