Spaces:
Runtime error
Runtime error
import importlib | |
from abc import ABC, abstractmethod | |
from pickle import ( # type: ignore[attr-defined] # type: ignore[attr-defined] | |
_getattribute, | |
_Pickler, | |
whichmodule as _pickle_whichmodule, | |
) | |
from types import ModuleType | |
from typing import Any, Dict, List, Optional, Tuple | |
from ._mangling import demangle, get_mangle_prefix, is_mangled | |
__all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"] | |
class ObjNotFoundError(Exception): | |
"""Raised when an importer cannot find an object by searching for its name.""" | |
pass | |
class ObjMismatchError(Exception): | |
"""Raised when an importer found a different object with the same name as the user-provided one.""" | |
pass | |
class Importer(ABC): | |
"""Represents an environment to import modules from. | |
By default, you can figure out what module an object belongs by checking | |
__module__ and importing the result using __import__ or importlib.import_module. | |
torch.package introduces module importers other than the default one. | |
Each PackageImporter introduces a new namespace. Potentially a single | |
name (e.g. 'foo.bar') is present in multiple namespaces. | |
It supports two main operations: | |
import_module: module_name -> module object | |
get_name: object -> (parent module name, name of obj within module) | |
The guarantee is that following round-trip will succeed or throw an ObjNotFoundError/ObjMisMatchError. | |
module_name, obj_name = env.get_name(obj) | |
module = env.import_module(module_name) | |
obj2 = getattr(module, obj_name) | |
assert obj1 is obj2 | |
""" | |
modules: Dict[str, ModuleType] | |
def import_module(self, module_name: str) -> ModuleType: | |
"""Import `module_name` from this environment. | |
The contract is the same as for importlib.import_module. | |
""" | |
pass | |
def get_name(self, obj: Any, name: Optional[str] = None) -> Tuple[str, str]: | |
"""Given an object, return a name that can be used to retrieve the | |
object from this environment. | |
Args: | |
obj: An object to get the module-environment-relative name for. | |
name: If set, use this name instead of looking up __name__ or __qualname__ on `obj`. | |
This is only here to match how Pickler handles __reduce__ functions that return a string, | |
don't use otherwise. | |
Returns: | |
A tuple (parent_module_name, attr_name) that can be used to retrieve `obj` from this environment. | |
Use it like: | |
mod = importer.import_module(parent_module_name) | |
obj = getattr(mod, attr_name) | |
Raises: | |
ObjNotFoundError: we couldn't retrieve `obj by name. | |
ObjMisMatchError: we found a different object with the same name as `obj`. | |
""" | |
if name is None and obj and _Pickler.dispatch.get(type(obj)) is None: | |
# Honor the string return variant of __reduce__, which will give us | |
# a global name to search for in this environment. | |
# TODO: I guess we should do copyreg too? | |
reduce = getattr(obj, "__reduce__", None) | |
if reduce is not None: | |
try: | |
rv = reduce() | |
if isinstance(rv, str): | |
name = rv | |
except Exception: | |
pass | |
if name is None: | |
name = getattr(obj, "__qualname__", None) | |
if name is None: | |
name = obj.__name__ | |
orig_module_name = self.whichmodule(obj, name) | |
# Demangle the module name before importing. If this obj came out of a | |
# PackageImporter, `__module__` will be mangled. See mangling.md for | |
# details. | |
module_name = demangle(orig_module_name) | |
# Check that this name will indeed return the correct object | |
try: | |
module = self.import_module(module_name) | |
obj2, _ = _getattribute(module, name) | |
except (ImportError, KeyError, AttributeError): | |
raise ObjNotFoundError( | |
f"{obj} was not found as {module_name}.{name}" | |
) from None | |
if obj is obj2: | |
return module_name, name | |
def get_obj_info(obj): | |
assert name is not None | |
module_name = self.whichmodule(obj, name) | |
is_mangled_ = is_mangled(module_name) | |
location = ( | |
get_mangle_prefix(module_name) | |
if is_mangled_ | |
else "the current Python environment" | |
) | |
importer_name = ( | |
f"the importer for {get_mangle_prefix(module_name)}" | |
if is_mangled_ | |
else "'sys_importer'" | |
) | |
return module_name, location, importer_name | |
obj_module_name, obj_location, obj_importer_name = get_obj_info(obj) | |
obj2_module_name, obj2_location, obj2_importer_name = get_obj_info(obj2) | |
msg = ( | |
f"\n\nThe object provided is from '{obj_module_name}', " | |
f"which is coming from {obj_location}." | |
f"\nHowever, when we import '{obj2_module_name}', it's coming from {obj2_location}." | |
"\nTo fix this, make sure this 'PackageExporter's importer lists " | |
f"{obj_importer_name} before {obj2_importer_name}." | |
) | |
raise ObjMismatchError(msg) | |
def whichmodule(self, obj: Any, name: str) -> str: | |
"""Find the module name an object belongs to. | |
This should be considered internal for end-users, but developers of | |
an importer can override it to customize the behavior. | |
Taken from pickle.py, but modified to exclude the search into sys.modules | |
""" | |
module_name = getattr(obj, "__module__", None) | |
if module_name is not None: | |
return module_name | |
# Protect the iteration by using a list copy of self.modules against dynamic | |
# modules that trigger imports of other modules upon calls to getattr. | |
for module_name, module in self.modules.copy().items(): | |
if ( | |
module_name == "__main__" | |
or module_name == "__mp_main__" # bpo-42406 | |
or module is None | |
): | |
continue | |
try: | |
if _getattribute(module, name)[0] is obj: | |
return module_name | |
except AttributeError: | |
pass | |
return "__main__" | |
class _SysImporter(Importer): | |
"""An importer that implements the default behavior of Python.""" | |
def import_module(self, module_name: str): | |
return importlib.import_module(module_name) | |
def whichmodule(self, obj: Any, name: str) -> str: | |
return _pickle_whichmodule(obj, name) | |
sys_importer = _SysImporter() | |
class OrderedImporter(Importer): | |
"""A compound importer that takes a list of importers and tries them one at a time. | |
The first importer in the list that returns a result "wins". | |
""" | |
def __init__(self, *args): | |
self._importers: List[Importer] = list(args) | |
def _is_torchpackage_dummy(self, module): | |
"""Returns true iff this module is an empty PackageNode in a torch.package. | |
If you intern `a.b` but never use `a` in your code, then `a` will be an | |
empty module with no source. This can break cases where we are trying to | |
re-package an object after adding a real dependency on `a`, since | |
OrderedImportere will resolve `a` to the dummy package and stop there. | |
See: https://github.com/pytorch/pytorch/pull/71520#issuecomment-1029603769 | |
""" | |
if not getattr(module, "__torch_package__", False): | |
return False | |
if not hasattr(module, "__path__"): | |
return False | |
if not hasattr(module, "__file__"): | |
return True | |
return module.__file__ is None | |
def import_module(self, module_name: str) -> ModuleType: | |
last_err = None | |
for importer in self._importers: | |
if not isinstance(importer, Importer): | |
raise TypeError( | |
f"{importer} is not a Importer. " | |
"All importers in OrderedImporter must inherit from Importer." | |
) | |
try: | |
module = importer.import_module(module_name) | |
if self._is_torchpackage_dummy(module): | |
continue | |
return module | |
except ModuleNotFoundError as err: | |
last_err = err | |
if last_err is not None: | |
raise last_err | |
else: | |
raise ModuleNotFoundError(module_name) | |
def whichmodule(self, obj: Any, name: str) -> str: | |
for importer in self._importers: | |
module_name = importer.whichmodule(obj, name) | |
if module_name != "__main__": | |
return module_name | |
return "__main__" | |