| """ |
| Wrapper for dummy majority vote fusion model. |
| """ |
|
|
| import importlib.util |
| import sys |
| from pathlib import Path |
| from typing import Any, Dict, List |
|
|
| from app.core.errors import FusionError, ConfigurationError |
| from app.core.logging import get_logger |
| from app.models.wrappers.base_wrapper import BaseFusionWrapper |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class DummyMajorityFusionWrapper(BaseFusionWrapper): |
| """ |
| Wrapper for dummy majority vote fusion models. |
| |
| These models are hosted on Hugging Face and contain a fusion.py |
| with a predict() function that performs majority voting on submodel outputs. |
| """ |
| |
| def __init__( |
| self, |
| repo_id: str, |
| config: Dict[str, Any], |
| local_path: str |
| ): |
| """ |
| Initialize the wrapper. |
| |
| Args: |
| repo_id: Hugging Face repository ID (e.g., "DeepFakeDetector/fusion-majority-test") |
| config: Configuration from config.json |
| local_path: Local path where the model files are stored |
| """ |
| super().__init__(repo_id, config, local_path) |
| self._submodel_repos: List[str] = config.get("submodels", []) |
| logger.info(f"Initialized DummyMajorityFusionWrapper for {repo_id}") |
| logger.info(f"Submodels: {self._submodel_repos}") |
| |
| @property |
| def submodel_repos(self) -> List[str]: |
| """Get list of submodel repository IDs.""" |
| return self._submodel_repos |
| |
| def load(self) -> None: |
| """ |
| Load the fusion predict function from the downloaded repository. |
| |
| Dynamically imports predict.py and extracts the predict function. |
| """ |
| fusion_path = Path(self.local_path) / "predict.py" |
| |
| if not fusion_path.exists(): |
| raise ConfigurationError( |
| message=f"predict.py not found in {self.local_path}", |
| details={"repo_id": self.repo_id, "expected_path": str(fusion_path)} |
| ) |
| |
| try: |
| |
| module_name = f"hf_model_{self.name.replace('-', '_')}_fusion" |
| |
| |
| spec = importlib.util.spec_from_file_location(module_name, fusion_path) |
| if spec is None or spec.loader is None: |
| raise ConfigurationError( |
| message=f"Could not load spec for {fusion_path}", |
| details={"repo_id": self.repo_id} |
| ) |
| |
| module = importlib.util.module_from_spec(spec) |
| sys.modules[module_name] = module |
| spec.loader.exec_module(module) |
| |
| |
| if not hasattr(module, "predict"): |
| raise ConfigurationError( |
| message=f"predict.py does not have a 'predict' function", |
| details={"repo_id": self.repo_id} |
| ) |
| |
| self._predict_fn = module.predict |
| logger.info(f"Loaded fusion predict function from {self.repo_id}") |
| |
| except ConfigurationError: |
| raise |
| except Exception as e: |
| logger.error(f"Failed to load fusion function from {self.repo_id}: {e}") |
| raise ConfigurationError( |
| message=f"Failed to load fusion model: {e}", |
| details={"repo_id": self.repo_id, "error": str(e)} |
| ) |
| |
| def predict( |
| self, |
| submodel_outputs: Dict[str, Dict[str, Any]], |
| **kwargs |
| ) -> Dict[str, Any]: |
| """ |
| Run fusion prediction on submodel outputs. |
| |
| Args: |
| submodel_outputs: Dictionary mapping submodel name to its prediction output |
| **kwargs: Additional arguments passed to the fusion function |
| |
| Returns: |
| Standardized prediction dictionary with: |
| - pred_int: 0 or 1 |
| - pred: "real" or "fake" |
| - prob_fake: float (average of pred_ints) |
| - meta: dict |
| """ |
| if self._predict_fn is None: |
| raise FusionError( |
| message="Fusion model not loaded", |
| details={"repo_id": self.repo_id} |
| ) |
| |
| try: |
| |
| result = self._predict_fn(submodel_outputs=submodel_outputs, **kwargs) |
| |
| |
| standardized = self._standardize_output(result) |
| return standardized |
| |
| except FusionError: |
| raise |
| except Exception as e: |
| logger.error(f"Fusion prediction failed for {self.repo_id}: {e}") |
| raise FusionError( |
| message=f"Fusion prediction failed: {e}", |
| details={"repo_id": self.repo_id, "error": str(e)} |
| ) |
| |
| def _standardize_output(self, result: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Standardize the fusion output to ensure consistent format. |
| |
| Args: |
| result: Raw fusion output |
| |
| Returns: |
| Standardized dictionary |
| """ |
| pred_int = result.get("pred_int", 0) |
| |
| |
| if pred_int not in (0, 1): |
| pred_int = 1 if pred_int > 0.5 else 0 |
| |
| |
| pred = result.get("pred") |
| if pred is None: |
| pred = "fake" if pred_int == 1 else "real" |
| |
| |
| prob_fake = result.get("prob_fake") |
| if prob_fake is None: |
| prob_fake = float(pred_int) |
| |
| return { |
| "pred_int": pred_int, |
| "pred": pred, |
| "prob_fake": float(prob_fake), |
| "meta": result.get("meta", {}) |
| } |
|
|