Spaces:
Running
Running
File size: 1,969 Bytes
c689089 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
# load_basis_model.py
# Load and initialize the base MOMENT model before finetuning
import torch
import logging
from momentfm import MOMENTPipeline
from transformer_model.scripts.config_transformer import (
FORECAST_HORIZON,
FREEZE_ENCODER,
FREEZE_EMBEDDER,
FREEZE_HEAD,
WEIGHT_DECAY,
HEAD_DROPOUT,
SEQ_LEN
)
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
def load_moment_model():
"""
Loads and configures the MOMENT model for forecasting.
"""
logging.info("Loading MOMENT model...")
model = MOMENTPipeline.from_pretrained(
"AutonLab/MOMENT-1-large",
model_kwargs={
'task_name': 'forecasting',
'forecast_horizon': FORECAST_HORIZON, # default = 1
'head_dropout': HEAD_DROPOUT, # default = 0.1
'weight_decay': WEIGHT_DECAY, # default = 0.0
'freeze_encoder': FREEZE_ENCODER, # default = True
'freeze_embedder': FREEZE_EMBEDDER, # default = True
'freeze_head': FREEZE_HEAD # default = False
}
)
model.init()
logging.info("Model initialized successfully.")
return model
def print_trainable_params(model):
"""
Logs all trainable (unfrozen) parameters of the model.
"""
logging.info("Unfrozen parameters:")
for name, param in model.named_parameters():
if param.requires_grad:
logging.info(f" {name}")
def test_dummy_forward(model):
"""
Performs a dummy forward pass to verify the model runs without error.
"""
logging.info("Running dummy forward pass with random tensors to see if model is running.")
dummy_x = torch.randn(16, 1, SEQ_LEN)
output = model(x_enc=dummy_x)
logging.info("Dummy forward pass successful.")
if __name__ == "__main__":
model = load_moment_model()
print_trainable_params(model)
test_dummy_forward(model)
|