LPX55
commited on
Commit
·
ce372d3
1
Parent(s):
c378b41
refactor(weights): update strongest model ID and adjust queue concurrency limit
Browse files- agents/ensemble_weights.py +5 -7
- app.py +2 -2
agents/ensemble_weights.py
CHANGED
|
@@ -38,8 +38,6 @@ class ContextualWeightOverrideAgent:
|
|
| 38 |
agent_logger.log("weight_optimization", "info", f"Combined context overrides: {combined_overrides}")
|
| 39 |
return combined_overrides
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
class ModelWeightManager:
|
| 44 |
def __init__(self, strongest_model_id: str = None):
|
| 45 |
agent_logger = AgentLogger()
|
|
@@ -49,8 +47,8 @@ class ModelWeightManager:
|
|
| 49 |
if num_models > 0:
|
| 50 |
if strongest_model_id and strongest_model_id in MODEL_REGISTRY:
|
| 51 |
agent_logger.log("weight_optimization", "info", f"Designating '{strongest_model_id}' as the strongest model.")
|
| 52 |
-
# Assign a high weight to the strongest model (e.g.,
|
| 53 |
-
strongest_weight_share = 0.
|
| 54 |
self.base_weights = {strongest_model_id: strongest_weight_share}
|
| 55 |
remaining_models = [mid for mid in MODEL_REGISTRY.keys() if mid != strongest_model_id]
|
| 56 |
if remaining_models:
|
|
@@ -126,7 +124,7 @@ class ModelWeightManager:
|
|
| 126 |
"""Check if models agree on prediction"""
|
| 127 |
agent_logger.log("weight_optimization", "info", "Checking for consensus among model predictions.")
|
| 128 |
non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
|
| 129 |
-
agent_logger.
|
| 130 |
result = len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1
|
| 131 |
agent_logger.log("weight_optimization", "info", f"Consensus detected: {result}")
|
| 132 |
return result
|
|
@@ -135,7 +133,7 @@ class ModelWeightManager:
|
|
| 135 |
"""Check if models have conflicting predictions"""
|
| 136 |
agent_logger.log("weight_optimization", "info", "Checking for conflicts among model predictions.")
|
| 137 |
non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
|
| 138 |
-
agent_logger.
|
| 139 |
result = len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1
|
| 140 |
agent_logger.log("weight_optimization", "info", f"Conflicts detected: {result}")
|
| 141 |
return result
|
|
@@ -154,4 +152,4 @@ class ModelWeightManager:
|
|
| 154 |
return {} # No models registered
|
| 155 |
normalized = {k: v/total for k, v in weights.items()}
|
| 156 |
agent_logger.log("weight_optimization", "info", f"Weights normalized. Total sum: {sum(normalized.values()):.2f}")
|
| 157 |
-
return normalized
|
|
|
|
| 38 |
agent_logger.log("weight_optimization", "info", f"Combined context overrides: {combined_overrides}")
|
| 39 |
return combined_overrides
|
| 40 |
|
|
|
|
|
|
|
| 41 |
class ModelWeightManager:
|
| 42 |
def __init__(self, strongest_model_id: str = None):
|
| 43 |
agent_logger = AgentLogger()
|
|
|
|
| 47 |
if num_models > 0:
|
| 48 |
if strongest_model_id and strongest_model_id in MODEL_REGISTRY:
|
| 49 |
agent_logger.log("weight_optimization", "info", f"Designating '{strongest_model_id}' as the strongest model.")
|
| 50 |
+
# Assign a high weight to the strongest model (e.g., 40%)
|
| 51 |
+
strongest_weight_share = 0.4
|
| 52 |
self.base_weights = {strongest_model_id: strongest_weight_share}
|
| 53 |
remaining_models = [mid for mid in MODEL_REGISTRY.keys() if mid != strongest_model_id]
|
| 54 |
if remaining_models:
|
|
|
|
| 124 |
"""Check if models agree on prediction"""
|
| 125 |
agent_logger.log("weight_optimization", "info", "Checking for consensus among model predictions.")
|
| 126 |
non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
|
| 127 |
+
agent_logger.debug("weight_optimization", "info", f"Non-none predictions for consensus check: {non_none_predictions}")
|
| 128 |
result = len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1
|
| 129 |
agent_logger.log("weight_optimization", "info", f"Consensus detected: {result}")
|
| 130 |
return result
|
|
|
|
| 133 |
"""Check if models have conflicting predictions"""
|
| 134 |
agent_logger.log("weight_optimization", "info", "Checking for conflicts among model predictions.")
|
| 135 |
non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
|
| 136 |
+
agent_logger.debug("weight_optimization", "info", f"Non-none predictions for conflict check: {non_none_predictions}")
|
| 137 |
result = len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1
|
| 138 |
agent_logger.log("weight_optimization", "info", f"Conflicts detected: {result}")
|
| 139 |
return result
|
|
|
|
| 152 |
return {} # No models registered
|
| 153 |
normalized = {k: v/total for k, v in weights.items()}
|
| 154 |
agent_logger.log("weight_optimization", "info", f"Weights normalized. Total sum: {sum(normalized.values()):.2f}")
|
| 155 |
+
return normalized
|
app.py
CHANGED
|
@@ -188,7 +188,7 @@ def full_prediction(img, confidence_threshold, rotate_degrees, noise_level, shar
|
|
| 188 |
img = img.convert('RGB')
|
| 189 |
|
| 190 |
monitor_agent = EnsembleMonitorAgent()
|
| 191 |
-
weight_manager = ModelWeightManager(strongest_model_id="
|
| 192 |
optimization_agent = WeightOptimizationAgent(weight_manager)
|
| 193 |
health_agent = SystemHealthAgent()
|
| 194 |
context_agent = ContextualIntelligenceAgent()
|
|
@@ -679,4 +679,4 @@ with gr.Blocks() as app:
|
|
| 679 |
footer.render()
|
| 680 |
|
| 681 |
|
| 682 |
-
app.queue(max_size=10, default_concurrency_limit=
|
|
|
|
| 188 |
img = img.convert('RGB')
|
| 189 |
|
| 190 |
monitor_agent = EnsembleMonitorAgent()
|
| 191 |
+
weight_manager = ModelWeightManager(strongest_model_id="model_8")
|
| 192 |
optimization_agent = WeightOptimizationAgent(weight_manager)
|
| 193 |
health_agent = SystemHealthAgent()
|
| 194 |
context_agent = ContextualIntelligenceAgent()
|
|
|
|
| 679 |
footer.render()
|
| 680 |
|
| 681 |
|
| 682 |
+
app.queue(max_size=10, default_concurrency_limit=1).launch(mcp_server=True)
|