Spaces:
Runtime error
Runtime error
import types | |
from contextlib import contextmanager | |
# The idea for this parameter is that we forbid bare assignment | |
# to torch.backends.<cudnn|mkldnn>.enabled and friends when running our | |
# test suite, where it's very easy to forget to undo the change | |
# later. | |
__allow_nonbracketed_mutation_flag = True | |
def disable_global_flags(): | |
global __allow_nonbracketed_mutation_flag | |
__allow_nonbracketed_mutation_flag = False | |
def flags_frozen(): | |
return not __allow_nonbracketed_mutation_flag | |
def __allow_nonbracketed_mutation(): | |
global __allow_nonbracketed_mutation_flag | |
old = __allow_nonbracketed_mutation_flag | |
__allow_nonbracketed_mutation_flag = True | |
try: | |
yield | |
finally: | |
__allow_nonbracketed_mutation_flag = old | |
class ContextProp: | |
def __init__(self, getter, setter): | |
self.getter = getter | |
self.setter = setter | |
def __get__(self, obj, objtype): | |
return self.getter() | |
def __set__(self, obj, val): | |
if not flags_frozen(): | |
self.setter(val) | |
else: | |
raise RuntimeError( | |
"not allowed to set %s flags " | |
"after disable_global_flags; please use flags() context manager instead" | |
% obj.__name__ | |
) | |
class PropModule(types.ModuleType): | |
def __init__(self, m, name): | |
super().__init__(name) | |
self.m = m | |
def __getattr__(self, attr): | |
return self.m.__getattribute__(attr) | |
from torch.backends import ( | |
cpu as cpu, | |
cuda as cuda, | |
cudnn as cudnn, | |
mkl as mkl, | |
mkldnn as mkldnn, | |
mps as mps, | |
openmp as openmp, | |
quantized as quantized, | |
) | |