EasyDeL
EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning models. With a primary focus on Jax, EasyDeL aims to provide convenient and effective solutions for training Flax/Jax models on TPU/GPU for both serving and training purposes.
Using Example
Using From EasyDeLState (*.easy files)
from easydel import EasyDeLState, AutoShardAndGatherFunctions
from jax import numpy as jnp, lax
shard_fns, gather_fns = AutoShardAndGatherFunctions.from_pretrained(
"REPO_ID", # Pytorch State should be saved to in order to find shard gather fns with no effort, otherwise read docs.
backend="gpu",
depth_target=["params", "params"],
flatten=False
)
state = EasyDeLState.load_state(
"*.easy",
dtype=jnp.float16,
param_dtype=jnp.float16,
precision=lax.Precision("fastest"),
verbose=True,
state_shard_fns=shard_fns
)
# State file Ready to use ...
Using From AutoEasyDeLModelForCausalLM (from PyTorch)
from easydel import AutoEasyDeLModelForCausalLM
from jax import numpy as jnp, lax
model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
"REPO_ID",
dtype=jnp.float16,
param_dtype=jnp.float16,
precision=lax.Precision("fastest"),
auto_shard_params=True,
)
# Model and Parameters Ready to use ...
Using From AutoEasyDeLModelForCausalLM (from EasyDeL)
from easydel import AutoEasyDeLModelForCausalLM
from jax import numpy as jnp, lax
model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
"REPO_ID/",
dtype=jnp.float16,
param_dtype=jnp.float16,
precision=lax.Precision("fastest"),
auto_shard_params=True,
from_torch=False
)
# Model and Parameters Ready to use ...
- Downloads last month
- 154
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.