Spaces:
Running
on
Zero
Running
on
Zero
from environs import Env | |
from torch import Tensor | |
from beartype import beartype | |
from beartype.door import is_bearable | |
from jaxtyping import ( | |
Float, | |
Int, | |
Bool, | |
jaxtyped | |
) | |
# environment | |
env = Env() | |
env.read_env() | |
# function | |
def always(value): | |
def inner(*args, **kwargs): | |
return value | |
return inner | |
def identity(t): | |
return t | |
# jaxtyping is a misnomer, works for pytorch | |
class TorchTyping: | |
def __init__(self, abstract_dtype): | |
self.abstract_dtype = abstract_dtype | |
def __getitem__(self, shapes: str): | |
return self.abstract_dtype[Tensor, shapes] | |
Float = TorchTyping(Float) | |
Int = TorchTyping(Int) | |
Bool = TorchTyping(Bool) | |
# use env variable TYPECHECK to control whether to use beartype + jaxtyping | |
should_typecheck = env.bool('TYPECHECK', False) | |
typecheck = jaxtyped(typechecker = beartype) if should_typecheck else identity | |
beartype_isinstance = is_bearable if should_typecheck else always(True) | |
__all__ = [ | |
Float, | |
Int, | |
Bool, | |
typecheck, | |
beartype_isinstance | |
] | |