kataria_opticals_api / classifier.py
codernotme's picture
update
7f3db4a verified
import os
import joblib
from huggingface_hub import hf_hub_download
from geometry import extract_features
from landmarks import get_landmarks
REPO_ID = "codernotme/kataria_optical"
MODEL_PATH = "face_shape_model.pkl"
# Global model cache
_model = None
def _get_feature_vector(features):
return [
features.get("lw_ratio", 0),
features.get("jaw_ratio", 0),
features.get("forehead_ratio", 0),
]
def load_model():
global _model
if _model is None:
local_path = MODEL_PATH
if not os.path.exists(local_path):
try:
print(f"Downloading {MODEL_PATH} from HF Hub...")
local_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_PATH)
except Exception as e:
print(f"Could not download from HF Hub: {e}")
return None
try:
_model = joblib.load(local_path)
print("Loaded face shape model.")
except Exception as e:
print(f"Failed to load model: {e}")
return _model
def classify_face_shape(image_input):
"""
Classifies face shape using the trained SVM model.
Args:
image_input: File path, PIL Image, or numpy array.
Returns:
dict: Sorted dictionary of probabilities.
"""
model = load_model()
if model is None or image_input is None:
return {"Unknown": 1.0}
try:
landmarks = get_landmarks(image_input)
feats = extract_features(landmarks)
vector = _get_feature_vector(feats)
probabilities = model.predict_proba([vector])[0]
labels = list(getattr(model, "classes_", []))
if not labels:
return {"Unknown": 1.0}
scores = {
str(label): round(float(score), 4)
for label, score in zip(labels, probabilities)
}
total_score = sum(scores.values()) or 1
scores = {k: round(float(v / total_score), 4) for k, v in scores.items()}
return dict(sorted(scores.items(), key=lambda item: item[1], reverse=True))
except Exception as e:
print(f"Prediction error: {e}")
return {"Error": 1.0}