Pretraining error
#18
by
saivineetha
- opened
I was trying to pre-train the model on a text file but training loss is becoming zero after 15 epochs. Can you help me with this.
if script_args.load_in_kbits in [4, 8]:
load_in_4bit = script_args.load_in_kbits == 4
load_in_8bit = script_args.load_in_kbits == 8
if script_args.modules_to_save is not None:
load_in_8bit_skip_modules = script_args.modules_to_save.split(",")
else:
load_in_8bit_skip_modules = None
quantization_config = BitsAndBytesConfig(
load_in_4bit=script_args.load_in_kbits == 4,
load_in_8bit=script_args.load_in_kbits == 8,
llm_int8_threshold=6.0,
load_in_8bit_skip_modules=load_in_8bit_skip_modules,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=script_args.double_quant,
bnb_4bit_quant_type=script_args.quant_type, # {'fp4', 'nf4'}
)
else:
load_in_4bit = False
load_in_8bit = False
quantization_config = None
if quantization_config is not None:
logger.info(f"quantization_config:{quantization_config.to_dict()}")
if script_args.model_name_or_path:
torch_dtype = (
script_args.torch_dtype
if script_args.torch_dtype in ["auto", None]
else getattr(torch, script_args.torch_dtype)
)
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
from_tf=bool(".ckpt" in script_args.model_name_or_path),
config=config,
# cache_dir=script_args.cache_dir,
# revision=model_args.model_revision,
use_auth_token=True if script_args.use_auth_token else None,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
device_map=device_map,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
quantization_config=quantization_config,
)
else:
model = AutoModelForCausalLM.from_config(config)
n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
logger.info(
f"Training new model from scratch - Total size={n_params/2**20:.2f}M params"
)
if script_args.load_in_kbits in [4, 8]:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=script_args.gradient_checkpointing
)
model.config.use_cache = False
model_vocab_size = model.get_output_embeddings().weight.size(0)
tokenizer_vocab_size = len(tokenizer)
logger.info(f"Model vocab size: {model_vocab_size}")
logger.info(f"Tokenizer vocab size: {tokenizer_vocab_size}")
if model_vocab_size != tokenizer_vocab_size:
logger.info(f"Resize model vocab size to {tokenizer_vocab_size}")
model.resize_token_embeddings(len(tokenizer))
if script_args.peft_path is not None:
logger.info("Peft from pre-trained model")
model = PeftModel.from_pretrained(
model, script_args.peft_path, device_map=device_map
)
else:
logger.info("Init new peft model")
target_modules = script_args.trainable.split(",")
modules_to_save = script_args.modules_to_save
if modules_to_save is not None:
modules_to_save = modules_to_save.split(",")
lora_rank = script_args.lora_rank
lora_dropout = script_args.lora_dropout
lora_alpha = script_args.lora_alpha
logger.info(f"target_modules: {target_modules}")
logger.info(f"lora_rank: {lora_rank}")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=target_modules,
inference_mode=False,
r=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
modules_to_save=modules_to_save,
)
model = get_peft_model(model, peft_config)
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
module = module.to(torch.float16)
# if script_args.bf16:
# module = module.to(torch.bfloat16)
# if script_args.fp16:
# module = module.to(torch.float16)
if "norm" in name:
module = module.to(torch.float16)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module = module.to(torch.float16)
# if script_args.bf16 and module.weight.dtype == torch.float32:
# module = module.to(torch.bfloat16)
# if script_args.fp16 and module.weight.dtype == torch.float32:
# module = module.to(torch.float16)
model.print_trainable_parameters()
logger.info(f"model.modules_to_save: {model.modules_to_save}")
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))