File size: 9,733 Bytes
057dd29
 
 
f78b194
057dd29
9310be4
057dd29
 
 
401d88f
 
057dd29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401d88f
057dd29
 
 
 
 
 
 
401d88f
057dd29
 
 
401d88f
057dd29
 
f78b194
 
057dd29
 
 
 
24db732
057dd29
 
 
 
 
 
 
 
 
 
 
 
24db732
057dd29
 
 
 
24db732
057dd29
 
 
 
 
 
 
 
 
 
 
f78b194
057dd29
f78b194
057dd29
 
 
24db732
057dd29
 
 
f78b194
057dd29
 
 
 
 
f78b194
057dd29
 
f78b194
057dd29
 
 
 
f78b194
057dd29
 
f78b194
057dd29
 
24db732
057dd29
 
 
f78b194
057dd29
 
f78b194
24db732
057dd29
 
24db732
057dd29
 
 
 
f78b194
057dd29
24db732
057dd29
24db732
057dd29
 
 
 
f78b194
057dd29
24db732
057dd29
24db732
057dd29
 
 
 
f78b194
057dd29
 
f78b194
057dd29
 
 
 
 
 
 
f78b194
24db732
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import logging
import torch
from utils.registry import MODEL_REGISTRY # Import MODEL_REGISTRY
from utils.agent_logger import AgentLogger

agent_logger = AgentLogger()

class ContextualWeightOverrideAgent:
    def __init__(self):
        agent_logger = AgentLogger()
        agent_logger.log("weight_optimization", "info", "Initializing ContextualWeightOverrideAgent.")
        self.context_overrides = {
            # Example: when image is outdoor, model_X is penalized, model_Y is boosted
            "outdoor": {
                "model_1": 0.8, # Example: Reduce weight of model_1 by 20% for outdoor scenes
                "model_5": 1.2, # Example: Boost weight of model_5 by 20% for outdoor scenes
            },
            "low_light": {
                "model_2": 0.7, 
                "model_7": 1.3,
            },
            "sunny": {
                "model_3": 0.9,
                "model_4": 1.1,
            }
            # Add more contexts and their specific model weight adjustments here
        }

    def get_overrides(self, context_tags: list[str]) -> dict:
        agent_logger.log("weight_optimization", "info", f"Getting weight overrides for context tags: {context_tags}")
        combined_overrides = {}
        for tag in context_tags:
            if tag in self.context_overrides:
                for model_id, multiplier in self.context_overrides[tag].items():
                    # If a model appears in multiple contexts, we can decide how to combine (e.g., multiply, average, take max)
                    # For now, let's just take the last one if there are conflicts, or multiply for simple cumulative effect.
                    combined_overrides[model_id] = combined_overrides.get(model_id, 1.0) * multiplier
        agent_logger.log("weight_optimization", "info", f"Combined context overrides: {combined_overrides}")
        return combined_overrides



