crispr-array-detection / inference /model_loader.py
genomenet's picture
Improve Space default prediction responsiveness
f44b2b9
"""
Model loader with singleton pattern for CRISPR BERT model.
Ensures the model is loaded only once and reused across requests.
"""
import os
import logging
from pathlib import Path
from typing import Optional
import threading
import numpy as np
import tensorflow as tf
from huggingface_hub import hf_hub_download
from .custom_layers import get_custom_objects
from .tokenizer import WINDOW_SIZE
logger = logging.getLogger(__name__)
# Singleton state
_model: Optional[tf.keras.Model] = None
_embedding_model: Optional[tf.keras.Model] = None
_model_lock = threading.Lock()
# HuggingFace model repository
HF_MODEL_REPO = os.environ.get("CRISPR_HF_REPO", "genomenet/crispr-bert-model")
HF_MODEL_FILENAME = os.environ.get("CRISPR_HF_FILENAME", "best.h5")
# Local model path (optional override)
DEFAULT_MODEL_PATH = os.environ.get("CRISPR_MODEL_PATH", "")
# Embedding layer name for hidden state extraction
# Note: Fine-tuned model has 22 blocks (0-21), base BERT has 24 (0-23)
EMBEDDING_LAYER = os.environ.get(
"CRISPR_EMBEDDING_LAYER",
"layer_transformer_block_21"
)
def setup_gpu():
"""Configure GPU memory growth to avoid OOM errors."""
gpus = tf.config.list_physical_devices("GPU")
if gpus:
for gpu in gpus:
try:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
logger.warning(f"GPU memory growth setting failed: {e}")
logger.info(f"GPUs available: {[g.name for g in gpus]}")
return True
else:
logger.warning("No GPU found. Running on CPU.")
return False
def load_model(model_path: Optional[str] = None) -> tf.keras.Model:
"""
Load the CRISPR detection model.
Downloads from HuggingFace Hub if no local path is provided.
Args:
model_path: Path to model file (.h5 or .keras)
Returns:
Loaded Keras model
"""
# Use provided path, environment variable, or download from HF Hub
if model_path:
path = Path(model_path)
elif DEFAULT_MODEL_PATH:
path = Path(DEFAULT_MODEL_PATH)
else:
# Download from HuggingFace Hub
logger.info(f"Downloading model from HuggingFace: {HF_MODEL_REPO}/{HF_MODEL_FILENAME}")
path = Path(hf_hub_download(
repo_id=HF_MODEL_REPO,
filename=HF_MODEL_FILENAME
))
logger.info(f"Model downloaded to: {path}")
if not path.exists():
raise FileNotFoundError(
f"Model file not found: {path}\n"
f"Please set CRISPR_MODEL_PATH or ensure HF_MODEL_REPO is accessible."
)
logger.info(f"Loading model from: {path}")
custom_objects = get_custom_objects()
model = tf.keras.models.load_model(str(path), custom_objects=custom_objects, compile=False)
logger.info(f"Model loaded. Input shape: {model.input_shape}, Output shape: {model.output_shape}")
return model
def build_embedding_model(model: tf.keras.Model, layer_name: str = EMBEDDING_LAYER) -> tf.keras.Model:
"""
Build a sub-model that outputs hidden states from a specific layer.
Args:
model: Full CRISPR detection model
layer_name: Name of the layer to extract embeddings from
Returns:
Keras model that outputs embeddings
"""
try:
embedding_output = model.get_layer(layer_name).output
except ValueError:
# Try to find a suitable layer
available_layers = [l.name for l in model.layers if "transformer" in l.name.lower()]
raise ValueError(
f"Layer '{layer_name}' not found in model. "
f"Available transformer layers: {available_layers}"
)
embedding_model = tf.keras.Model(
inputs=model.inputs,
outputs=embedding_output,
name="embedding_model"
)
logger.info(f"Embedding model built. Output shape: {embedding_model.output_shape}")
return embedding_model
def get_model(model_path: Optional[str] = None) -> tf.keras.Model:
"""
Get the singleton model instance.
Thread-safe lazy loading of the model.
Args:
model_path: Optional path to model file
Returns:
Loaded Keras model
"""
global _model
if _model is None:
with _model_lock:
if _model is None:
setup_gpu()
_model = load_model(model_path)
return _model
def get_embedding_model(model_path: Optional[str] = None, layer_name: str = EMBEDDING_LAYER) -> tf.keras.Model:
"""
Get the singleton embedding model instance.
Args:
model_path: Optional path to model file
layer_name: Name of layer to extract embeddings from
Returns:
Embedding extraction model
"""
global _embedding_model
if _embedding_model is None:
with _model_lock:
if _embedding_model is None:
model = get_model(model_path)
_embedding_model = build_embedding_model(model, layer_name)
return _embedding_model
def warmup_model(model: Optional[tf.keras.Model] = None):
"""
Warm up the model by running a dummy inference.
This triggers graph compilation and avoids slow first request.
Args:
model: Model to warm up (uses singleton if not provided)
"""
if model is None:
model = get_model()
logger.info("Warming up model...")
# Determine expected input dtype
expected_dtype = model.inputs[0].dtype
if expected_dtype.is_floating:
dtype = np.float32
elif expected_dtype == tf.int64:
dtype = np.int64
else:
dtype = np.int32
# Create dummy input
dummy = np.ones((1, WINDOW_SIZE), dtype=dtype)
# Run inference
_ = model(dummy, training=False)
logger.info("Model warmup complete.")
def get_model_info() -> dict:
"""
Get information about the loaded model.
Returns:
Dictionary with model metadata
"""
model = get_model()
return {
"input_shape": str(model.input_shape),
"output_shape": str(model.output_shape),
"input_dtype": str(model.inputs[0].dtype.name),
"num_parameters": int(model.count_params()),
"num_layers": len(model.layers),
}
def is_model_loaded() -> bool:
"""Check if the model has been loaded."""
return _model is not None
def get_gpu_status() -> dict:
"""Get GPU availability status."""
gpus = tf.config.list_physical_devices("GPU")
return {
"gpu_available": len(gpus) > 0,
"gpu_count": len(gpus),
"gpu_names": [g.name for g in gpus],
}