Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,051 Bytes
829e08b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from environs import Env
from torch import Tensor
from beartype import beartype
from beartype.door import is_bearable
from jaxtyping import (
Float,
Int,
Bool,
jaxtyped
)
# environment
env = Env()
env.read_env()
# function
def always(value):
def inner(*args, **kwargs):
return value
return inner
def identity(t):
return t
# jaxtyping is a misnomer, works for pytorch
class TorchTyping:
def __init__(self, abstract_dtype):
self.abstract_dtype = abstract_dtype
def __getitem__(self, shapes: str):
return self.abstract_dtype[Tensor, shapes]
Float = TorchTyping(Float)
Int = TorchTyping(Int)
Bool = TorchTyping(Bool)
# use env variable TYPECHECK to control whether to use beartype + jaxtyping
should_typecheck = env.bool('TYPECHECK', False)
typecheck = jaxtyped(typechecker = beartype) if should_typecheck else identity
beartype_isinstance = is_bearable if should_typecheck else always(True)
__all__ = [
Float,
Int,
Bool,
typecheck,
beartype_isinstance
]
|