YulianSa's picture
init
829e08b
from environs import Env
from torch import Tensor
from beartype import beartype
from beartype.door import is_bearable
from jaxtyping import (
Float,
Int,
Bool,
jaxtyped
)
# environment
env = Env()
env.read_env()
# function
def always(value):
def inner(*args, **kwargs):
return value
return inner
def identity(t):
return t
# jaxtyping is a misnomer, works for pytorch
class TorchTyping:
def __init__(self, abstract_dtype):
self.abstract_dtype = abstract_dtype
def __getitem__(self, shapes: str):
return self.abstract_dtype[Tensor, shapes]
Float = TorchTyping(Float)
Int = TorchTyping(Int)
Bool = TorchTyping(Bool)
# use env variable TYPECHECK to control whether to use beartype + jaxtyping
should_typecheck = env.bool('TYPECHECK', False)
typecheck = jaxtyped(typechecker = beartype) if should_typecheck else identity
beartype_isinstance = is_bearable if should_typecheck else always(True)
__all__ = [
Float,
Int,
Bool,
typecheck,
beartype_isinstance
]