Merge pull request #179 from OpenAccess-AI-Collective/fix-max_seq_len
Browse files
src/axolotl/utils/models.py
CHANGED
@@ -255,8 +255,15 @@ def load_model(
|
|
255 |
)
|
256 |
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
257 |
# when training starts
|
258 |
-
if config
|
259 |
config.max_seq_len = cfg.sequence_len
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
model = AutoModelForCausalLM.from_pretrained(
|
261 |
base_model,
|
262 |
config=config,
|
|
|
255 |
)
|
256 |
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
257 |
# when training starts
|
258 |
+
if hasattr(config, "max_seq_len") and cfg.sequence_len > config.max_seq_len:
|
259 |
config.max_seq_len = cfg.sequence_len
|
260 |
+
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
261 |
+
elif (
|
262 |
+
hasattr(config, "max_sequence_length")
|
263 |
+
and cfg.sequence_len > config.max_sequence_length
|
264 |
+
):
|
265 |
+
config.max_sequence_length = cfg.sequence_len
|
266 |
+
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
267 |
model = AutoModelForCausalLM.from_pretrained(
|
268 |
base_model,
|
269 |
config=config,
|