thadillo
Fix: Use /data/models/finetuned for HF Spaces training
b08ba59
"""
Model Manager for Fine-Tuned Model Deployment and Versioning
Handles loading, deploying, and rolling back fine-tuned models.
"""
import os
import json
import shutil
from typing import Optional, Dict
from datetime import datetime
import logging
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
logger = logging.getLogger(__name__)
class ModelManager:
"""Manage fine-tuned model deployment and versioning"""
def __init__(self, models_dir: str = None):
"""
Initialize ModelManager.
Args:
models_dir: Base directory for storing fine-tuned models
(defaults to MODELS_DIR env var or './models/finetuned')
"""
if models_dir is None:
# Use environment variable or /data path for HF Spaces
models_dir = os.getenv('MODELS_DIR', '/data/models/finetuned')
self.models_dir = models_dir
self.base_model_name = "facebook/bart-large-mnli"
# Create directory if it doesn't exist
try:
os.makedirs(models_dir, exist_ok=True)
except PermissionError:
logger.error(f"Permission denied creating models directory: {models_dir}")
raise
def get_model_path(self, run_id: int) -> str:
"""Get path to model for a specific training run"""
return os.path.join(self.models_dir, f"run_{run_id}")
def load_model(self, run_id: Optional[int] = None):
"""
Load a fine-tuned model or base model.
Args:
run_id: Training run ID (None for base model)
Returns:
Tuple of (model, tokenizer)
"""
if run_id is None:
logger.info("Loading base model")
model_name = self.base_model_name
else:
model_path = self.get_model_path(run_id)
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model not found: {model_path}")
logger.info(f"Loading fine-tuned model from run {run_id}")
model_name = model_path
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
ignore_mismatched_sizes=True
)
return model, tokenizer
def deploy_model(self, run_id: int, db_session) -> Dict:
"""
Deploy a fine-tuned model (set as active).
Args:
run_id: Training run ID to deploy
db_session: Database session for updating FineTuningRun
Returns:
Dict with deployment info
"""
from app.models.models import FineTuningRun
logger.info(f"Deploying model from run {run_id}")
# Verify model exists
model_path = self.get_model_path(run_id)
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model not found: {model_path}")
# Get the run record
run = db_session.query(FineTuningRun).filter_by(id=run_id).first()
if not run:
raise ValueError(f"Training run {run_id} not found")
if run.status != 'completed':
raise ValueError(f"Cannot deploy non-completed run (status: {run.status})")
# Deactivate all other models
db_session.query(FineTuningRun).update({'is_active_model': False})
# Activate this model
run.is_active_model = True
db_session.commit()
logger.info(f"Model from run {run_id} is now active")
return {
'run_id': run_id,
'deployed_at': datetime.utcnow().isoformat(),
'model_path': model_path
}
def rollback_to_baseline(self, db_session) -> Dict:
"""
Rollback to base model (deactivate all fine-tuned models).
Args:
db_session: Database session
Returns:
Dict with rollback info
"""
from app.models.models import FineTuningRun
logger.info("Rolling back to base model")
# Deactivate all fine-tuned models
active_count = db_session.query(FineTuningRun).filter_by(is_active_model=True).count()
db_session.query(FineTuningRun).update({'is_active_model': False})
db_session.commit()
logger.info(f"Deactivated {active_count} fine-tuned model(s)")
return {
'rolled_back_at': datetime.utcnow().isoformat(),
'deactivated_models': active_count,
'active_model': 'base'
}
def get_active_model_info(self, db_session) -> Optional[Dict]:
"""
Get information about the currently active model.
Args:
db_session: Database session
Returns:
Dict with active model info, or None if base model is active
"""
from app.models.models import FineTuningRun
active_run = db_session.query(FineTuningRun).filter_by(is_active_model=True).first()
if not active_run:
return None
return {
'run_id': active_run.id,
'model_path': self.get_model_path(active_run.id),
'created_at': active_run.created_at.isoformat() if active_run.created_at else None,
'results': active_run.get_results(),
'config': active_run.get_config()
}
def export_model(self, run_id: int, export_path: str) -> str:
"""
Export model for backup or sharing.
Args:
run_id: Training run ID
export_path: Destination path for export
Returns:
Path to exported model
"""
logger.info(f"Exporting model from run {run_id}")
model_path = self.get_model_path(run_id)
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model not found: {model_path}")
# Create export directory
os.makedirs(export_path, exist_ok=True)
# Copy all model files
export_model_path = os.path.join(export_path, f"model_run_{run_id}")
shutil.copytree(model_path, export_model_path, dirs_exist_ok=True)
# Create model card
model_card = {
'run_id': run_id,
'export_date': datetime.utcnow().isoformat(),
'base_model': self.base_model_name,
'model_type': 'BART with LoRA fine-tuning',
'task': 'Multi-class text classification',
'categories': ['Vision', 'Problem', 'Objectives', 'Directives', 'Values', 'Actions']
}
with open(os.path.join(export_model_path, 'model_card.json'), 'w') as f:
json.dump(model_card, f, indent=2)
logger.info(f"Model exported to {export_model_path}")
return export_model_path
def import_model(self, import_path: str, run_id: int) -> str:
"""
Import a previously exported model.
Args:
import_path: Path to imported model directory
run_id: Training run ID to assign
Returns:
Path to imported model in models directory
"""
logger.info(f"Importing model to run {run_id}")
if not os.path.exists(import_path):
raise FileNotFoundError(f"Import path not found: {import_path}")
# Verify it's a valid model directory
required_files = ['config.json', 'pytorch_model.bin'] # or adapter_model.bin for LoRA
has_required = any(os.path.exists(os.path.join(import_path, f)) for f in required_files)
if not has_required:
raise ValueError(f"Import path does not contain a valid model")
# Copy to models directory
model_path = self.get_model_path(run_id)
shutil.copytree(import_path, model_path, dirs_exist_ok=True)
logger.info(f"Model imported to {model_path}")
return model_path
def delete_model(self, run_id: int) -> None:
"""
Delete a fine-tuned model from disk.
Args:
run_id: Training run ID
"""
logger.info(f"Deleting model from run {run_id}")
model_path = self.get_model_path(run_id)
if os.path.exists(model_path):
shutil.rmtree(model_path)
logger.info(f"Model deleted: {model_path}")
else:
logger.warning(f"Model not found: {model_path}")
def get_model_size(self, run_id: int) -> Dict:
"""
Get size information for a model.
Args:
run_id: Training run ID
Returns:
Dict with size info
"""
model_path = self.get_model_path(run_id)
if not os.path.exists(model_path):
return {'exists': False}
# Calculate directory size
total_size = 0
file_count = 0
for dirpath, dirnames, filenames in os.walk(model_path):
for filename in filenames:
filepath = os.path.join(dirpath, filename)
total_size += os.path.getsize(filepath)
file_count += 1
return {
'exists': True,
'total_size_bytes': total_size,
'total_size_mb': round(total_size / (1024 * 1024), 2),
'file_count': file_count,
'path': model_path
}
def list_available_models(self, db_session) -> list:
"""
List all available fine-tuned models.
Args:
db_session: Database session
Returns:
List of dicts with model info
"""
from app.models.models import FineTuningRun
runs = db_session.query(FineTuningRun).filter_by(status='completed').all()
models = []
for run in runs:
model_path = self.get_model_path(run.id)
size_info = self.get_model_size(run.id)
models.append({
'run_id': run.id,
'created_at': run.created_at.isoformat() if run.created_at else None,
'is_active': run.is_active_model,
'results': run.get_results(),
'model_exists': size_info.get('exists', False),
'size_mb': size_info.get('total_size_mb', 0)
})
return models