import importlib from abc import ABC, abstractmethod from pickle import ( # type: ignore[attr-defined] # type: ignore[attr-defined] _getattribute, _Pickler, whichmodule as _pickle_whichmodule, ) from types import ModuleType from typing import Any, Dict, List, Optional, Tuple from ._mangling import demangle, get_mangle_prefix, is_mangled __all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"] class ObjNotFoundError(Exception): """Raised when an importer cannot find an object by searching for its name.""" pass class ObjMismatchError(Exception): """Raised when an importer found a different object with the same name as the user-provided one.""" pass class Importer(ABC): """Represents an environment to import modules from. By default, you can figure out what module an object belongs by checking __module__ and importing the result using __import__ or importlib.import_module. torch.package introduces module importers other than the default one. Each PackageImporter introduces a new namespace. Potentially a single name (e.g. 'foo.bar') is present in multiple namespaces. It supports two main operations: import_module: module_name -> module object get_name: object -> (parent module name, name of obj within module) The guarantee is that following round-trip will succeed or throw an ObjNotFoundError/ObjMisMatchError. module_name, obj_name = env.get_name(obj) module = env.import_module(module_name) obj2 = getattr(module, obj_name) assert obj1 is obj2 """ modules: Dict[str, ModuleType] @abstractmethod def import_module(self, module_name: str) -> ModuleType: """Import `module_name` from this environment. The contract is the same as for importlib.import_module. """ pass def get_name(self, obj: Any, name: Optional[str] = None) -> Tuple[str, str]: """Given an object, return a name that can be used to retrieve the object from this environment. Args: obj: An object to get the module-environment-relative name for. name: If set, use this name instead of looking up __name__ or __qualname__ on `obj`. This is only here to match how Pickler handles __reduce__ functions that return a string, don't use otherwise. Returns: A tuple (parent_module_name, attr_name) that can be used to retrieve `obj` from this environment. Use it like: mod = importer.import_module(parent_module_name) obj = getattr(mod, attr_name) Raises: ObjNotFoundError: we couldn't retrieve `obj by name. ObjMisMatchError: we found a different object with the same name as `obj`. """ if name is None and obj and _Pickler.dispatch.get(type(obj)) is None: # Honor the string return variant of __reduce__, which will give us # a global name to search for in this environment. # TODO: I guess we should do copyreg too? reduce = getattr(obj, "__reduce__", None) if reduce is not None: try: rv = reduce() if isinstance(rv, str): name = rv except Exception: pass if name is None: name = getattr(obj, "__qualname__", None) if name is None: name = obj.__name__ orig_module_name = self.whichmodule(obj, name) # Demangle the module name before importing. If this obj came out of a # PackageImporter, `__module__` will be mangled. See mangling.md for # details. module_name = demangle(orig_module_name) # Check that this name will indeed return the correct object try: module = self.import_module(module_name) obj2, _ = _getattribute(module, name) except (ImportError, KeyError, AttributeError): raise ObjNotFoundError( f"{obj} was not found as {module_name}.{name}" ) from None if obj is obj2: return module_name, name def get_obj_info(obj): assert name is not None module_name = self.whichmodule(obj, name) is_mangled_ = is_mangled(module_name) location = ( get_mangle_prefix(module_name) if is_mangled_ else "the current Python environment" ) importer_name = ( f"the importer for {get_mangle_prefix(module_name)}" if is_mangled_ else "'sys_importer'" ) return module_name, location, importer_name obj_module_name, obj_location, obj_importer_name = get_obj_info(obj) obj2_module_name, obj2_location, obj2_importer_name = get_obj_info(obj2) msg = ( f"\n\nThe object provided is from '{obj_module_name}', " f"which is coming from {obj_location}." f"\nHowever, when we import '{obj2_module_name}', it's coming from {obj2_location}." "\nTo fix this, make sure this 'PackageExporter's importer lists " f"{obj_importer_name} before {obj2_importer_name}." ) raise ObjMismatchError(msg) def whichmodule(self, obj: Any, name: str) -> str: """Find the module name an object belongs to. This should be considered internal for end-users, but developers of an importer can override it to customize the behavior. Taken from pickle.py, but modified to exclude the search into sys.modules """ module_name = getattr(obj, "__module__", None) if module_name is not None: return module_name # Protect the iteration by using a list copy of self.modules against dynamic # modules that trigger imports of other modules upon calls to getattr. for module_name, module in self.modules.copy().items(): if ( module_name == "__main__" or module_name == "__mp_main__" # bpo-42406 or module is None ): continue try: if _getattribute(module, name)[0] is obj: return module_name except AttributeError: pass return "__main__" class _SysImporter(Importer): """An importer that implements the default behavior of Python.""" def import_module(self, module_name: str): return importlib.import_module(module_name) def whichmodule(self, obj: Any, name: str) -> str: return _pickle_whichmodule(obj, name) sys_importer = _SysImporter() class OrderedImporter(Importer): """A compound importer that takes a list of importers and tries them one at a time. The first importer in the list that returns a result "wins". """ def __init__(self, *args): self._importers: List[Importer] = list(args) def _is_torchpackage_dummy(self, module): """Returns true iff this module is an empty PackageNode in a torch.package. If you intern `a.b` but never use `a` in your code, then `a` will be an empty module with no source. This can break cases where we are trying to re-package an object after adding a real dependency on `a`, since OrderedImportere will resolve `a` to the dummy package and stop there. See: https://github.com/pytorch/pytorch/pull/71520#issuecomment-1029603769 """ if not getattr(module, "__torch_package__", False): return False if not hasattr(module, "__path__"): return False if not hasattr(module, "__file__"): return True return module.__file__ is None def import_module(self, module_name: str) -> ModuleType: last_err = None for importer in self._importers: if not isinstance(importer, Importer): raise TypeError( f"{importer} is not a Importer. " "All importers in OrderedImporter must inherit from Importer." ) try: module = importer.import_module(module_name) if self._is_torchpackage_dummy(module): continue return module except ModuleNotFoundError as err: last_err = err if last_err is not None: raise last_err else: raise ModuleNotFoundError(module_name) def whichmodule(self, obj: Any, name: str) -> str: for importer in self._importers: module_name = importer.whichmodule(obj, name) if module_name != "__main__": return module_name return "__main__"