File size: 8,587 Bytes
8f7f87a
 
1146644
8f7f87a
 
 
 
 
e1eac06
8f7f87a
 
 
 
 
 
 
 
 
 
39558cb
8f7f87a
 
 
 
 
 
 
e1eac06
8f7f87a
 
 
 
 
 
 
e1eac06
8f7f87a
 
 
 
1146644
e1eac06
1146644
 
 
 
e1eac06
1146644
 
 
 
 
 
 
 
 
 
 
e1eac06
 
1146644
 
 
 
e1eac06
1146644
8f7f87a
 
 
39558cb
 
8f7f87a
 
 
 
 
e1eac06
8f7f87a
e1eac06
8f7f87a
 
 
e1eac06
8f7f87a
 
 
e1eac06
8f7f87a
 
 
e1eac06
 
 
8f7f87a
 
e1eac06
8f7f87a
 
e1eac06
 
 
8f7f87a
 
e1eac06
8f7f87a
 
e1eac06
8f7f87a
 
 
e1eac06
8f7f87a
 
e1eac06
 
8f7f87a
e1eac06
 
 
8f7f87a
 
 
e1eac06
73783f9
e1eac06
 
 
 
8f7f87a
 
 
e1eac06
73783f9
e1eac06
 
 
 
8f7f87a
 
 
e1eac06
8f7f87a
 
e1eac06
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import logging
import torch
from utils.registry import MODEL_REGISTRY # Import MODEL_REGISTRY

logger = logging.getLogger(__name__)

class ContextualWeightOverrideAgent:
    def __init__(self):
        logger.info("Initializing ContextualWeightOverrideAgent.")
        self.context_overrides = {
            # Example: when image is outdoor, model_X is penalized, model_Y is boosted
            "outdoor": {
                "model_1": 0.8, # Example: Reduce weight of model_1 by 20% for outdoor scenes
                "model_5": 1.2, # Example: Boost weight of model_5 by 20% for outdoor scenes
            },
            "low_light": {
                "model_2": 0.7, 
                "model_7": 1.3,
            },
            "sunny": {
                "model_3": 0.9,
                "model_4": 1.1,
            }
            # Add more contexts and their specific model weight adjustments here
        }

    def get_overrides(self, context_tags: list[str]) -> dict:
        logger.info(f"Getting weight overrides for context tags: {context_tags}")
        combined_overrides = {}
        for tag in context_tags:
            if tag in self.context_overrides:
                for model_id, multiplier in self.context_overrides[tag].items():
                    # If a model appears in multiple contexts, we can decide how to combine (e.g., multiply, average, take max)
                    # For now, let's just take the last one if there are conflicts, or multiply for simple cumulative effect.
                    combined_overrides[model_id] = combined_overrides.get(model_id, 1.0) * multiplier
        logger.info(f"Combined context overrides: {combined_overrides}")
        return combined_overrides


