Spaces:
Runtime error
Runtime error
import logging | |
import os | |
from typing import Dict, Any, Union, Tuple | |
import torch.nn as nn | |
from transformers import PreTrainedModel | |
def setup_logging(args: Dict[str, Any]) -> logging.Logger: | |
"""Setup logging configuration.""" | |
# Create logger | |
logger = logging.getLogger('training') | |
logger.setLevel(logging.INFO) | |
# Create formatters and handlers | |
formatter = logging.Formatter( | |
'%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S' | |
) | |
# Console handler | |
console_handler = logging.StreamHandler() | |
console_handler.setLevel(logging.INFO) | |
console_handler.setFormatter(formatter) | |
logger.addHandler(console_handler) | |
# File handler | |
log_file = os.path.join(args.output_dir, f'{args.output_model_name.split(".")[0]}_training.log') | |
file_handler = logging.FileHandler(log_file) | |
file_handler.setLevel(logging.INFO) | |
file_handler.setFormatter(formatter) | |
logger.addHandler(file_handler) | |
# Log initial info | |
logger.info("Starting training with configuration:") | |
for key, value in vars(args).items(): | |
logger.info(f"{key}: {value}") | |
return logger | |
def count_parameters(model: Union[nn.Module, PreTrainedModel]) -> Tuple[int, int]: | |
""" | |
Count total and trainable parameters in model. | |
Args: | |
model: PyTorch model or Hugging Face model | |
Returns: | |
Tuple of (total_params, trainable_params) | |
""" | |
total_params = sum(p.numel() for p in model.parameters()) | |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
return total_params, trainable_params | |
def format_parameter_count(count: int) -> str: | |
"""Format parameter count with appropriate unit.""" | |
if count >= 1e9: | |
return f"{count/1e9:.2f}B" | |
elif count >= 1e6: | |
return f"{count/1e6:.2f}M" | |
elif count >= 1e3: | |
return f"{count/1e3:.2f}K" | |
return str(count) | |
def print_model_parameters(model: nn.Module, plm_model: PreTrainedModel, logger=None): | |
""" | |
Print parameter statistics for both adapter and PLM models. | |
Args: | |
model: Adapter model | |
plm_model: Pre-trained language model | |
logger: Optional logger for output | |
""" | |
# Count adapter parameters | |
adapter_total, adapter_trainable = count_parameters(model) | |
# Count PLM parameters | |
plm_total, plm_trainable = count_parameters(plm_model) | |
# Prepare output strings | |
output = [ | |
"------------------------", | |
"Model Parameters Statistics:", | |
"------------------------", | |
f"Adapter Model:", | |
f" Total parameters: {format_parameter_count(adapter_total)}", | |
f" Trainable parameters: {format_parameter_count(adapter_trainable)}", | |
f"Pre-trained Model:", | |
f" Total parameters: {format_parameter_count(plm_total)}", | |
f" Trainable parameters: {format_parameter_count(plm_trainable)}", | |
f"Combined:", | |
f" Total parameters: {format_parameter_count(adapter_total + plm_total)}", | |
f" Trainable parameters: {format_parameter_count(adapter_trainable + plm_trainable)}", | |
f" Trainable percentage: {((adapter_trainable + plm_trainable) / (adapter_total + plm_total)) * 100:.2f}%", | |
"------------------------" | |
] | |
# Print output | |
if logger: | |
for line in output: | |
logger.info(line) | |
else: | |
for line in output: | |
print(line) |