import logging def validate_config(cfg): if cfg.load_4bit: raise ValueError( "cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq" ) if cfg.adapter == "qlora": if cfg.merge_lora: # can't merge qlora if loaded in 8bit or 4bit if cfg.load_in_8bit: raise ValueError("Can't merge qlora if loaded in 8bit") if cfg.gptq: raise ValueError("Can't merge qlora if gptq") if cfg.load_in_4bit: raise ValueError("Can't merge qlora if loaded in 4bit") else: if cfg.load_in_8bit: raise ValueError("Can't load qlora in 8bit") if cfg.gptq: raise ValueError("Can't load qlora if gptq") if not cfg.load_in_4bit: raise ValueError("Require cfg.load_in_4bit to be True for qlora") if not cfg.load_in_8bit and cfg.adapter == "lora": logging.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") if cfg.trust_remote_code: logging.warning( "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." ) if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True: raise ValueError("Require cfg.hf_use_auth_token to be True for push_dataset_to_hub") # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 # no 8bit adamw w bf16