Spaces:
Runtime error
Runtime error
File size: 3,599 Bytes
8918ac7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import logging
import os
from typing import Dict, Any, Union, Tuple
import torch.nn as nn
from transformers import PreTrainedModel
def setup_logging(args: Dict[str, Any]) -> logging.Logger:
"""Setup logging configuration."""
# Create logger
logger = logging.getLogger('training')
logger.setLevel(logging.INFO)
# Create formatters and handlers
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# Console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# File handler
log_file = os.path.join(args.output_dir, f'{args.output_model_name.split(".")[0]}_training.log')
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# Log initial info
logger.info("Starting training with configuration:")
for key, value in vars(args).items():
logger.info(f"{key}: {value}")
return logger
def count_parameters(model: Union[nn.Module, PreTrainedModel]) -> Tuple[int, int]:
"""
Count total and trainable parameters in model.
Args:
model: PyTorch model or Hugging Face model
Returns:
Tuple of (total_params, trainable_params)
"""
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total_params, trainable_params
def format_parameter_count(count: int) -> str:
"""Format parameter count with appropriate unit."""
if count >= 1e9:
return f"{count/1e9:.2f}B"
elif count >= 1e6:
return f"{count/1e6:.2f}M"
elif count >= 1e3:
return f"{count/1e3:.2f}K"
return str(count)
def print_model_parameters(model: nn.Module, plm_model: PreTrainedModel, logger=None):
"""
Print parameter statistics for both adapter and PLM models.
Args:
model: Adapter model
plm_model: Pre-trained language model
logger: Optional logger for output
"""
# Count adapter parameters
adapter_total, adapter_trainable = count_parameters(model)
# Count PLM parameters
plm_total, plm_trainable = count_parameters(plm_model)
# Prepare output strings
output = [
"------------------------",
"Model Parameters Statistics:",
"------------------------",
f"Adapter Model:",
f" Total parameters: {format_parameter_count(adapter_total)}",
f" Trainable parameters: {format_parameter_count(adapter_trainable)}",
f"Pre-trained Model:",
f" Total parameters: {format_parameter_count(plm_total)}",
f" Trainable parameters: {format_parameter_count(plm_trainable)}",
f"Combined:",
f" Total parameters: {format_parameter_count(adapter_total + plm_total)}",
f" Trainable parameters: {format_parameter_count(adapter_trainable + plm_trainable)}",
f" Trainable percentage: {((adapter_trainable + plm_trainable) / (adapter_total + plm_total)) * 100:.2f}%",
"------------------------"
]
# Print output
if logger:
for line in output:
logger.info(line)
else:
for line in output:
print(line) |