Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import fnmatch | |
| import logging | |
| import os | |
| import tarfile | |
| from typing import IO, Union | |
| LOGGER = logging.getLogger("NeMo") | |
| try: | |
| from zarr.storage import BaseStore | |
| HAVE_ZARR = True | |
| except Exception as e: | |
| LOGGER.warning(f"Cannot import zarr, support for zarr-based checkpoints is not available. {type(e).__name__}: {e}") | |
| BaseStore = object | |
| HAVE_ZARR = False | |
| class TarPath: | |
| """ | |
| A class that represents a path inside a TAR archive and behaves like pathlib.Path. | |
| Expected use is to create a TarPath for the root of the archive first, and then derive | |
| paths to other files or directories inside the archive like so: | |
| with TarPath('/path/to/archive.tar') as archive: | |
| myfile = archive / 'filename.txt' | |
| if myfile.exists(): | |
| data = myfile.read() | |
| ... | |
| Only read and enumeration operations are supported. | |
| """ | |
| def __init__(self, tar: Union[str, tarfile.TarFile, 'TarPath'], *parts): | |
| self._needs_to_close = False | |
| self._relpath = '' | |
| if isinstance(tar, TarPath): | |
| self._tar = tar._tar | |
| self._relpath = os.path.join(tar._relpath, *parts) | |
| elif isinstance(tar, tarfile.TarFile): | |
| self._tar = tar | |
| if parts: | |
| self._relpath = os.path.join(*parts) | |
| elif isinstance(tar, str): | |
| self._needs_to_close = True | |
| self._tar = tarfile.open(tar, 'r') | |
| if parts: | |
| self._relpath = os.path.join(*parts) | |
| else: | |
| raise ValueError(f"Unexpected argument type for TarPath: {type(tar).__name__}") | |
| def __del__(self): | |
| if self._needs_to_close: | |
| self._tar.close() | |
| def __truediv__(self, key) -> 'TarPath': | |
| return TarPath(self._tar, os.path.join(self._relpath, key)) | |
| def __str__(self) -> str: | |
| return os.path.join(self._tar.name, self._relpath) | |
| def tarobject(self): | |
| """ | |
| Returns the wrapped tar object. | |
| """ | |
| return self._tar | |
| def relpath(self): | |
| """ | |
| Returns the relative path of the path. | |
| """ | |
| return self._relpath | |
| def name(self): | |
| """ | |
| Returns the name of the path. | |
| """ | |
| return os.path.split(self._relpath)[1] | |
| def suffix(self): | |
| """ | |
| Returns the suffix of the path. | |
| """ | |
| name = self.name | |
| i = name.rfind('.') | |
| if 0 < i < len(name) - 1: | |
| return name[i:] | |
| else: | |
| return '' | |
| def __enter__(self): | |
| self._tar.__enter__() | |
| return self | |
| def __exit__(self, *args): | |
| return self._tar.__exit__(*args) | |
| def exists(self): | |
| """ | |
| Checks if the path exists. | |
| """ | |
| try: | |
| self._tar.getmember(self._relpath) | |
| return True | |
| except KeyError: | |
| try: | |
| self._tar.getmember(os.path.join('.', self._relpath)) | |
| return True | |
| except KeyError: | |
| return False | |
| def is_file(self): | |
| """ | |
| Checks if the path is a file. | |
| """ | |
| try: | |
| self._tar.getmember(self._relpath).isreg() | |
| return True | |
| except KeyError: | |
| try: | |
| self._tar.getmember(os.path.join('.', self._relpath)).isreg() | |
| return True | |
| except KeyError: | |
| return False | |
| def is_dir(self): | |
| """ | |
| Checks if the path is a directory. | |
| """ | |
| try: | |
| self._tar.getmember(self._relpath).isdir() | |
| return True | |
| except KeyError: | |
| try: | |
| self._tar.getmember(os.path.join('.', self._relpath)).isdir() | |
| return True | |
| except KeyError: | |
| return False | |
| def open(self, mode: str) -> IO[bytes]: | |
| """ | |
| Opens a file in the archive. | |
| """ | |
| if mode != 'r' and mode != 'rb': | |
| raise NotImplementedError() | |
| file = None | |
| try: | |
| # Try the relative path as-is first | |
| file = self._tar.extractfile(self._relpath) | |
| except KeyError: | |
| try: | |
| # Try the relative path with "./" prefix | |
| file = self._tar.extractfile(os.path.join('.', self._relpath)) | |
| except KeyError: | |
| raise FileNotFoundError() | |
| if file is None: | |
| raise FileNotFoundError() | |
| return file | |
| def glob(self, pattern): | |
| """ | |
| Returns an iterator over the files in the directory, matching the pattern. | |
| """ | |
| for member in self._tar.getmembers(): | |
| # Remove the "./" prefix, if any | |
| name = member.name[2:] if member.name.startswith('./') else member.name | |
| # If we're in a subdirectory, make sure the file is too, and remove that subdir component | |
| if self._relpath: | |
| if not name.startswith(self._relpath + '/'): | |
| continue | |
| name = name[len(self._relpath) + 1 :] | |
| # See if the name matches the pattern | |
| if fnmatch.fnmatch(name, pattern): | |
| yield TarPath(self._tar, os.path.join(self._relpath, name)) | |
| def rglob(self, pattern): | |
| """ | |
| Returns an iterator over the files in the directory, including subdirectories. | |
| """ | |
| for member in self._tar.getmembers(): | |
| # Remove the "./" prefix, if any | |
| name = member.name[2:] if member.name.startswith('./') else member.name | |
| # If we're in a subdirectory, make sure the file is too, and remove that subdir component | |
| if self._relpath: | |
| if not name.startswith(self._relpath + '/'): | |
| continue | |
| name = name[len(self._relpath) + 1 :] | |
| # See if any tail of the path matches the pattern, return full path if that's true | |
| parts = name.split('/') | |
| for i in range(len(parts)): | |
| subname = '/'.join(parts[i:]) | |
| if fnmatch.fnmatch(subname, pattern): | |
| yield TarPath(self._tar, os.path.join(self._relpath, name)) | |
| break | |
| def iterdir(self): | |
| """ | |
| Returns an iterator over the files in the directory. | |
| """ | |
| return self.glob('*') | |
| class ZarrPathStore(BaseStore): | |
| """ | |
| An implementation of read-only Store for zarr library | |
| that works with pathlib.Path or TarPath objects. | |
| """ | |
| def __init__(self, tarpath: TarPath): | |
| assert HAVE_ZARR, "Package zarr>=2.18.2,<3.0.0 is required to use ZarrPathStore" | |
| self._path = tarpath | |
| self._writable = False | |
| self._erasable = False | |
| def __getitem__(self, key): | |
| with (self._path / key).open('rb') as file: | |
| return file.read() | |
| def __contains__(self, key): | |
| return (self._path / key).is_file() | |
| def __iter__(self): | |
| return self.keys() | |
| def __len__(self): | |
| return sum(1 for _ in self.keys()) | |
| def __setitem__(self, key, value): | |
| raise NotImplementedError() | |
| def __delitem__(self, key): | |
| raise NotImplementedError() | |
| def keys(self): | |
| """ | |
| Returns an iterator over the keys in the store. | |
| """ | |
| return self._path.iterdir() | |