# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os from mmengine.utils.dl_utils.parrots_wrapper import TORCH_VERSION parrots_jit_option = os.getenv('PARROTS_JIT_OPTION') if TORCH_VERSION == 'parrots' and parrots_jit_option == 'ON': from parrots.jit import pat as jit else: def jit(func=None, check_input=None, full_shape=True, derivate=False, coderize=False, optimize=False): def wrapper(func): def wrapper_inner(*args, **kargs): return func(*args, **kargs) return wrapper_inner if func is None: return wrapper else: return func if TORCH_VERSION == 'parrots': from parrots.utils.tester import skip_no_elena else: def skip_no_elena(func): def wrapper(*args, **kargs): return func(*args, **kargs) return wrapper