class ModelWeightManager:
    def __init__(self, strongest_model_id: str = None):
        agent_logger = AgentLogger()
        agent_logger.log("weight_optimization", "info", f"Initializing ModelWeightManager. Strongest model: {strongest_model_id}")
        # Dynamically initialize base_weights from MODEL_REGISTRY
        num_models = len(MODEL_REGISTRY)
        if num_models > 0:
            if strongest_model_id and strongest_model_id in MODEL_REGISTRY:
                agent_logger.log("weight_optimization", "info", f"Designating '{strongest_model_id}' as the strongest model.")
                # Assign a high weight to the strongest model (e.g., 50%)
                strongest_weight_share = 0.5
                self.base_weights = {strongest_model_id: strongest_weight_share}
                remaining_models = [mid for mid in MODEL_REGISTRY.keys() if mid != strongest_model_id]
                if remaining_models:
                    other_models_weight_share = (1.0 - strongest_weight_share) / len(remaining_models)
                    for model_id in remaining_models:
                        self.base_weights[model_id] = other_models_weight_share
                else: # Only one model, which is the strongest
                    self.base_weights[strongest_model_id] = 1.0
            else:
                if strongest_model_id and strongest_model_id not in MODEL_REGISTRY:
                    agent_logger.log("weight_optimization", "warning", f"Strongest model ID '{strongest_model_id}' not found in MODEL_REGISTRY. Distributing weights equally.")
                initial_weight = 1.0 / num_models
                self.base_weights = {model_id: initial_weight for model_id in MODEL_REGISTRY.keys()}
        else:
            self.base_weights = {} # Handle case with no registered models
        agent_logger.log("weight_optimization", "info", f"Base weights initialized: {self.base_weights}")
        
        self.situation_weights = {
            "high_confidence": 1.2,    # Boost weights for high confidence predictions
            "low_confidence": 0.8,     # Reduce weights for low confidence
            "conflict": 0.5,          # Reduce weights when models disagree
            "consensus": 1.5          # Boost weights when models agree
        }
        self.context_override_agent = ContextualWeightOverrideAgent()
    
    def adjust_weights(self, predictions, confidence_scores, context_tags: list[str] = None):
        """Dynamically adjust weights based on prediction patterns and optional context."""
        agent_logger.log("weight_optimization", "info", "Adjusting model weights.")
        adjusted_weights = self.base_weights.copy()
        agent_logger.log("weight_optimization", "info", f"Initial adjusted weights (copy of base): {adjusted_weights}")

        # 1. Apply contextual overrides first
        if context_tags:
            agent_logger.log("weight_optimization", "info", f"Applying contextual overrides for tags: {context_tags}")
            overrides = self.context_override_agent.get_overrides(context_tags)
            for model_id, multiplier in overrides.items():
                adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier
            agent_logger.log("weight_optimization", "info", f"Adjusted weights after context overrides: {adjusted_weights}")
        
        # 2. Apply situation-based adjustments (consensus, conflict, confidence)
        # Check for consensus
        has_consensus = self._has_consensus(predictions)
        if has_consensus:
            agent_logger.log("weight_optimization", "info", "Consensus detected. Boosting weights for consensus.")
            for model in adjusted_weights:
                adjusted_weights[model] *= self.situation_weights["consensus"]
            agent_logger.log("weight_optimization", "info", f"Adjusted weights after consensus boost: {adjusted_weights}")
        
        # Check for conflicts
        has_conflicts = self._has_conflicts(predictions)
        if has_conflicts:
            agent_logger.log("weight_optimization", "info", "Conflicts detected. Reducing weights for conflict.")
            for model in adjusted_weights:
                adjusted_weights[model] *= self.situation_weights["conflict"]
            agent_logger.log("weight_optimization", "info", f"Adjusted weights after conflict reduction: {adjusted_weights}")
        
        # Adjust based on confidence
        agent_logger.log("weight_optimization", "info", "Adjusting weights based on model confidence scores.")
        for model, confidence in confidence_scores.items():
            if confidence > 0.8:
                adjusted_weights[model] *= self.situation_weights["high_confidence"]
                agent_logger.log("weight_optimization", "info", f"Model '{model}' has high confidence ({confidence:.2f}). Weight boosted.")
            elif confidence < 0.5:
                adjusted_weights[model] *= self.situation_weights["low_confidence"]
                agent_logger.log("weight_optimization", "info", f"Model '{model}' has low confidence ({confidence:.2f}). Weight reduced.")
        agent_logger.log("weight_optimization", "info", f"Adjusted weights before normalization: {adjusted_weights}")
        
        normalized_weights = self._normalize_weights(adjusted_weights)
        agent_logger.log("weight_optimization", "info", f"Final normalized adjusted weights: {normalized_weights}")
        return normalized_weights
    
    def _has_consensus(self, predictions):
        """Check if models agree on prediction"""
        agent_logger.log("weight_optimization", "info", "Checking for consensus among model predictions.")
        non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
        agent_logger.log("weight_optimization", "debug", f"Non-none predictions for consensus check: {non_none_predictions}")
        result = len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1
        agent_logger.log("weight_optimization", "info", f"Consensus detected: {result}")
        return result
    
    def _has_conflicts(self, predictions):
        """Check if models have conflicting predictions"""
        agent_logger.log("weight_optimization", "info", "Checking for conflicts among model predictions.")
        non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
        agent_logger.log("weight_optimization", "debug", f"Non-none predictions for conflict check: {non_none_predictions}")
        result = len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1
        agent_logger.log("weight_optimization", "info", f"Conflicts detected: {result}")
        return result
    
    def _normalize_weights(self, weights):
        """Normalize weights to sum to 1"""
        agent_logger.log("weight_optimization", "info", "Normalizing weights.")
        total = sum(weights.values())
        if total == 0:
            agent_logger.log("weight_optimization", "warning", "All weights became zero after adjustments. Reverting to equal base weights for registered models.")
            # Revert to equal weights for all *registered* models if total becomes zero
            num_registered_models = len(MODEL_REGISTRY)
            if num_registered_models > 0:
                return {k: 1.0/num_registered_models for k in MODEL_REGISTRY.keys()}
            else:
                return {} # No models registered
        normalized = {k: v/total for k, v in weights.items()}
        agent_logger.log("weight_optimization", "info", f"Weights normalized. Total sum: {sum(normalized.values()):.2f}")
        return normalized