Spaces:
Running
on
A100
Running
on
A100
# Copyright (c) 2025 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
# LICENSE is in incl_licenses directory. | |
import json | |
import os | |
import pickle | |
from contextlib import contextmanager | |
from typing import IO, Any, BinaryIO, Callable, Dict, Iterator, TextIO, Union | |
import numpy as np | |
import torch | |
import yaml | |
__all__ = [ | |
"load", | |
"save", | |
"load_json", | |
"save_json", | |
"load_jsonl", | |
"save_jsonl", | |
"load_mat", | |
"save_mat", | |
"load_npy", | |
"save_npy", | |
"load_npz", | |
"save_npz", | |
"load_pt", | |
"save_pt", | |
"load_yaml", | |
"save_yaml", | |
] | |
def file_descriptor(f: Union[str, IO], mode: str = "r") -> Iterator[IO]: | |
opened = False | |
try: | |
if isinstance(f, str): | |
f = open(f, mode) | |
opened = True | |
yield f | |
finally: | |
if opened: | |
f.close() | |
def load_json(f: Union[str, TextIO], **kwargs) -> Any: | |
with file_descriptor(f, mode="r") as fd: | |
return json.load(fd, **kwargs) | |
def save_json(f: Union[str, TextIO], obj: Any, **kwargs) -> None: | |
with file_descriptor(f, mode="w") as fd: | |
json.dump(obj, fd, **kwargs) | |
def load_jsonl(f: Union[str, TextIO], **kwargs) -> Any: | |
with file_descriptor(f, mode="r") as fd: | |
return [json.loads(datum, **kwargs) for datum in fd.readlines()] | |
def save_jsonl(f: Union[str, TextIO], obj: Any, **kwargs) -> None: | |
with file_descriptor(f, mode="w") as fd: | |
fd.write("\n".join(json.dumps(datum, **kwargs) for datum in obj)) | |
def load_mat(f: Union[str, BinaryIO], **kwargs) -> Any: | |
import scipy.io | |
return scipy.io.loadmat(f, **kwargs) | |
def save_mat(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None: | |
import scipy.io | |
scipy.io.savemat(f, obj, **kwargs) | |
def load_npy(f: Union[str, BinaryIO], **kwargs) -> Any: | |
return np.load(f, **kwargs) | |
def save_npy(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None: | |
np.save(f, obj, **kwargs) | |
def load_npz(f: Union[str, BinaryIO], **kwargs) -> Any: | |
return np.load(f, **kwargs) | |
def save_npz(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None: | |
np.savez(f, obj, **kwargs) | |
def load_pkl(f: Union[str, BinaryIO], **kwargs) -> Any: | |
with file_descriptor(f, mode="rb") as fd: | |
try: | |
return pickle.load(fd, **kwargs) | |
except UnicodeDecodeError: | |
if "encoding" in kwargs: | |
raise | |
fd.seek(0) | |
return pickle.load(fd, encoding="latin1", **kwargs) | |
def save_pkl(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None: | |
with file_descriptor(f, mode="wb") as fd: | |
pickle.dump(obj, fd, **kwargs) | |
def load_pt(f: Union[str, BinaryIO], **kwargs) -> Any: | |
return torch.load(f, **kwargs) | |
def save_pt(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None: | |
torch.save(obj, f, **kwargs) | |
def load_yaml(f: Union[str, TextIO]) -> Any: | |
with file_descriptor(f, mode="r") as fd: | |
return yaml.safe_load(fd) | |
def save_yaml(f: Union[str, TextIO], obj: Any, **kwargs) -> None: | |
with file_descriptor(f, mode="w") as fd: | |
yaml.safe_dump(obj, fd, **kwargs) | |
def load_txt(f: Union[str, TextIO]) -> Any: | |
with file_descriptor(f, mode="r") as fd: | |
return fd.read() | |
def save_txt(f: Union[str, TextIO], obj: Any, **kwargs) -> None: | |
with file_descriptor(f, mode="w") as fd: | |
fd.write(obj) | |
__io_registry: Dict[str, Dict[str, Callable]] = { | |
".txt": {"load": load_txt, "save": save_txt}, | |
".json": {"load": load_json, "save": save_json}, | |
".jsonl": {"load": load_jsonl, "save": save_jsonl}, | |
".mat": {"load": load_mat, "save": save_mat}, | |
".npy": {"load": load_npy, "save": save_npy}, | |
".npz": {"load": load_npz, "save": save_npz}, | |
".pkl": {"load": load_pkl, "save": save_pkl}, | |
".pt": {"load": load_pt, "save": save_pt}, | |
".pth": {"load": load_pt, "save": save_pt}, | |
".pth.tar": {"load": load_pt, "save": save_pt}, | |
".yaml": {"load": load_yaml, "save": save_yaml}, | |
".yml": {"load": load_yaml, "save": save_yaml}, | |
} | |
def load(fpath: str, **kwargs) -> Any: | |
assert isinstance(fpath, str), type(fpath) | |
for extension in sorted(__io_registry.keys(), key=len, reverse=True): | |
if fpath.endswith(extension) and "load" in __io_registry[extension]: | |
return __io_registry[extension]["load"](fpath, **kwargs) | |
raise NotImplementedError(f'"{fpath}" cannot be loaded.') | |
def save(fpath: str, obj: Any, **kwargs) -> None: | |
assert isinstance(fpath, str), type(fpath) | |
os.makedirs(os.path.dirname(fpath), exist_ok=True) | |
for extension in sorted(__io_registry.keys(), key=len, reverse=True): | |
if fpath.endswith(extension) and "save" in __io_registry[extension]: | |
__io_registry[extension]["save"](fpath, obj, **kwargs) | |
return | |
raise NotImplementedError(f'"{fpath}" cannot be saved.') | |