class ModelWeightManager:
    def __init__(self, strongest_model_id: str = None):
        logger.info(f"Initializing ModelWeightManager with strongest_model_id: {strongest_model_id}")
        # Dynamically initialize base_weights from MODEL_REGISTRY
        num_models = len(MODEL_REGISTRY)
        if num_models > 0:
            if strongest_model_id and strongest_model_id in MODEL_REGISTRY:
                logger.info(f"Designating '{strongest_model_id}' as the strongest model.")
                # Assign a high weight to the strongest model (e.g., 50%)
                strongest_weight_share = 0.5
                self.base_weights = {strongest_model_id: strongest_weight_share}
                remaining_models = [mid for mid in MODEL_REGISTRY.keys() if mid != strongest_model_id]
                if remaining_models:
                    other_models_weight_share = (1.0 - strongest_weight_share) / len(remaining_models)
                    for model_id in remaining_models:
                        self.base_weights[model_id] = other_models_weight_share
                else: # Only one model, which is the strongest
                    self.base_weights[strongest_model_id] = 1.0
            else:
                if strongest_model_id and strongest_model_id not in MODEL_REGISTRY:
                    logger.warning(f"Strongest model ID '{strongest_model_id}' not found in MODEL_REGISTRY. Distributing weights equally.")
                initial_weight = 1.0 / num_models
                self.base_weights = {model_id: initial_weight for model_id in MODEL_REGISTRY.keys()}
        else:
            self.base_weights = {} # Handle case with no registered models
        logger.info(f"Base weights initialized: {self.base_weights}")
        
        self.situation_weights = {
            "high_confidence": 1.2,    # Boost weights for high confidence predictions
            "low_confidence": 0.8,     # Reduce weights for low confidence
            "conflict": 0.5,          # Reduce weights when models disagree
            "consensus": 1.5          # Boost weights when models agree
        }
        self.context_override_agent = ContextualWeightOverrideAgent()
    
    def adjust_weights(self, predictions, confidence_scores, context_tags: list[str] = None):
        """Dynamically adjust weights based on prediction patterns and optional context."""
        logger.info("Adjusting model weights.")
        adjusted_weights = self.base_weights.copy()
        logger.info(f"Initial adjusted weights (copy of base): {adjusted_weights}")

        # 1. Apply contextual overrides first
        if context_tags:
            logger.info(f"Applying contextual overrides for tags: {context_tags}")
            overrides = self.context_override_agent.get_overrides(context_tags)
            for model_id, multiplier in overrides.items():
                adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier
            logger.info(f"Adjusted weights after context overrides: {adjusted_weights}")
        
        # 2. Apply situation-based adjustments (consensus, conflict, confidence)
        # Check for consensus
        has_consensus = self._has_consensus(predictions)
        if has_consensus:
            logger.info("Consensus detected. Boosting weights for consensus.")
            for model in adjusted_weights:
                adjusted_weights[model] *= self.situation_weights["consensus"]
            logger.info(f"Adjusted weights after consensus boost: {adjusted_weights}")
        
        # Check for conflicts
        has_conflicts = self._has_conflicts(predictions)
        if has_conflicts:
            logger.info("Conflicts detected. Reducing weights for conflict.")
            for model in adjusted_weights:
                adjusted_weights[model] *= self.situation_weights["conflict"]
            logger.info(f"Adjusted weights after conflict reduction: {adjusted_weights}")
        
        # Adjust based on confidence
        logger.info("Adjusting weights based on model confidence scores.")
        for model, confidence in confidence_scores.items():
            if confidence > 0.8:
                adjusted_weights[model] *= self.situation_weights["high_confidence"]
                logger.info(f"Model '{model}' has high confidence ({confidence:.2f}). Weight boosted.")
            elif confidence < 0.5:
                adjusted_weights[model] *= self.situation_weights["low_confidence"]
                logger.info(f"Model '{model}' has low confidence ({confidence:.2f}). Weight reduced.")
        logger.info(f"Adjusted weights before normalization: {adjusted_weights}")
        
        normalized_weights = self._normalize_weights(adjusted_weights)
        logger.info(f"Final normalized adjusted weights: {normalized_weights}")
        return normalized_weights
    
    def _has_consensus(self, predictions):
        """Check if models agree on prediction"""
        logger.info("Checking for consensus among model predictions.")
        non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
        logger.debug(f"Non-none predictions for consensus check: {non_none_predictions}")
        result = len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1
        logger.info(f"Consensus detected: {result}")
        return result
    
    def _has_conflicts(self, predictions):
        """Check if models have conflicting predictions"""
        logger.info("Checking for conflicts among model predictions.")
        non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
        logger.debug(f"Non-none predictions for conflict check: {non_none_predictions}")
        result = len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1
        logger.info(f"Conflicts detected: {result}")
        return result
    
    def _normalize_weights(self, weights):
        """Normalize weights to sum to 1"""
        logger.info("Normalizing weights.")
        total = sum(weights.values())
        if total == 0:
            logger.warning("All weights became zero after adjustments. Reverting to equal base weights for registered models.")
            # Revert to equal weights for all *registered* models if total becomes zero
            num_registered_models = len(MODEL_REGISTRY)
            if num_registered_models > 0:
                return {k: 1.0/num_registered_models for k in MODEL_REGISTRY.keys()}
            else:
                return {} # No models registered
        normalized = {k: v/total for k, v in weights.items()}
        logger.info(f"Weights normalized. Total sum: {sum(normalized.values()):.2f}")
        return normalized