import types from contextlib import contextmanager # The idea for this parameter is that we forbid bare assignment # to torch.backends..enabled and friends when running our # test suite, where it's very easy to forget to undo the change # later. __allow_nonbracketed_mutation_flag = True def disable_global_flags(): global __allow_nonbracketed_mutation_flag __allow_nonbracketed_mutation_flag = False def flags_frozen(): return not __allow_nonbracketed_mutation_flag @contextmanager def __allow_nonbracketed_mutation(): global __allow_nonbracketed_mutation_flag old = __allow_nonbracketed_mutation_flag __allow_nonbracketed_mutation_flag = True try: yield finally: __allow_nonbracketed_mutation_flag = old class ContextProp: def __init__(self, getter, setter): self.getter = getter self.setter = setter def __get__(self, obj, objtype): return self.getter() def __set__(self, obj, val): if not flags_frozen(): self.setter(val) else: raise RuntimeError( "not allowed to set %s flags " "after disable_global_flags; please use flags() context manager instead" % obj.__name__ ) class PropModule(types.ModuleType): def __init__(self, m, name): super().__init__(name) self.m = m def __getattr__(self, attr): return self.m.__getattribute__(attr) from torch.backends import ( cpu as cpu, cuda as cuda, cudnn as cudnn, mkl as mkl, mkldnn as mkldnn, mps as mps, openmp as openmp, quantized as quantized, )