| |
| import os |
|
|
| from .parrots_wrapper import TORCH_VERSION |
|
|
| parrots_jit_option = os.getenv('PARROTS_JIT_OPTION') |
|
|
| if TORCH_VERSION == 'parrots' and parrots_jit_option == 'ON': |
| from parrots.jit import pat as jit |
| else: |
|
|
| def jit(func=None, |
| check_input=None, |
| full_shape=True, |
| derivate=False, |
| coderize=False, |
| optimize=False): |
|
|
| def wrapper(func): |
|
|
| def wrapper_inner(*args, **kargs): |
| return func(*args, **kargs) |
|
|
| return wrapper_inner |
|
|
| if func is None: |
| return wrapper |
| else: |
| return func |
|
|
|
|
| if TORCH_VERSION == 'parrots': |
| from parrots.utils.tester import skip_no_elena |
| else: |
|
|
| def skip_no_elena(func): |
|
|
| def wrapper(*args, **kargs): |
| return func(*args, **kargs) |
|
|
| return wrapper |
|
|