jexzeb's picture
Upload folder using huggingface_hub
b6068b4
raw
history blame contribute delete
No virus
40.8 kB
import base64
import collections
import io
import itertools
import logging
import math
import os
from functools import lru_cache
from typing import TYPE_CHECKING
import fsspec.core
try:
import ujson as json
except ImportError:
if not TYPE_CHECKING:
import json
from ..asyn import AsyncFileSystem
from ..callbacks import _DEFAULT_CALLBACK
from ..core import filesystem, open, split_protocol
from ..utils import isfilelike, merge_offset_ranges, other_paths
logger = logging.getLogger("fsspec.reference")
class ReferenceNotReachable(RuntimeError):
def __init__(self, reference, target, *args):
super().__init__(*args)
self.reference = reference
self.target = target
def __str__(self):
return f'Reference "{self.reference}" failed to fetch target {self.target}'
def _first(d):
return list(d.values())[0]
def _prot_in_references(path, references):
ref = references.get(path)
if isinstance(ref, (list, tuple)):
return split_protocol(ref[0])[0] if ref[0] else ref[0]
def _protocol_groups(paths, references):
if isinstance(paths, str):
return {_prot_in_references(paths, references): [paths]}
out = {}
for path in paths:
protocol = _prot_in_references(path, references)
out.setdefault(protocol, []).append(path)
return out
class RefsValuesView(collections.abc.ValuesView):
def __iter__(self):
for val in self._mapping.zmetadata.values():
yield json.dumps(val).encode()
yield from self._mapping._items.values()
for field in self._mapping.listdir():
chunk_sizes = self._mapping._get_chunk_sizes(field)
if len(chunk_sizes) == 0:
yield self._mapping[field + "/0"]
continue
yield from self._mapping._generate_all_records(field)
class RefsItemsView(collections.abc.ItemsView):
def __iter__(self):
return zip(self._mapping.keys(), self._mapping.values())
def ravel_multi_index(idx, sizes):
val = 0
mult = 1
for i, s in zip(idx[::-1], sizes[::-1]):
val += i * mult
mult *= s
return val
class LazyReferenceMapper(collections.abc.MutableMapping):
"""Interface to read parquet store as if it were a standard kerchunk
references dict."""
# import is class level to prevent numpy dep requirement for fsspec
@property
def np(self):
import numpy as np
return np
@property
def pd(self):
import pandas as pd
return pd
def __init__(
self, root, fs=None, out_root=None, cache_size=128, categorical_threshold=10
):
"""
Parameters
----------
root : str
Root of parquet store
fs : fsspec.AbstractFileSystem
fsspec filesystem object, default is local filesystem.
cache_size : int
Maximum size of LRU cache, where cache_size*record_size denotes
the total number of references that can be loaded in memory at once.
"""
self.root = root
self.chunk_sizes = {}
self._items = {}
self.dirs = None
self.fs = fsspec.filesystem("file") if fs is None else fs
with self.fs.open("/".join([self.root, ".zmetadata"]), "rb") as f:
self._items[".zmetadata"] = f.read()
met = json.loads(self._items[".zmetadata"])
self.record_size = met["record_size"]
self.zmetadata = met["metadata"]
self.url = self.root + "/{field}/refs.{record}.parq"
self.out_root = out_root or self.root
self.cat_thresh = categorical_threshold
# Define function to open and decompress refs
@lru_cache(maxsize=cache_size)
def open_refs(field, record):
"""cached parquet file loader"""
path = self.url.format(field=field, record=record)
with self.fs.open(path) as f:
# TODO: since all we do is iterate, is arrow without pandas
# better here?
df = self.pd.read_parquet(f, engine="fastparquet")
refs = {c: df[c].values for c in df.columns}
return refs
self.open_refs = open_refs
@staticmethod
def create(record_size, root, fs, **kwargs):
met = {"metadata": {}, "record_size": record_size}
fs.pipe("/".join([root, ".zmetadata"]), json.dumps(met).encode())
return LazyReferenceMapper(root, fs, **kwargs)
def listdir(self, basename=True):
"""List top-level directories"""
if self.dirs is None:
dirs = [p.split("/", 1)[0] for p in self.zmetadata]
self.dirs = set(sorted(p for p in dirs if p and not p.startswith(".")))
listing = self.dirs
if basename:
listing = [os.path.basename(path) for path in listing]
return listing
def ls(self, path="", detail=True):
"""Shortcut file listings"""
if not path:
dirnames = self.listdir()
others = set(
[".zmetadata"]
+ [name for name in self.zmetadata if "/" not in name]
+ [name for name in self._items if "/" not in name]
)
if detail is False:
others.update(dirnames)
return sorted(others)
dirinfo = [
{"name": name, "type": "directory", "size": 0} for name in dirnames
]
fileinfo = [
{
"name": name,
"type": "file",
"size": len(
json.dumps(self.zmetadata[name])
if name in self.zmetadata
else self._items[name]
),
}
for name in others
]
return sorted(dirinfo + fileinfo, key=lambda s: s["name"])
parts = path.split("/", 1)
if len(parts) > 1:
raise FileNotFoundError("Cannot list within directories right now")
field = parts[0]
others = set(
[name for name in self.zmetadata if name.startswith(f"{path}/")]
+ [name for name in self._items if name.startswith(f"{path}/")]
)
fileinfo = [
{
"name": name,
"type": "file",
"size": len(
json.dumps(self.zmetadata[name])
if name in self.zmetadata
else self._items[name]
),
}
for name in others
]
keys = self._keys_in_field(field)
if detail is False:
return list(others) + list(keys)
recs = self._generate_all_records(field)
recinfo = [
{"name": name, "type": "file", "size": rec[-1]}
for name, rec in zip(keys, recs)
if rec[0] # filters out path==None, deleted/missing
]
return fileinfo + recinfo
def _load_one_key(self, key):
"""Get the reference for one key
Returns bytes, one-element list or three-element list.
"""
if key in self._items:
return self._items[key]
elif key in self.zmetadata:
return json.dumps(self.zmetadata[key]).encode()
elif "/" not in key or self._is_meta(key):
raise KeyError(key)
field, sub_key = key.split("/")
record, _, _ = self._key_to_record(key)
maybe = self._items.get((field, key), {}).get(sub_key, False)
if maybe is None:
# explicitly deleted
raise KeyError
elif maybe:
return maybe
# Chunk keys can be loaded from row group and cached in LRU cache
try:
record, ri, chunk_size = self._key_to_record(key)
if chunk_size == 0:
return b""
refs = self.open_refs(field, record)
except (ValueError, TypeError, FileNotFoundError):
raise KeyError(key)
columns = ["path", "offset", "size", "raw"]
selection = [refs[c][ri] if c in refs else None for c in columns]
raw = selection[-1]
if raw is not None:
return raw
if selection[0] is None:
raise KeyError("This reference has been deleted")
if selection[1:3] == [0, 0]:
# URL only
return selection[:1]
# URL, offset, size
return selection[:3]
@lru_cache(4096)
def _key_to_record(self, key):
"""Details needed to construct a reference for one key"""
field, chunk = key.split("/")
chunk_sizes = self._get_chunk_sizes(field)
if len(chunk_sizes) == 0:
return 0, 0, 0
chunk_idx = [int(c) for c in chunk.split(".")]
chunk_number = ravel_multi_index(chunk_idx, chunk_sizes)
record = chunk_number // self.record_size
ri = chunk_number % self.record_size
return record, ri, len(chunk_sizes)
def _get_chunk_sizes(self, field):
"""The number of chunks along each axis for a given field"""
if field not in self.chunk_sizes:
zarray = self.zmetadata[f"{field}/.zarray"]
size_ratio = [
math.ceil(s / c) for s, c in zip(zarray["shape"], zarray["chunks"])
]
self.chunk_sizes[field] = size_ratio
return self.chunk_sizes[field]
def _generate_record(self, field, record):
"""The references for a given parquet file of a given field"""
refs = self.open_refs(field, record)
it = iter(zip(refs.values()))
if len(refs) == 3:
# All urls
return (list(t) for t in it)
elif len(refs) == 1:
# All raws
return refs["raw"]
else:
# Mix of urls and raws
return (list(t[:3]) if not t[3] else t[3] for t in it)
def _generate_all_records(self, field):
"""Load all the references within a field by iterating over the parquet files"""
nrec = 1
for ch in self._get_chunk_sizes(field):
nrec *= ch
nrec = math.ceil(nrec / self.record_size)
for record in range(nrec):
yield from self._generate_record(field, record)
def values(self):
return RefsValuesView(self)
def items(self):
return RefsItemsView(self)
def __hash__(self):
return id(self)
@lru_cache(20)
def __getitem__(self, key):
return self._load_one_key(key)
def __setitem__(self, key, value):
if "/" in key and not self._is_meta(key):
field, chunk = key.split("/")
record, i, _ = self._key_to_record(key)
subdict = self._items.setdefault((field, record), {})
subdict[i] = value
if len(subdict) == self.record_size:
self.write(field, record)
else:
# metadata or top-level
self._items[key] = value
self.zmetadata[key] = json.loads(
value.decode() if isinstance(value, bytes) else value
)
@staticmethod
def _is_meta(key):
return key.startswith(".z") or "/.z" in key
def __delitem__(self, key):
if key in self._items:
del self._items[key]
elif key in self.zmetadata:
del self.zmetadata[key]
else:
if "/" in key and not self._is_meta(key):
field, chunk = key.split("/")
record, _, _ = self._key_to_record(key)
subdict = self._items.setdefault((field, record), {})
subdict[chunk] = None
if len(subdict) == self.record_size:
self.write(field, record)
else:
# metadata or top-level
self._items[key] = None
def write(self, field, record, base_url=None, storage_options=None):
# extra requirements if writing
import kerchunk.df
import numpy as np
import pandas as pd
# TODO: if the dict is incomplete, also load records and merge in
partition = self._items[(field, record)]
fn = f"{base_url or self.out_root}/{field}/refs.{record}.parq"
####
paths = np.full(self.record_size, np.nan, dtype="O")
offsets = np.zeros(self.record_size, dtype="int64")
sizes = np.zeros(self.record_size, dtype="int64")
raws = np.full(self.record_size, np.nan, dtype="O")
nraw = 0
npath = 0
for j, data in partition.items():
if isinstance(data, list):
npath += 1
paths[j] = data[0]
if len(data) > 1:
offsets[j] = data[1]
sizes[j] = data[2]
else:
nraw += 1
raws[j] = kerchunk.df._proc_raw(data)
# TODO: only save needed columns
df = pd.DataFrame(
dict(
path=paths,
offset=offsets,
size=sizes,
raw=raws,
),
copy=False,
)
if df.path.count() / (df.path.nunique() or 1) > self.cat_thresh:
df["path"] = df["path"].astype("category")
object_encoding = dict(raw="bytes", path="utf8")
has_nulls = ["path", "raw"]
self.fs.mkdirs(f"{base_url or self.out_root}/{field}", exist_ok=True)
df.to_parquet(
fn,
engine="fastparquet",
storage_options=storage_options
or getattr(self.fs, "storage_options", None),
compression="zstd",
index=False,
stats=False,
object_encoding=object_encoding,
has_nulls=has_nulls,
# **kwargs,
)
partition.clear()
self._items.pop((field, record))
def flush(self, base_url=None, storage_options=None):
"""Output any modified or deleted keys
Parameters
----------
base_url: str
Location of the output
"""
# write what we have so far and clear sub chunks
for thing in list(self._items):
if isinstance(thing, tuple):
field, record = thing
self.write(
field,
record,
base_url=base_url,
storage_options=storage_options,
)
# gather .zmetadata from self._items and write that too
for k in list(self._items):
if k != ".zmetadata" and ".z" in k:
self.zmetadata[k] = json.loads(self._items.pop(k))
met = {"metadata": self.zmetadata, "record_size": self.record_size}
self._items[".zmetadata"] = json.dumps(met).encode()
self.fs.pipe(
"/".join([base_url or self.out_root, ".zmetadata"]),
self._items[".zmetadata"],
)
# TODO: only clear those that we wrote to?
self.open_refs.cache_clear()
def __len__(self):
# Caveat: This counts expected references, not actual
count = 0
for field in self.listdir():
if field.startswith("."):
count += 1
else:
chunk_sizes = self._get_chunk_sizes(field)
nchunks = self.np.product(chunk_sizes)
count += nchunks
count += len(self.zmetadata) # all metadata keys
count += len(self._items) # the metadata file itself
return count
def __iter__(self):
# Caveat: Note that this generates all expected keys, but does not
# account for reference keys that are missing.
metas = set(self.zmetadata)
metas.update(self._items)
for bit in metas:
if isinstance(bit, str):
yield bit
for field in self.listdir():
yield from self._keys_in_field(field)
def __contains__(self, item):
try:
self._load_one_key(item)
return True
except KeyError:
return False
def _keys_in_field(self, field):
"""List key names in given field
Produces strings like "field/x.y" appropriate from the chunking of the array
"""
chunk_sizes = self._get_chunk_sizes(field)
if len(chunk_sizes) == 0:
yield field + "/0"
return
inds = itertools.product(*(range(i) for i in chunk_sizes))
for ind in inds:
yield field + "/" + ".".join([str(c) for c in ind])
class ReferenceFileSystem(AsyncFileSystem):
"""View byte ranges of some other file as a file system
Initial version: single file system target, which must support
async, and must allow start and end args in _cat_file. Later versions
may allow multiple arbitrary URLs for the targets.
This FileSystem is read-only. It is designed to be used with async
targets (for now). This FileSystem only allows whole-file access, no
``open``. We do not get original file details from the target FS.
Configuration is by passing a dict of references at init, or a URL to
a JSON file containing the same; this dict
can also contain concrete data for some set of paths.
Reference dict format:
{path0: bytes_data, path1: (target_url, offset, size)}
https://github.com/fsspec/kerchunk/blob/main/README.md
"""
protocol = "reference"
def __init__(
self,
fo,
target=None,
ref_storage_args=None,
target_protocol=None,
target_options=None,
remote_protocol=None,
remote_options=None,
fs=None,
template_overrides=None,
simple_templates=True,
max_gap=64_000,
max_block=256_000_000,
cache_size=128,
**kwargs,
):
"""
Parameters
----------
fo : dict or str
The set of references to use for this instance, with a structure as above.
If str referencing a JSON file, will use fsspec.open, in conjunction
with target_options and target_protocol to open and parse JSON at this
location. If a directory, then assume references are a set of parquet
files to be loaded lazily.
target : str
For any references having target_url as None, this is the default file
target to use
ref_storage_args : dict
If references is a str, use these kwargs for loading the JSON file.
Deprecated: use target_options instead.
target_protocol : str
Used for loading the reference file, if it is a path. If None, protocol
will be derived from the given path
target_options : dict
Extra FS options for loading the reference file ``fo``, if given as a path
remote_protocol : str
The protocol of the filesystem on which the references will be evaluated
(unless fs is provided). If not given, will be derived from the first
URL that has a protocol in the templates or in the references, in that
order.
remote_options : dict
kwargs to go with remote_protocol
fs : AbstractFileSystem | dict(str, (AbstractFileSystem | dict))
Directly provide a file system(s):
- a single filesystem instance
- a dict of protocol:filesystem, where each value is either a filesystem
instance, or a dict of kwargs that can be used to create in
instance for the given protocol
If this is given, remote_options and remote_protocol are ignored.
template_overrides : dict
Swap out any templates in the references file with these - useful for
testing.
simple_templates: bool
Whether templates can be processed with simple replace (True) or if
jinja is needed (False, much slower). All reference sets produced by
``kerchunk`` are simple in this sense, but the spec allows for complex.
max_gap, max_block: int
For merging multiple concurrent requests to the same remote file.
Neighboring byte ranges will only be merged when their
inter-range gap is <= ``max_gap``. Default is 64KB. Set to 0
to only merge when it requires no extra bytes. Pass a negative
number to disable merging, appropriate for local target files.
Neighboring byte ranges will only be merged when the size of
the aggregated range is <= ``max_block``. Default is 256MB.
cache_size : int
Maximum size of LRU cache, where cache_size*record_size denotes
the total number of references that can be loaded in memory at once.
Only used for lazily loaded references.
kwargs : passed to parent class
"""
super().__init__(**kwargs)
self.target = target
self.template_overrides = template_overrides
self.simple_templates = simple_templates
self.templates = {}
self.fss = {}
self._dircache = {}
self.max_gap = max_gap
self.max_block = max_block
if isinstance(fo, str):
dic = dict(
**(ref_storage_args or target_options or {}), protocol=target_protocol
)
ref_fs, fo2 = fsspec.core.url_to_fs(fo, **dic)
if ref_fs.isfile(fo):
# text JSON
with fsspec.open(fo, "rb", **dic) as f:
logger.info("Read reference from URL %s", fo)
text = json.load(f)
self._process_references(text, template_overrides)
else:
# Lazy parquet refs
logger.info("Open lazy reference dict from URL %s", fo)
self.references = LazyReferenceMapper(
fo2,
fs=ref_fs,
cache_size=cache_size,
)
else:
# dictionaries
self._process_references(fo, template_overrides)
if isinstance(fs, dict):
self.fss = {
k: (
fsspec.filesystem(k.split(":", 1)[0], **opts)
if isinstance(opts, dict)
else opts
)
for k, opts in fs.items()
}
if None not in self.fss:
self.fss[None] = filesystem("file")
return
if fs is not None:
# single remote FS
remote_protocol = (
fs.protocol[0] if isinstance(fs.protocol, tuple) else fs.protocol
)
self.fss[remote_protocol] = fs
if remote_protocol is None:
# get single protocol from any templates
for ref in self.templates.values():
if callable(ref):
ref = ref()
protocol, _ = fsspec.core.split_protocol(ref)
if protocol and protocol not in self.fss:
fs = filesystem(protocol, **(remote_options or {}))
self.fss[protocol] = fs
if remote_protocol is None:
# get single protocol from references
for ref in self.references.values():
if callable(ref):
ref = ref()
if isinstance(ref, list) and ref[0]:
protocol, _ = fsspec.core.split_protocol(ref[0])
if protocol not in self.fss:
fs = filesystem(protocol, **(remote_options or {}))
self.fss[protocol] = fs
# only use first remote URL
break
if remote_protocol and remote_protocol not in self.fss:
fs = filesystem(remote_protocol, **(remote_options or {}))
self.fss[remote_protocol] = fs
self.fss[None] = fs or filesystem("file") # default one
def _cat_common(self, path, start=None, end=None):
path = self._strip_protocol(path)
logger.debug(f"cat: {path}")
try:
part = self.references[path]
except KeyError:
raise FileNotFoundError(path)
if isinstance(part, str):
part = part.encode()
if isinstance(part, bytes):
logger.debug(f"Reference: {path}, type bytes")
if part.startswith(b"base64:"):
part = base64.b64decode(part[7:])
return part, None, None
if len(part) == 1:
logger.debug(f"Reference: {path}, whole file => {part}")
url = part[0]
start1, end1 = start, end
else:
url, start0, size = part
logger.debug(f"Reference: {path} => {url}, offset {start0}, size {size}")
end0 = start0 + size
if start is not None:
if start >= 0:
start1 = start0 + start
else:
start1 = end0 + start
else:
start1 = start0
if end is not None:
if end >= 0:
end1 = start0 + end
else:
end1 = end0 + end
else:
end1 = end0
if url is None:
url = self.target
return url, start1, end1
async def _cat_file(self, path, start=None, end=None, **kwargs):
part_or_url, start0, end0 = self._cat_common(path, start=start, end=end)
if isinstance(part_or_url, bytes):
return part_or_url[start:end]
protocol, _ = split_protocol(part_or_url)
try:
await self.fss[protocol]._cat_file(part_or_url, start=start, end=end)
except Exception as e:
raise ReferenceNotReachable(path, part_or_url) from e
def cat_file(self, path, start=None, end=None, **kwargs):
part_or_url, start0, end0 = self._cat_common(path, start=start, end=end)
if isinstance(part_or_url, bytes):
return part_or_url[start:end]
protocol, _ = split_protocol(part_or_url)
try:
return self.fss[protocol].cat_file(part_or_url, start=start0, end=end0)
except Exception as e:
raise ReferenceNotReachable(path, part_or_url) from e
def pipe_file(self, path, value, **_):
"""Temporarily add binary data or reference as a file"""
self.references[path] = value
async def _get_file(self, rpath, lpath, **kwargs):
if self.isdir(rpath):
return os.makedirs(lpath, exist_ok=True)
data = await self._cat_file(rpath)
with open(lpath, "wb") as f:
f.write(data)
def get_file(self, rpath, lpath, callback=_DEFAULT_CALLBACK, **kwargs):
if self.isdir(rpath):
return os.makedirs(lpath, exist_ok=True)
data = self.cat_file(rpath, **kwargs)
callback.set_size(len(data))
if isfilelike(lpath):
lpath.write(data)
else:
with open(lpath, "wb") as f:
f.write(data)
callback.absolute_update(len(data))
def get(self, rpath, lpath, recursive=False, **kwargs):
if recursive:
# trigger directory build
self.ls("")
rpath = self.expand_path(rpath, recursive=recursive)
fs = fsspec.filesystem("file", auto_mkdir=True)
targets = other_paths(rpath, lpath)
if recursive:
data = self.cat([r for r in rpath if not self.isdir(r)])
else:
data = self.cat(rpath)
for remote, local in zip(rpath, targets):
if remote in data:
fs.pipe_file(local, data[remote])
def cat(self, path, recursive=False, on_error="raise", **kwargs):
if isinstance(path, str) and recursive:
raise NotImplementedError
if isinstance(path, list) and (recursive or any("*" in p for p in path)):
raise NotImplementedError
proto_dict = _protocol_groups(path, self.references)
out = {}
for proto, paths in proto_dict.items():
fs = self.fss[proto]
urls, starts, ends = [], [], []
for p in paths:
# find references or label not-found. Early exit if any not
# found and on_error is "raise"
try:
u, s, e = self._cat_common(p)
urls.append(u)
starts.append(s)
ends.append(e)
except FileNotFoundError as err:
if on_error == "raise":
raise
if on_error != "omit":
out[p] = err
# process references into form for merging
urls2 = []
starts2 = []
ends2 = []
paths2 = []
whole_files = set()
for u, s, e, p in zip(urls, starts, ends, paths):
if isinstance(u, bytes):
# data
out[p] = u
elif s is None:
# whole file - limits are None, None, but no further
# entries take for this file
whole_files.add(u)
urls2.append(u)
starts2.append(s)
ends2.append(e)
paths2.append(p)
for u, s, e, p in zip(urls, starts, ends, paths):
# second run to account for files that are to be loaded whole
if s is not None and u not in whole_files:
urls2.append(u)
starts2.append(s)
ends2.append(e)
paths2.append(p)
# merge and fetch consolidated ranges
new_paths, new_starts, new_ends = merge_offset_ranges(
list(urls2),
list(starts2),
list(ends2),
sort=True,
max_gap=self.max_gap,
max_block=self.max_block,
)
bytes_out = fs.cat_ranges(new_paths, new_starts, new_ends)
# unbundle from merged bytes - simple approach
for u, s, e, p in zip(urls, starts, ends, paths):
if p in out:
continue # was bytes, already handled
for np, ns, ne, b in zip(new_paths, new_starts, new_ends, bytes_out):
if np == u and (ns is None or ne is None):
if isinstance(b, Exception):
out[p] = b
else:
out[p] = b[s:e]
elif np == u and s >= ns and e <= ne:
if isinstance(b, Exception):
out[p] = b
else:
out[p] = b[s - ns : (e - ne) or None]
for k, v in out.copy().items():
# these were valid references, but fetch failed, so transform exc
if isinstance(v, Exception) and k in self.references:
ex = out[k]
new_ex = ReferenceNotReachable(k, self.references[k])
new_ex.__cause__ = ex
if on_error == "raise":
raise new_ex
elif on_error != "omit":
out[k] = new_ex
if len(out) == 1 and isinstance(path, str) and "*" not in path:
return _first(out)
return out
def _process_references(self, references, template_overrides=None):
vers = references.get("version", None)
if vers is None:
self._process_references0(references)
elif vers == 1:
self._process_references1(references, template_overrides=template_overrides)
else:
raise ValueError(f"Unknown reference spec version: {vers}")
# TODO: we make dircache by iterating over all entries, but for Spec >= 1,
# can replace with programmatic. Is it even needed for mapper interface?
def _process_references0(self, references):
"""Make reference dict for Spec Version 0"""
self.references = references
def _process_references1(self, references, template_overrides=None):
if not self.simple_templates or self.templates:
import jinja2
self.references = {}
self._process_templates(references.get("templates", {}))
@lru_cache(1000)
def _render_jinja(u):
return jinja2.Template(u).render(**self.templates)
for k, v in references.get("refs", {}).items():
if isinstance(v, str):
if v.startswith("base64:"):
self.references[k] = base64.b64decode(v[7:])
self.references[k] = v
elif self.templates:
u = v[0]
if "{{" in u:
if self.simple_templates:
u = (
u.replace("{{", "{")
.replace("}}", "}")
.format(**self.templates)
)
else:
u = _render_jinja(u)
self.references[k] = [u] if len(v) == 1 else [u, v[1], v[2]]
else:
self.references[k] = v
self.references.update(self._process_gen(references.get("gen", [])))
def _process_templates(self, tmp):
self.templates = {}
if self.template_overrides is not None:
tmp.update(self.template_overrides)
for k, v in tmp.items():
if "{{" in v:
import jinja2
self.templates[k] = lambda temp=v, **kwargs: jinja2.Template(
temp
).render(**kwargs)
else:
self.templates[k] = v
def _process_gen(self, gens):
out = {}
for gen in gens:
dimension = {
k: v
if isinstance(v, list)
else range(v.get("start", 0), v["stop"], v.get("step", 1))
for k, v in gen["dimensions"].items()
}
products = (
dict(zip(dimension.keys(), values))
for values in itertools.product(*dimension.values())
)
for pr in products:
import jinja2
key = jinja2.Template(gen["key"]).render(**pr, **self.templates)
url = jinja2.Template(gen["url"]).render(**pr, **self.templates)
if ("offset" in gen) and ("length" in gen):
offset = int(
jinja2.Template(gen["offset"]).render(**pr, **self.templates)
)
length = int(
jinja2.Template(gen["length"]).render(**pr, **self.templates)
)
out[key] = [url, offset, length]
elif ("offset" in gen) ^ ("length" in gen):
raise ValueError(
"Both 'offset' and 'length' are required for a "
"reference generator entry if either is provided."
)
else:
out[key] = [url]
return out
def _dircache_from_items(self):
self.dircache = {"": []}
it = self.references.items()
for path, part in it:
if isinstance(part, (bytes, str)):
size = len(part)
elif len(part) == 1:
size = None
else:
_, start, size = part
par = path.rsplit("/", 1)[0] if "/" in path else ""
par0 = par
while par0 and par0 not in self.dircache:
# build parent directories
self.dircache[par0] = []
self.dircache.setdefault(
par0.rsplit("/", 1)[0] if "/" in par0 else "", []
).append({"name": par0, "type": "directory", "size": 0})
par0 = self._parent(par0)
self.dircache[par].append({"name": path, "type": "file", "size": size})
def _open(self, path, mode="rb", block_size=None, cache_options=None, **kwargs):
data = self.cat_file(path) # load whole chunk into memory
return io.BytesIO(data)
def ls(self, path, detail=True, **kwargs):
path = self._strip_protocol(path)
if isinstance(self.references, LazyReferenceMapper):
try:
return self.references.ls(path, detail)
except KeyError:
pass
raise FileNotFoundError(f"'{path}' is not a known key")
if not self.dircache:
self._dircache_from_items()
out = self._ls_from_cache(path)
if out is None:
raise FileNotFoundError(path)
if detail:
return out
return [o["name"] for o in out]
def exists(self, path, **kwargs): # overwrite auto-sync version
return self.isdir(path) or self.isfile(path)
def isdir(self, path): # overwrite auto-sync version
if self.dircache:
return path in self.dircache
elif isinstance(self.references, LazyReferenceMapper):
return path in self.references.listdir("")
else:
# this may be faster than building dircache for single calls, but
# by looping will be slow for many calls; could cache it?
return any(_.startswith(f"{path}/") for _ in self.references)
def isfile(self, path): # overwrite auto-sync version
return path in self.references
async def _ls(self, path, detail=True, **kwargs): # calls fast sync code
return self.ls(path, detail, **kwargs)
def find(self, path, maxdepth=None, withdirs=False, detail=False, **kwargs):
if withdirs:
return super().find(
path, maxdepth=maxdepth, withdirs=withdirs, detail=detail, **kwargs
)
if path:
path = self._strip_protocol(path)
r = sorted(k for k in self.references if k.startswith(path))
else:
r = sorted(self.references)
if detail:
if not self.dircache:
self._dircache_from_items()
return {k: self._ls_from_cache(k)[0] for k in r}
else:
return r
def info(self, path, **kwargs):
out = self.references.get(path)
if out is not None:
if isinstance(out, (str, bytes)):
# decode base64 here
return {"name": path, "type": "file", "size": len(out)}
elif len(out) > 1:
return {"name": path, "type": "file", "size": out[2]}
else:
out0 = [{"name": path, "type": "file", "size": None}]
else:
out = self.ls(path, True)
out0 = [o for o in out if o["name"] == path]
if not out0:
return {"name": path, "type": "directory", "size": 0}
if out0[0]["size"] is None:
# if this is a whole remote file, update size using remote FS
prot, _ = split_protocol(self.references[path][0])
out0[0]["size"] = self.fss[prot].size(self.references[path][0])
return out0[0]
async def _info(self, path, **kwargs): # calls fast sync code
return self.info(path)
async def _rm_file(self, path, **kwargs):
self.references.pop(
path, None
) # ignores FileNotFound, just as well for directories
self.dircache.clear() # this is a bit heavy handed
async def _pipe_file(self, path, data):
# can be str or bytes
self.references[path] = data
self.dircache.clear() # this is a bit heavy handed
async def _put_file(self, lpath, rpath):
# puts binary
with open(lpath, "rb") as f:
self.references[rpath] = f.read()
self.dircache.clear() # this is a bit heavy handed
def save_json(self, url, **storage_options):
"""Write modified references into new location"""
out = {}
for k, v in self.references.items():
if isinstance(v, bytes):
try:
out[k] = v.decode("ascii")
except UnicodeDecodeError:
out[k] = (b"base64:" + base64.b64encode(v)).decode()
else:
out[k] = v
with fsspec.open(url, "wb", **storage_options) as f:
f.write(json.dumps({"version": 1, "refs": out}).encode())