|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""A config to load and eval key model using the core train.py. |
|
|
|
The runtime varies widely depending on the model, but each one should reproduce |
|
the corresponding paper's numbers. |
|
This configuration makes use of the "arg" to get_config to select which model |
|
to run, so a few examples are given below: |
|
|
|
Run and evaluate a BiT-M ResNet-50x1 model that was transferred to i1k: |
|
|
|
big_vision.train \ |
|
--config big_vision/configs/load_and_eval.py:name=bit_paper,batch_size=8 \ |
|
--config.model_init M-imagenet2012 --config.model.width 1 --config.model.depth 50 |
|
|
|
Run and evaluate the recommended ViT-B/32 from "how to train your vit" paper: |
|
|
|
big_vision.train \ |
|
--config big_vision/configs/load_and_eval.py:name=vit_i21k,batch_size=8 \ |
|
--config.model.variant B/32 --config.model_init howto-i21k-B/32 |
|
""" |
|
|
|
import big_vision.configs.common as bvcc |
|
from big_vision.configs.common_fewshot import get_fewshot_lsr |
|
|
|
|
|
def eval_only(config, batch_size, spec_for_init): |
|
"""Set a few configs that turn trainer into (almost) eval-only.""" |
|
config.total_steps = 0 |
|
config.input = {} |
|
config.input.batch_size = batch_size |
|
config.input.data = dict(name='bv:dummy', spec=spec_for_init) |
|
config.optax_name = 'identity' |
|
config.lr = 0.0 |
|
|
|
config.mesh = [('data', -1)] |
|
config.sharding_strategy = [('params/.*', 'fsdp(axis="data")')] |
|
config.sharding_rules = [('act_batch', ('data',))] |
|
|
|
return config |
|
|
|
|
|
def get_config(arg=''): |
|
config = bvcc.parse_arg(arg, name='bit_paper', batch_size=4) |
|
|
|
|
|
eval_only(config, config.batch_size, spec_for_init=dict( |
|
image=dict(shape=(224, 224, 3), dtype='float32'), |
|
)) |
|
|
|
config.evals = dict(fewshot=get_fewshot_lsr()) |
|
|
|
|
|
|
|
globals()[config.name](config) |
|
return config |
|
|
|
|
|
def bit_paper(config): |
|
config.num_classes = 1000 |
|
|
|
config.model_name = 'bit_paper' |
|
config.model_init = 'M-imagenet2012' |
|
config.model = dict(width=1, depth=50) |
|
|
|
def get_eval(split, lbl, dataset='imagenet2012_real'): |
|
return dict( |
|
type='classification', |
|
data=dict(name=dataset, split=split), |
|
loss_name='softmax_xent', |
|
cache='none', |
|
pp_fn=( |
|
'decode|resize(384)|value_range(-1, 1)' |
|
f'|onehot(1000, key="{lbl}", key_result="labels")' |
|
'|keep("image", "labels")' |
|
), |
|
) |
|
config.evals.test = get_eval('validation', 'original_label') |
|
config.evals.real = get_eval('validation', 'real_label') |
|
config.evals.v2 = get_eval('test', 'label', 'imagenet_v2') |
|
|
|
|
|
def vit_i1k(config): |
|
config.num_classes = 1000 |
|
|
|
config.model_name = 'vit' |
|
config.model_init = '' |
|
config.model = dict(variant='S/16', pool_type='gap', posemb='sincos2d', |
|
rep_size=True) |
|
|
|
config.evals.val = dict( |
|
type='classification', |
|
data=dict(name='imagenet2012', split='validation'), |
|
pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")', |
|
loss_name='softmax_xent', |
|
cache='none', |
|
) |
|
|
|
|
|
def mlp_mixer_i1k(config): |
|
config.num_classes = 1000 |
|
|
|
config.model_name = 'mlp_mixer' |
|
config.model_init = '' |
|
config.model = dict(variant='L/16') |
|
|
|
config.evals.val = dict( |
|
type='classification', |
|
data=dict(name='imagenet2012', split='validation'), |
|
pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")', |
|
loss_name='softmax_xent', |
|
cache='none', |
|
) |
|
|
|
|
|
def vit_i21k(config): |
|
config.num_classes = 21843 |
|
|
|
config.model_name = 'vit' |
|
config.model_init = '' |
|
config.model = dict(variant='B/32', pool_type='tok') |
|
|
|
config.evals.val = dict( |
|
type='classification', |
|
data=dict(name='imagenet21k', split='full[:51200]'), |
|
pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(21843)|keep("image", "labels")', |
|
loss_name='sigmoid_xent', |
|
cache='none', |
|
) |
|
|