|
import importlib |
|
from abc import ABC, abstractmethod |
|
from pickle import ( |
|
_getattribute, |
|
_Pickler, |
|
whichmodule as _pickle_whichmodule, |
|
) |
|
from types import ModuleType |
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
from ._mangling import demangle, get_mangle_prefix, is_mangled |
|
|
|
__all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"] |
|
|
|
|
|
class ObjNotFoundError(Exception): |
|
"""Raised when an importer cannot find an object by searching for its name.""" |
|
|
|
pass |
|
|
|
|
|
class ObjMismatchError(Exception): |
|
"""Raised when an importer found a different object with the same name as the user-provided one.""" |
|
|
|
pass |
|
|
|
|
|
class Importer(ABC): |
|
"""Represents an environment to import modules from. |
|
|
|
By default, you can figure out what module an object belongs by checking |
|
__module__ and importing the result using __import__ or importlib.import_module. |
|
|
|
torch.package introduces module importers other than the default one. |
|
Each PackageImporter introduces a new namespace. Potentially a single |
|
name (e.g. 'foo.bar') is present in multiple namespaces. |
|
|
|
It supports two main operations: |
|
import_module: module_name -> module object |
|
get_name: object -> (parent module name, name of obj within module) |
|
|
|
The guarantee is that following round-trip will succeed or throw an ObjNotFoundError/ObjMisMatchError. |
|
module_name, obj_name = env.get_name(obj) |
|
module = env.import_module(module_name) |
|
obj2 = getattr(module, obj_name) |
|
assert obj1 is obj2 |
|
""" |
|
|
|
modules: Dict[str, ModuleType] |
|
|
|
@abstractmethod |
|
def import_module(self, module_name: str) -> ModuleType: |
|
"""Import `module_name` from this environment. |
|
|
|
The contract is the same as for importlib.import_module. |
|
""" |
|
pass |
|
|
|
def get_name(self, obj: Any, name: Optional[str] = None) -> Tuple[str, str]: |
|
"""Given an object, return a name that can be used to retrieve the |
|
object from this environment. |
|
|
|
Args: |
|
obj: An object to get the the module-environment-relative name for. |
|
name: If set, use this name instead of looking up __name__ or __qualname__ on `obj`. |
|
This is only here to match how Pickler handles __reduce__ functions that return a string, |
|
don't use otherwise. |
|
Returns: |
|
A tuple (parent_module_name, attr_name) that can be used to retrieve `obj` from this environment. |
|
Use it like: |
|
mod = importer.import_module(parent_module_name) |
|
obj = getattr(mod, attr_name) |
|
|
|
Raises: |
|
ObjNotFoundError: we couldn't retrieve `obj by name. |
|
ObjMisMatchError: we found a different object with the same name as `obj`. |
|
""" |
|
if name is None and obj and _Pickler.dispatch.get(type(obj)) is None: |
|
|
|
|
|
|
|
reduce = getattr(obj, "__reduce__", None) |
|
if reduce is not None: |
|
try: |
|
rv = reduce() |
|
if isinstance(rv, str): |
|
name = rv |
|
except Exception: |
|
pass |
|
if name is None: |
|
name = getattr(obj, "__qualname__", None) |
|
if name is None: |
|
name = obj.__name__ |
|
|
|
orig_module_name = self.whichmodule(obj, name) |
|
|
|
|
|
|
|
module_name = demangle(orig_module_name) |
|
|
|
|
|
try: |
|
module = self.import_module(module_name) |
|
obj2, _ = _getattribute(module, name) |
|
except (ImportError, KeyError, AttributeError): |
|
raise ObjNotFoundError( |
|
f"{obj} was not found as {module_name}.{name}" |
|
) from None |
|
|
|
if obj is obj2: |
|
return module_name, name |
|
|
|
def get_obj_info(obj): |
|
assert name is not None |
|
module_name = self.whichmodule(obj, name) |
|
is_mangled_ = is_mangled(module_name) |
|
location = ( |
|
get_mangle_prefix(module_name) |
|
if is_mangled_ |
|
else "the current Python environment" |
|
) |
|
importer_name = ( |
|
f"the importer for {get_mangle_prefix(module_name)}" |
|
if is_mangled_ |
|
else "'sys_importer'" |
|
) |
|
return module_name, location, importer_name |
|
|
|
obj_module_name, obj_location, obj_importer_name = get_obj_info(obj) |
|
obj2_module_name, obj2_location, obj2_importer_name = get_obj_info(obj2) |
|
msg = ( |
|
f"\n\nThe object provided is from '{obj_module_name}', " |
|
f"which is coming from {obj_location}." |
|
f"\nHowever, when we import '{obj2_module_name}', it's coming from {obj2_location}." |
|
"\nTo fix this, make sure this 'PackageExporter's importer lists " |
|
f"{obj_importer_name} before {obj2_importer_name}." |
|
) |
|
raise ObjMismatchError(msg) |
|
|
|
def whichmodule(self, obj: Any, name: str) -> str: |
|
"""Find the module name an object belongs to. |
|
|
|
This should be considered internal for end-users, but developers of |
|
an importer can override it to customize the behavior. |
|
|
|
Taken from pickle.py, but modified to exclude the search into sys.modules |
|
""" |
|
module_name = getattr(obj, "__module__", None) |
|
if module_name is not None: |
|
return module_name |
|
|
|
|
|
|
|
for module_name, module in self.modules.copy().items(): |
|
if ( |
|
module_name == "__main__" |
|
or module_name == "__mp_main__" |
|
or module is None |
|
): |
|
continue |
|
try: |
|
if _getattribute(module, name)[0] is obj: |
|
return module_name |
|
except AttributeError: |
|
pass |
|
|
|
return "__main__" |
|
|
|
|
|
class _SysImporter(Importer): |
|
"""An importer that implements the default behavior of Python.""" |
|
|
|
def import_module(self, module_name: str): |
|
return importlib.import_module(module_name) |
|
|
|
def whichmodule(self, obj: Any, name: str) -> str: |
|
return _pickle_whichmodule(obj, name) |
|
|
|
|
|
sys_importer = _SysImporter() |
|
|
|
|
|
class OrderedImporter(Importer): |
|
"""A compound importer that takes a list of importers and tries them one at a time. |
|
|
|
The first importer in the list that returns a result "wins". |
|
""" |
|
|
|
def __init__(self, *args): |
|
self._importers: List[Importer] = list(args) |
|
|
|
def _is_torchpackage_dummy(self, module): |
|
"""Returns true iff this module is an empty PackageNode in a torch.package. |
|
|
|
If you intern `a.b` but never use `a` in your code, then `a` will be an |
|
empty module with no source. This can break cases where we are trying to |
|
re-package an object after adding a real dependency on `a`, since |
|
OrderedImportere will resolve `a` to the dummy package and stop there. |
|
|
|
See: https://github.com/pytorch/pytorch/pull/71520#issuecomment-1029603769 |
|
""" |
|
if not getattr(module, "__torch_package__", False): |
|
return False |
|
if not hasattr(module, "__path__"): |
|
return False |
|
if not hasattr(module, "__file__"): |
|
return True |
|
return module.__file__ is None |
|
|
|
def import_module(self, module_name: str) -> ModuleType: |
|
last_err = None |
|
for importer in self._importers: |
|
if not isinstance(importer, Importer): |
|
raise TypeError( |
|
f"{importer} is not a Importer. " |
|
"All importers in OrderedImporter must inherit from Importer." |
|
) |
|
try: |
|
module = importer.import_module(module_name) |
|
if self._is_torchpackage_dummy(module): |
|
continue |
|
return module |
|
except ModuleNotFoundError as err: |
|
last_err = err |
|
|
|
if last_err is not None: |
|
raise last_err |
|
else: |
|
raise ModuleNotFoundError(module_name) |
|
|
|
def whichmodule(self, obj: Any, name: str) -> str: |
|
for importer in self._importers: |
|
module_name = importer.whichmodule(obj, name) |
|
if module_name != "__main__": |
|
return module_name |
|
|
|
return "__main__" |
|
|