VenusFactory / src /utils /logger.py
2dogey's picture
Upload folder using huggingface_hub
8918ac7 verified
import logging
import os
from typing import Dict, Any, Union, Tuple
import torch.nn as nn
from transformers import PreTrainedModel
def setup_logging(args: Dict[str, Any]) -> logging.Logger:
"""Setup logging configuration."""
# Create logger
logger = logging.getLogger('training')
logger.setLevel(logging.INFO)
# Create formatters and handlers
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# Console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# File handler
log_file = os.path.join(args.output_dir, f'{args.output_model_name.split(".")[0]}_training.log')
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# Log initial info
logger.info("Starting training with configuration:")
for key, value in vars(args).items():
logger.info(f"{key}: {value}")
return logger
def count_parameters(model: Union[nn.Module, PreTrainedModel]) -> Tuple[int, int]:
"""
Count total and trainable parameters in model.
Args:
model: PyTorch model or Hugging Face model
Returns:
Tuple of (total_params, trainable_params)
"""
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total_params, trainable_params
def format_parameter_count(count: int) -> str:
"""Format parameter count with appropriate unit."""
if count >= 1e9:
return f"{count/1e9:.2f}B"
elif count >= 1e6:
return f"{count/1e6:.2f}M"
elif count >= 1e3:
return f"{count/1e3:.2f}K"
return str(count)
def print_model_parameters(model: nn.Module, plm_model: PreTrainedModel, logger=None):
"""
Print parameter statistics for both adapter and PLM models.
Args:
model: Adapter model
plm_model: Pre-trained language model
logger: Optional logger for output
"""
# Count adapter parameters
adapter_total, adapter_trainable = count_parameters(model)
# Count PLM parameters
plm_total, plm_trainable = count_parameters(plm_model)
# Prepare output strings
output = [
"------------------------",
"Model Parameters Statistics:",
"------------------------",
f"Adapter Model:",
f" Total parameters: {format_parameter_count(adapter_total)}",
f" Trainable parameters: {format_parameter_count(adapter_trainable)}",
f"Pre-trained Model:",
f" Total parameters: {format_parameter_count(plm_total)}",
f" Trainable parameters: {format_parameter_count(plm_trainable)}",
f"Combined:",
f" Total parameters: {format_parameter_count(adapter_total + plm_total)}",
f" Trainable parameters: {format_parameter_count(adapter_trainable + plm_trainable)}",
f" Trainable percentage: {((adapter_trainable + plm_trainable) / (adapter_total + plm_total)) * 100:.2f}%",
"------------------------"
]
# Print output
if logger:
for line in output:
logger.info(line)
else:
for line in output:
print(line)