taraky's picture
Upload folder using huggingface_hub
b7f3196 verified
"""
Utilities for Healthcare Classification System
This module contains shared constants and utilities for the healthcare
classification system.
"""
from classifier.head import ClassifierHead
from classifier.config import load_env
import os
from sentence_transformers import SentenceTransformer
import torch
from datetime import datetime
from pathlib import Path
# Load environment variables (including HF_TOKEN)
load_env()
MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical"
CLASSIFIER_NAME = "davidgray/health-query-triage"
CATEGORIES: list[str] = ["medical", "insurance"]
# Model and training configuration
MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical"
CHECKPOINT_PATH = "classifier/checkpoints"
DATETIME_FORMAT = "%Y%m%d_%H%M%S"
# Device configuration - use David's newer approach with fallback
try:
DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
except AttributeError:
# Fallback for older PyTorch versions
if torch.backends.mps.is_available():
DEVICE = torch.device("mps")
elif torch.cuda.is_available():
DEVICE = torch.device("cuda")
else:
DEVICE = torch.device("cpu")
print(f"Using {DEVICE} device")
def get_models(model_id: str | None = None, num_labels: int = len(CATEGORIES)) -> tuple[SentenceTransformer, ClassifierHead]:
"""
Loads embeddinggemma-300m-medical model and initializes the classification head.
Returns:
tuple: (embedding_model, classifier_head)
"""
try:
model_body = SentenceTransformer(
MODEL_NAME,
prompts={
'classification': 'task: classification | query: ',
'retrieval (query)': 'task: search result | query: ',
'retrieval (document)': 'title: {title | "none"} | text: ',
},
default_prompt_name='classification',
)
if model_id:
model_head = ClassifierHead.from_pretrained(model_id)
else:
model_head = ClassifierHead(num_labels)
except Exception as e:
print(f"Error loading model {MODEL_NAME}: {e}")
print("Please ensure you have an internet connection and the transformers library installed.")
raise RuntimeError("Failed to load the embedding model.")
return model_body.to(DEVICE), model_head.to(DEVICE)
def get_latest_checkpoint(checkpoint_path: str):
return os.path.join(checkpoint_path, sorted(os.listdir(checkpoint_path))[-1])