fix tokenizer loading, got openllama 3b working
Browse files
    	
        examples/{lora-alpaca-7b → lora-openllama-3b}/config.yml
    RENAMED
    
    | 
         @@ -1,5 +1,5 @@ 
     | 
|
| 1 | 
         
            -
            base_model:  
     | 
| 2 | 
         
            -
            base_model_config:  
     | 
| 3 | 
         
             
            model_type: LlamaForCausalLM
         
     | 
| 4 | 
         
             
            tokenizer_type: LlamaTokenizer
         
     | 
| 5 | 
         
             
            load_in_8bit: true
         
     | 
| 
         @@ -32,9 +32,9 @@ wandb_watch: 
     | 
|
| 32 | 
         
             
            wandb_run_id:
         
     | 
| 33 | 
         
             
            wandb_log_model:
         
     | 
| 34 | 
         
             
            output_dir: ./lora-out
         
     | 
| 35 | 
         
            -
            batch_size:  
     | 
| 36 | 
         
            -
            micro_batch_size:  
     | 
| 37 | 
         
            -
            num_epochs:  
     | 
| 38 | 
         
             
            optimizer: adamw_bnb_8bit
         
     | 
| 39 | 
         
             
            torchdistx_path:
         
     | 
| 40 | 
         
             
            lr_scheduler: cosine
         
     | 
| 
         | 
|
| 1 | 
         
            +
            base_model: openlm-research/open_llama_3b_600bt_preview
         
     | 
| 2 | 
         
            +
            base_model_config: openlm-research/open_llama_3b_600bt_preview
         
     | 
| 3 | 
         
             
            model_type: LlamaForCausalLM
         
     | 
| 4 | 
         
             
            tokenizer_type: LlamaTokenizer
         
     | 
| 5 | 
         
             
            load_in_8bit: true
         
     | 
| 
         | 
|
| 32 | 
         
             
            wandb_run_id:
         
     | 
| 33 | 
         
             
            wandb_log_model:
         
     | 
| 34 | 
         
             
            output_dir: ./lora-out
         
     | 
| 35 | 
         
            +
            batch_size: 16
         
     | 
| 36 | 
         
            +
            micro_batch_size: 4
         
     | 
| 37 | 
         
            +
            num_epochs: 3
         
     | 
| 38 | 
         
             
            optimizer: adamw_bnb_8bit
         
     | 
| 39 | 
         
             
            torchdistx_path:
         
     | 
| 40 | 
         
             
            lr_scheduler: cosine
         
     | 
    	
        src/axolotl/utils/models.py
    CHANGED
    
    | 
         @@ -211,12 +211,12 @@ def load_model( 
     | 
|
| 211 | 
         
             
                    try:
         
     | 
| 212 | 
         
             
                        if is_llama_derived_model and "LlamaTokenizer" in globals():
         
     | 
| 213 | 
         
             
                            tokenizer = LlamaTokenizer.from_pretrained(
         
     | 
| 214 | 
         
            -
                                 
     | 
| 215 | 
         
             
                                trust_remote_code=True if cfg.trust_remote_code is True else False,
         
     | 
| 216 | 
         
             
                            )
         
     | 
| 217 | 
         
             
                        else:
         
     | 
| 218 | 
         
             
                            tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
         
     | 
| 219 | 
         
            -
                                 
     | 
| 220 | 
         
             
                                trust_remote_code=True if cfg.trust_remote_code is True else False,
         
     | 
| 221 | 
         
             
                            )
         
     | 
| 222 | 
         
             
                    except:
         
     | 
| 
         | 
|
| 211 | 
         
             
                    try:
         
     | 
| 212 | 
         
             
                        if is_llama_derived_model and "LlamaTokenizer" in globals():
         
     | 
| 213 | 
         
             
                            tokenizer = LlamaTokenizer.from_pretrained(
         
     | 
| 214 | 
         
            +
                                base_model_config,
         
     | 
| 215 | 
         
             
                                trust_remote_code=True if cfg.trust_remote_code is True else False,
         
     | 
| 216 | 
         
             
                            )
         
     | 
| 217 | 
         
             
                        else:
         
     | 
| 218 | 
         
             
                            tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
         
     | 
| 219 | 
         
            +
                                base_model_config,
         
     | 
| 220 | 
         
             
                                trust_remote_code=True if cfg.trust_remote_code is True else False,
         
     | 
| 221 | 
         
             
                            )
         
     | 
| 222 | 
         
             
                    except:
         
     |