Spaces:
Runtime error
Runtime error
import lightning as pl | |
from lightning.pytorch.callbacks import ( | |
ModelCheckpoint, | |
EarlyStopping, | |
LearningRateMonitor, | |
RichProgressBar, | |
) | |
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger | |
from lightning.pytorch.callbacks import ModelSummary | |
from src.dataloader import MNISTDataModule | |
from src.model import LitEfficientNet | |
from loguru import logger | |
import os | |
from src.utils.aws_s3_services import S3Handler | |
# Ensure the logs directory exists | |
os.makedirs("logs", exist_ok=True) | |
# Configure Loguru for logging | |
logger.add("logs/training.log", rotation="1 MB", level="INFO") | |
def main(): | |
""" | |
Main training loop for the model with advanced configuration (CPU training). | |
""" | |
# Data Module | |
logger.info("Setting up data module...") | |
data_module = MNISTDataModule(batch_size=256) | |
# Model | |
logger.info("Setting up model...") | |
model = LitEfficientNet(model_name="tf_efficientnet_lite0", num_classes=10, lr=1e-3) | |
logger.info(model) | |
# Callbacks | |
logger.info("Setting up callbacks...") | |
checkpoint_callback = ModelCheckpoint( | |
monitor="val_acc", | |
dirpath="checkpoints/", | |
filename="best_model", | |
save_top_k=1, | |
mode="max", | |
auto_insert_metric_name=False, | |
verbose=True, | |
save_last=True, | |
enable_version_counter=False, | |
) | |
early_stopping_callback = EarlyStopping( | |
monitor="val_acc", | |
patience=5, # Extended patience for advanced models | |
mode="max", | |
verbose=True, | |
) | |
lr_monitor = LearningRateMonitor(logging_interval="epoch") # Log learning rate | |
rich_progress = RichProgressBar() | |
model_summary = ModelSummary( | |
max_depth=1 | |
) # Show only the first level of model layers | |
# Loggers | |
logger.info("Setting up loggers...") | |
csv_logger = CSVLogger("logs/", name="mnist_csv") | |
tb_logger = TensorBoardLogger("logs/", name="mnist_tb") | |
# Trainer Configuration for CPU | |
logger.info("Setting up trainer...") | |
trainer = pl.Trainer( | |
max_epochs=2, | |
callbacks=[ | |
checkpoint_callback, | |
early_stopping_callback, | |
lr_monitor, | |
rich_progress, | |
model_summary, | |
], | |
logger=[csv_logger, tb_logger], | |
deterministic=True, | |
accelerator="auto", | |
devices="auto", | |
) | |
# Train the model | |
logger.info("Training the model...") | |
trainer.fit(model, datamodule=data_module) | |
# Test the model | |
logger.info("Testing the model...") | |
data_module.setup(stage="test") | |
trainer.test(model, datamodule=data_module) | |
# write a checkpoints/train_done.flag | |
with open("checkpoints/train_done.flag", "w") as f: | |
f.write("Training done.") | |
# upload checkpoints to S3 | |
s3_handler = S3Handler(bucket_name="deep-bucket-s3") | |
s3_handler.upload_folder( | |
"checkpoints", | |
"checkpoints_test", | |
) | |
if __name__ == "__main__": | |
main() | |