Spaces:
Running
Running
""" | |
SmolLM3 Model Wrapper | |
Handles model loading, tokenizer, and training setup | |
""" | |
import os | |
import torch | |
import torch.nn as nn | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
AutoConfig, | |
TrainingArguments, | |
Trainer | |
) | |
from typing import Optional, Dict, Any | |
import logging | |
logger = logging.getLogger(__name__) | |
class SmolLM3Model: | |
"""Wrapper for SmolLM3 model and tokenizer""" | |
def __init__( | |
self, | |
model_name: str = "HuggingFaceTB/SmolLM3-3B", | |
max_seq_length: int = 4096, | |
config: Optional[Any] = None, | |
device_map: Optional[str] = None, | |
torch_dtype: Optional[torch.dtype] = None | |
): | |
self.model_name = model_name | |
self.max_seq_length = max_seq_length | |
self.config = config | |
# Set device and dtype | |
if torch_dtype is None: | |
if torch.cuda.is_available(): | |
# Check if config specifies mixed precision | |
if config and hasattr(config, 'fp16') and config.fp16: | |
# Use fp16 if explicitly configured | |
self.torch_dtype = torch.float16 | |
elif config and hasattr(config, 'bf16') and config.bf16: | |
# Use bf16 if explicitly configured | |
self.torch_dtype = torch.bfloat16 | |
else: | |
# Default to bfp16 for better compatibility | |
self.torch_dtype = torch.bfloat16 | |
else: | |
self.torch_dtype = torch.float32 | |
else: | |
self.torch_dtype = torch_dtype | |
if device_map is None: | |
self.device_map = "auto" if torch.cuda.is_available() else "cpu" | |
else: | |
self.device_map = device_map | |
# Load tokenizer and model | |
self._load_tokenizer() | |
self._load_model() | |
def _load_tokenizer(self): | |
"""Load the tokenizer""" | |
logger.info("Loading tokenizer from %s", self.model_name) | |
try: | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, | |
trust_remote_code=True, | |
use_fast=True | |
) | |
# Set pad token if not present | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
logger.info("Tokenizer loaded successfully. Vocab size: %d", self.tokenizer.vocab_size) | |
except Exception as e: | |
logger.error("Failed to load tokenizer: %s", e) | |
raise | |
def _load_model(self): | |
"""Load the model""" | |
logger.info("Loading model from %s", self.model_name) | |
try: | |
# Load model configuration | |
model_config = AutoConfig.from_pretrained( | |
self.model_name, | |
trust_remote_code=True | |
) | |
# Update configuration if needed | |
if hasattr(model_config, 'max_position_embeddings'): | |
model_config.max_position_embeddings = self.max_seq_length | |
# SmolLM3-specific optimizations for long context | |
if hasattr(model_config, 'rope_scaling'): | |
# Enable YaRN scaling for long context | |
model_config.rope_scaling = {"type": "yarn", "factor": 2.0} | |
logger.info("Enabled YaRN scaling for long context") | |
# Load model | |
model_kwargs = { | |
"torch_dtype": self.torch_dtype, | |
"device_map": self.device_map, | |
"trust_remote_code": True | |
} | |
# Only add flash attention if the model supports it | |
if hasattr(self.config, 'use_flash_attention') and self.config.use_flash_attention: | |
try: | |
# Test if the model supports flash attention | |
test_config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=True) | |
if hasattr(test_config, 'use_flash_attention_2'): | |
model_kwargs["use_flash_attention_2"] = True | |
logger.info("Enabled Flash Attention 2 for better long context performance") | |
except: | |
# If flash attention is not supported, skip it | |
pass | |
# Try to load the model, fallback to fp16 if bf16 fails | |
try: | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
config=model_config, | |
**model_kwargs | |
) | |
except RuntimeError as e: | |
if "bfloat16" in str(e) or "BFloat16" in str(e): | |
logger.warning("BFloat16 not supported, falling back to Float16") | |
model_kwargs["torch_dtype"] = torch.float16 | |
self.torch_dtype = torch.float16 | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
config=model_config, | |
**model_kwargs | |
) | |
else: | |
raise | |
# Enable gradient checkpointing if specified | |
if self.config and self.config.use_gradient_checkpointing: | |
self.model.gradient_checkpointing_enable() | |
logger.info("Model loaded successfully. Parameters: {:,}".format(self.model.num_parameters())) | |
logger.info("Max sequence length: %d", self.max_seq_length) | |
except Exception as e: | |
logger.error("Failed to load model: %s", e) | |
raise | |
def get_training_arguments(self, output_dir: str, **kwargs) -> TrainingArguments: | |
"""Get training arguments for the Trainer""" | |
if self.config is None: | |
raise ValueError("Config is required to get training arguments") | |
# Merge config with kwargs | |
training_args = { | |
"output_dir": output_dir, | |
"per_device_train_batch_size": self.config.batch_size, | |
"per_device_eval_batch_size": self.config.batch_size, | |
"gradient_accumulation_steps": self.config.gradient_accumulation_steps, | |
"learning_rate": self.config.learning_rate, | |
"weight_decay": self.config.weight_decay, | |
"warmup_steps": self.config.warmup_steps, | |
"max_steps": self.config.max_iters, | |
"save_steps": self.config.save_steps, | |
"eval_steps": self.config.eval_steps, | |
"logging_steps": self.config.logging_steps, | |
"save_total_limit": self.config.save_total_limit, | |
"eval_strategy": self.config.eval_strategy, | |
"metric_for_best_model": self.config.metric_for_best_model, | |
"greater_is_better": self.config.greater_is_better, | |
"load_best_model_at_end": self.config.load_best_model_at_end, | |
"fp16": self.config.fp16, | |
"bf16": self.config.bf16, | |
# Only enable DDP if multiple GPUs are available | |
"ddp_backend": self.config.ddp_backend if torch.cuda.device_count() > 1 else None, | |
"ddp_find_unused_parameters": self.config.ddp_find_unused_parameters if torch.cuda.device_count() > 1 else False, | |
"report_to": None, # Disable external logging - use None instead of "none" | |
"remove_unused_columns": False, | |
"dataloader_pin_memory": getattr(self.config, 'dataloader_pin_memory', False), | |
"group_by_length": getattr(self.config, 'group_by_length', True), | |
"length_column_name": "length", | |
"ignore_data_skip": False, | |
"seed": 42, | |
"data_seed": 42, | |
"dataloader_num_workers": getattr(self.config, 'dataloader_num_workers', 4), | |
"max_grad_norm": getattr(self.config, 'max_grad_norm', 1.0), | |
"optim": self.config.optimizer, | |
"lr_scheduler_type": self.config.scheduler, | |
"warmup_ratio": 0.1, | |
"save_strategy": "steps", | |
"logging_strategy": "steps", | |
"prediction_loss_only": True, | |
} | |
# Ensure boolean parameters are properly typed | |
if "dataloader_pin_memory" in training_args: | |
training_args["dataloader_pin_memory"] = bool(training_args["dataloader_pin_memory"]) | |
if "group_by_length" in training_args: | |
training_args["group_by_length"] = bool(training_args["group_by_length"]) | |
if "prediction_loss_only" in training_args: | |
training_args["prediction_loss_only"] = bool(training_args["prediction_loss_only"]) | |
if "ignore_data_skip" in training_args: | |
training_args["ignore_data_skip"] = bool(training_args["ignore_data_skip"]) | |
if "remove_unused_columns" in training_args: | |
training_args["remove_unused_columns"] = bool(training_args["remove_unused_columns"]) | |
if "ddp_find_unused_parameters" in training_args: | |
training_args["ddp_find_unused_parameters"] = bool(training_args["ddp_find_unused_parameters"]) | |
if "fp16" in training_args: | |
training_args["fp16"] = bool(training_args["fp16"]) | |
if "bf16" in training_args: | |
training_args["bf16"] = bool(training_args["bf16"]) | |
if "load_best_model_at_end" in training_args: | |
training_args["load_best_model_at_end"] = bool(training_args["load_best_model_at_end"]) | |
if "greater_is_better" in training_args: | |
training_args["greater_is_better"] = bool(training_args["greater_is_better"]) | |
# Add dataloader_prefetch_factor if it exists in config | |
if hasattr(self.config, 'dataloader_prefetch_factor'): | |
try: | |
# Test if the parameter is supported by creating a dummy TrainingArguments | |
test_args = TrainingArguments(output_dir="/tmp/test", dataloader_prefetch_factor=2) | |
training_args["dataloader_prefetch_factor"] = self.config.dataloader_prefetch_factor | |
logger.info("Added dataloader_prefetch_factor: %d", self.config.dataloader_prefetch_factor) | |
except Exception as e: | |
logger.warning("dataloader_prefetch_factor not supported in this transformers version: %s", e) | |
# Remove the parameter if it's not supported | |
if "dataloader_prefetch_factor" in training_args: | |
del training_args["dataloader_prefetch_factor"] | |
# Override with kwargs | |
training_args.update(kwargs) | |
# Clean up any None values that might cause issues | |
training_args = {k: v for k, v in training_args.items() if v is not None} | |
return TrainingArguments(**training_args) | |
def save_pretrained(self, path: str): | |
"""Save model and tokenizer""" | |
logger.info("Saving model and tokenizer to %s", path) | |
os.makedirs(path, exist_ok=True) | |
self.model.save_pretrained(path) | |
self.tokenizer.save_pretrained(path) | |
# Save configuration | |
if self.config: | |
import json | |
config_dict = {k: v for k, v in self.config.__dict__.items() | |
if not k.startswith('_')} | |
with open(os.path.join(path, 'training_config.json'), 'w') as f: | |
json.dump(config_dict, f, indent=2, default=str) | |
def load_checkpoint(self, checkpoint_path: str): | |
"""Load model from checkpoint""" | |
logger.info("Loading checkpoint from %s", checkpoint_path) | |
try: | |
self.model = AutoModelForCausalLM.from_pretrained( | |
checkpoint_path, | |
torch_dtype=self.torch_dtype, | |
device_map=self.device_map, | |
trust_remote_code=True | |
) | |
logger.info("Checkpoint loaded successfully") | |
except Exception as e: | |
logger.error("Failed to load checkpoint: %s", e) | |
raise |