File size: 459 Bytes
bb5cd12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from .config import MultimodalConfig
from .magma import Magma
from .language_model import get_gptj
from .transforms import get_transforms
from .utils import (
    count_parameters,
    is_main,
    cycle,
    get_tokenizer,
    parse_args,
    wandb_log,
    wandb_init,
    save_model,
    load_model,
    print_main,
    configure_param_groups,
    log_table,
)
from .train_loop import eval_step, inference_step, train_step
from .datasets import collate_fn