magma / magma /__init__.py
stellaathena's picture
This should work
bb5cd12
from .config import MultimodalConfig
from .magma import Magma
from .language_model import get_gptj
from .transforms import get_transforms
from .utils import (
count_parameters,
is_main,
cycle,
get_tokenizer,
parse_args,
wandb_log,
wandb_init,
save_model,
load_model,
print_main,
configure_param_groups,
log_table,
)
from .train_loop import eval_step, inference_step, train_step
from .datasets import collate_fn