winglian commited on
Commit
931e606
2 Parent(s): 6b50200 7f09106

Merge pull request #179 from OpenAccess-AI-Collective/fix-max_seq_len

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +8 -1
src/axolotl/utils/models.py CHANGED
@@ -255,8 +255,15 @@ def load_model(
255
  )
256
  # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
257
  # when training starts
258
- if config.max_seq_len and cfg.sequence_len > config.max_seq_len:
259
  config.max_seq_len = cfg.sequence_len
 
 
 
 
 
 
 
260
  model = AutoModelForCausalLM.from_pretrained(
261
  base_model,
262
  config=config,
 
255
  )
256
  # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
257
  # when training starts
258
+ if hasattr(config, "max_seq_len") and cfg.sequence_len > config.max_seq_len:
259
  config.max_seq_len = cfg.sequence_len
260
+ logging.warning(f"increasing context length to {cfg.sequence_len}")
261
+ elif (
262
+ hasattr(config, "max_sequence_length")
263
+ and cfg.sequence_len > config.max_sequence_length
264
+ ):
265
+ config.max_sequence_length = cfg.sequence_len
266
+ logging.warning(f"increasing context length to {cfg.sequence_len}")
267
  model = AutoModelForCausalLM.from_pretrained(
268
  base_model,
269
  config=config,