| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| r"""WIT Retrieval + Captioning Pre-Training.""" |
|
|
| import ml_collections |
|
|
| TRAIN_DATA_SIZE = 1_000_000_000 |
|
|
|
|
| def get_config() -> ml_collections.ConfigDict: |
| """Returns the base experiment configuration.""" |
| config = ml_collections.ConfigDict() |
| config.experiment_name = 'image_caption_debug' |
|
|
| config.optimizer = 'adafactor' |
| n_device = 128 |
| batch_size = 12 * 2 * n_device |
| config.optimizer_configs = ml_collections.ConfigDict() |
| config.optimizer_configs.momentum = None |
| |
| |
| config.optimizer_configs.weight_decay_rate = 2e-3 |
| config.optimizer_configs.clipping_threshold = 5.0 |
| config.optimizer_configs.skip_scale_and_bias_regularization = True |
|
|
| config.frozen_patterns = [] |
| config.not_frozen_patterns = [('value_perceiver/.*', 0.3), |
| |
| |
| ('shared_token_embedder/.*', 0.1), |
| ('query_head/.*', 0.2), ('out_decoder/.*', 1), |
| ('key_head/.*', 0.2), ('head_out/.*', 0.2), |
| ('fusion_encoder/.*', 0.5), |
| ('att_transform/.*', 0.3), |
| ('dataset_gate/.*', 0.5)] |
|
|
| config.grad_clip_configs = ml_collections.ConfigDict() |
| config.grad_clip_configs.clip_method = 'clip_by_global_norm' |
| config.grad_clip_configs.clip_value = 1.0 |
|
|
| config.kb_dataset_names = ['wit_table', 'cc12m_table', 'vqa_table'] |
| config.kb_dataset_configs = [{}, {}, {}] |
|
|
| config.batch_size = batch_size |
| config.eval_batch_size = batch_size |
| config.rng_seed = 0 |
| config.update_num = False |
| config.num_training_epochs = 1 |
| config.data_dtype_str = 'bfloat16' |
| |
| config.model_name = 'knowledge_fid' |
| config.model = ml_collections.ConfigDict() |
| config.model.image_model = 'vit' |
| config.model.t5_name = 't5_1_1_base' |
| |
| config.model.num_fusion_layers = 6 |
| config.model.n_compressed_tokens = 32 |
| config.model.key_dim = 512 |
| config.model.dropout_rate = 0.0 |
| config.model.temperature = 0.2 |
| config.model.retr_k = 10 |
| config.model.retr_data_ratio = 0.2 |
| config.model.label_smoothing = 1e-2 |
| config.model.vit_name = 'B/16' |
| config.model.vit_model_path = 'JFT3b-B/16' |
| |
| config.model.t5_frozen_base = False |
| config.model.vit_num_frozen_layers = 1 / 2 |
| config.model.retrieve_local = False |
| config.model.use_psudo_retr = True |
| config.model.disentangle = True |
| config.model.gap = True |
| config.model.retrieval_ratio = 1e-2 |
| config.model.n_knowledge_source = len(config.kb_dataset_names) |
| config.model.qa = False |
| config.frozen_memory = False |
|
|
| config.vocab_size = 32120 |
| config.autoregressive_decoding = ml_collections.ConfigDict() |
| config.autoregressive_decoding.num_decodes = 1 |
| config.autoregressive_decoding.beam_search = False |
| |
| config.dataset_name = 'web_image_text_generation' |
| config.dataset_configs = ml_collections.ConfigDict() |
|
|
| |
| config.num_train_examples = TRAIN_DATA_SIZE |
| steps_per_epoch = TRAIN_DATA_SIZE // config.batch_size |
| config.lr_configs = ml_collections.ConfigDict() |
| config.lr_configs.total_steps = int(config.num_training_epochs * |
| steps_per_epoch) |
| config.lr_configs.learning_rate_schedule = 'compound' |
| config.lr_configs.factors = 'constant * rsqrt_decay * linear_warmup' |
| config.lr_configs.warmup_steps = 10000 |
| config.lr_configs.timescale = 10000 |
| |
| config.lr_configs.base_learning_rate = 1e-4 |
| config.lr_configs.end_learning_rate = 1e-6 |
|
|
| |
| config.log_summary_steps = 100 |
| config.log_eval_steps = 1000 |
| config.checkpoint_steps = 5000 |
| config.write_summary = True |
| config.xprof = True |
| config.checkpoint = True |
| config.debug_train = False |
| config.debug_eval = False |
|
|
| |
| config.init_from = ml_collections.ConfigDict() |
| |
| config.init_from.load_key_encoder = False |
| config.init_from.encoder = ml_collections.ConfigDict() |
| config.init_from.encoder.init_from_vit = False |
| config.init_from.encoder.checkpoint_path = None |
| return config |
|
|