Upload HDLM model with complete HF integration
Browse files- config.yaml +88 -0
 
    	
        config.yaml
    ADDED
    
    | 
         @@ -0,0 +1,88 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ngpus: 4
         
     | 
| 2 | 
         
            +
            gradient_accumulation_steps: 8
         
     | 
| 3 | 
         
            +
            pretrain_autoregressive_path: /home/toolkit/research-diffcodegen/exp_local/openwebtext/mdlm-autoregressive/org-DiTAR-absorb-v2/checkpoints-meta/checkpoint.pth
         
     | 
| 4 | 
         
            +
            tokenizer:
         
     | 
| 5 | 
         
            +
              tokens: 50257
         
     | 
| 6 | 
         
            +
              model: gpt2
         
     | 
| 7 | 
         
            +
            training:
         
     | 
| 8 | 
         
            +
              batch_size: 512
         
     | 
| 9 | 
         
            +
              accum: ${gradient_accumulation_steps}
         
     | 
| 10 | 
         
            +
              n_iters: 1000000
         
     | 
| 11 | 
         
            +
              snapshot_freq: 100
         
     | 
| 12 | 
         
            +
              log_freq: 10
         
     | 
| 13 | 
         
            +
              eval_freq: 100
         
     | 
| 14 | 
         
            +
              snapshot_freq_for_preemption: 3000
         
     | 
| 15 | 
         
            +
              weight: standard
         
     | 
| 16 | 
         
            +
              snapshot_sampling: true
         
     | 
| 17 | 
         
            +
              ema: 0.9999
         
     | 
| 18 | 
         
            +
              warmup_iter: -1
         
     | 
| 19 | 
         
            +
            data:
         
     | 
| 20 | 
         
            +
              train: openwebtext-train
         
     | 
| 21 | 
         
            +
              valid: wikitext103
         
     | 
| 22 | 
         
            +
              cache_dir: /home/toolkit/research-diffcodegen/data
         
     | 
| 23 | 
         
            +
              debug: false
         
     | 
| 24 | 
         
            +
            graph:
         
     | 
| 25 | 
         
            +
              type: QGamma
         
     | 
| 26 | 
         
            +
              gamma: 0.01
         
     | 
| 27 | 
         
            +
              file: /home/toolkit/research-diffcodegen/data
         
     | 
| 28 | 
         
            +
              report_all: false
         
     | 
| 29 | 
         
            +
              expanded_sigma: true
         
     | 
| 30 | 
         
            +
            noise:
         
     | 
| 31 | 
         
            +
              type: loglinear
         
     | 
| 32 | 
         
            +
              sigma_min: 0.0001
         
     | 
| 33 | 
         
            +
              sigma_max: 2.0
         
     | 
| 34 | 
         
            +
              ar_diffusion: false
         
     | 
| 35 | 
         
            +
              expanded_sigma: ${graph.expanded_sigma}
         
     | 
| 36 | 
         
            +
            sampling:
         
     | 
| 37 | 
         
            +
              predictor: analytic
         
     | 
| 38 | 
         
            +
              steps_per_level: 1
         
     | 
| 39 | 
         
            +
              noise_removal: true
         
     | 
| 40 | 
         
            +
              strategy: direct
         
     | 
| 41 | 
         
            +
              strategy_param: 0.9
         
     | 
| 42 | 
         
            +
            annealing:
         
     | 
| 43 | 
         
            +
              type: block
         
     | 
| 44 | 
         
            +
              efficient: false
         
     | 
| 45 | 
         
            +
              width: 1024
         
     | 
| 46 | 
         
            +
              tau: 2048
         
     | 
| 47 | 
         
            +
              eval_tau: 512
         
     | 
| 48 | 
         
            +
              steps_per_level: ${sampling.steps_per_level}
         
     | 
| 49 | 
         
            +
              sampling_method: SAR
         
     | 
| 50 | 
         
            +
              diffusion_loss_weight: 1.0
         
     | 
| 51 | 
         
            +
              ce_loss_weight: 4.0
         
     | 
| 52 | 
         
            +
              sampling_eps: 0.0001
         
     | 
| 53 | 
         
            +
              attention:
         
     | 
| 54 | 
         
            +
                context_type: block_causal
         
     | 
| 55 | 
         
            +
                block_type: full
         
     | 
| 56 | 
         
            +
              match_inference: true
         
     | 
| 57 | 
         
            +
            eval:
         
     | 
| 58 | 
         
            +
              batch_size: 32
         
     | 
| 59 | 
         
            +
              perplexity: true
         
     | 
| 60 | 
         
            +
              perplexity_batch_size: 16
         
     | 
| 61 | 
         
            +
            optim:
         
     | 
| 62 | 
         
            +
              weight_decay: 0.0
         
     | 
| 63 | 
         
            +
              optimizer: AdamW
         
     | 
| 64 | 
         
            +
              lr: 0.0003
         
     | 
| 65 | 
         
            +
              beta1: 0.9
         
     | 
| 66 | 
         
            +
              beta2: 0.999
         
     | 
| 67 | 
         
            +
              eps: 1.0e-08
         
     | 
| 68 | 
         
            +
              warmup: 10000
         
     | 
| 69 | 
         
            +
              grad_clip: 1.0
         
     | 
| 70 | 
         
            +
              scheduler: lambda
         
     | 
| 71 | 
         
            +
            experiment:
         
     | 
| 72 | 
         
            +
              name: QGamma0.01-v2
         
     | 
| 73 | 
         
            +
              wandb_project: debug-QGamma
         
     | 
| 74 | 
         
            +
            model:
         
     | 
| 75 | 
         
            +
              name: gamma_hdlm
         
     | 
| 76 | 
         
            +
              type: ddit
         
     | 
| 77 | 
         
            +
              hidden_size: 768
         
     | 
| 78 | 
         
            +
              cond_dim: 128
         
     | 
| 79 | 
         
            +
              length: 1024
         
     | 
| 80 | 
         
            +
              n_blocks: 12
         
     | 
| 81 | 
         
            +
              n_heads: 12
         
     | 
| 82 | 
         
            +
              scale_by_sigma: false
         
     | 
| 83 | 
         
            +
              dropout: 0.1
         
     | 
| 84 | 
         
            +
              transformer_sigma_conditioning: true
         
     | 
| 85 | 
         
            +
              hybrid_sigma_embedding: true
         
     | 
| 86 | 
         
            +
              post_process_logits: true
         
     | 
| 87 | 
         
            +
              use_timestep_embedding: true
         
     | 
| 88 | 
         
            +
            model_type: gamma_hybrid
         
     |