| """ |
| Deploy Model Graph - FIXED |
| |
| This module implements the model deployment workflow graph for the ComputeAgent. |
| |
| KEY FIX: DeployModelState now correctly inherits from AgentState (TypedDict) |
| instead of StateGraph. |
| |
| Author: ComputeAgent Team |
| License: Private |
| """ |
|
|
| import logging |
| from typing import Dict, Any, Optional |
| from langgraph.graph import StateGraph, END |
| from langgraph.graph.state import CompiledStateGraph |
| from ComputeAgent.graph.graph_ReAct import ReactWorkflow |
| from ComputeAgent.graph.state import AgentState |
|
|
| |
| from ComputeAgent.nodes.ReAct_DeployModel.extract_model_info import extract_model_info_node |
| from ComputeAgent.nodes.ReAct_DeployModel.generate_additional_info import generate_additional_info_node |
| from ComputeAgent.nodes.ReAct_DeployModel.capacity_estimation import capacity_estimation_node |
| from ComputeAgent.nodes.ReAct_DeployModel.capacity_approval import capacity_approval_node, auto_capacity_approval_node |
| from ComputeAgent.models.model_manager import ModelManager |
| from langchain_mcp_adapters.client import MultiServerMCPClient |
| import os |
|
|
| |
| from constant import Constants |
|
|
| |
| model_manager = ModelManager() |
|
|
| logger = logging.getLogger("ComputeAgent") |
|
|
| |
| import sys |
| |
| |
| project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| mcp_server_path = os.path.join(project_root, "Compute_MCP", "main.py") |
|
|
| |
| python_executable = sys.executable |
|
|
| mcp_client = MultiServerMCPClient( |
| { |
| "hivecompute": { |
| "command": python_executable, |
| "args": [mcp_server_path], |
| "transport": "stdio", |
| "env": { |
| |
| "HIVE_COMPUTE_DEFAULT_API_TOKEN": os.getenv("HIVE_COMPUTE_DEFAULT_API_TOKEN", ""), |
| "HIVE_COMPUTE_BASE_API_URL": os.getenv("HIVE_COMPUTE_BASE_API_URL", "https://api.hivecompute.ai"), |
| |
| "PATH": os.getenv("PATH", ""), |
| "PYTHONPATH": os.getenv("PYTHONPATH", ""), |
| } |
| } |
| } |
| ) |
|
|
| logger = logging.getLogger("DeployModelGraph") |
|
|
|
|
| |
| class DeployModelState(AgentState): |
| """ |
| DeployModelState extends AgentState to inherit all base agent fields. |
| |
| Inherited from AgentState (TypedDict): |
| - query: str |
| - response: str |
| - current_step: str |
| - messages: List[Dict[str, Any]] |
| - agent_decision: str |
| - deployment_approved: bool |
| - model_name: str |
| - llm: Any |
| - model_card: Dict[str, Any] |
| - model_info: Dict[str, Any] |
| - capacity_estimate: Dict[str, Any] |
| - deployment_result: Dict[str, Any] |
| - react_results: Dict[str, Any] |
| - tool_calls: List[Dict[str, Any]] |
| - tool_results: List[Dict[str, Any]] |
| |
| All fields are inherited from AgentState - no additional fields needed. |
| """ |
| pass |
|
|
|
|
| class DeployModelAgent: |
| """ |
| Standalone Deploy Model Agent class with memory and streaming support. |
| |
| This class provides a dedicated interface for model deployment workflows |
| with full memory management and streaming capabilities. |
| """ |
| |
| def __init__(self, llm, react_tools): |
| self.llm = llm |
| self.react_tools = react_tools |
| self.react_subgraph = ReactWorkflow(llm=self.llm, tools=self.react_tools) |
| self.graph = self._create_graph() |
|
|
| @classmethod |
| async def create(cls, llm=None, custom_tools=None): |
| """ |
| Async factory method for DeployModelAgent. |
| |
| Args: |
| llm: Optional pre-loaded LLM |
| custom_tools: Optional pre-loaded tools for the nested ReactWorkflow |
| |
| Returns: |
| DeployModelAgent instance |
| """ |
| if llm is None: |
| llm = await model_manager.load_llm_model(Constants.DEFAULT_LLM_FC) |
| |
| if custom_tools is None: |
| |
| custom_tools = await mcp_client.get_tools() |
| |
| return cls(llm=llm, react_tools=custom_tools) |
| |
| def _create_graph(self) -> CompiledStateGraph: |
| """ |
| Creates and configures the deploy model workflow. |
| |
| β
FIXED: Now correctly creates StateGraph with DeployModelState (TypedDict) |
| """ |
| |
| workflow = StateGraph(DeployModelState) |
|
|
| |
| workflow.add_node("extract_model_info", extract_model_info_node) |
| workflow.add_node("generate_model_name", generate_additional_info_node) |
| workflow.add_node("capacity_estimation", capacity_estimation_node) |
| workflow.add_node("capacity_approval", capacity_approval_node) |
| workflow.add_node("auto_capacity_approval", auto_capacity_approval_node) |
| workflow.add_node("react_deployment", self.react_subgraph.get_compiled_graph()) |
| |
| |
| workflow.set_entry_point("extract_model_info") |
|
|
| |
| workflow.add_conditional_edges( |
| "extract_model_info", |
| self.should_validate_or_generate, |
| { |
| "generate_model_name": "generate_model_name", |
| "capacity_estimation": "capacity_estimation" |
| } |
| ) |
| |
| |
| workflow.add_conditional_edges( |
| "capacity_estimation", |
| self.should_continue_to_capacity_approval, |
| { |
| "capacity_approval": "capacity_approval", |
| "auto_capacity_approval": "auto_capacity_approval", |
| "end": END |
| } |
| ) |
| |
| |
| workflow.add_conditional_edges( |
| "capacity_approval", |
| self.should_continue_after_capacity_approval, |
| { |
| "react_deployment": "react_deployment", |
| "capacity_estimation": "capacity_estimation", |
| "end": END |
| } |
| ) |
| |
| |
| workflow.add_edge("auto_capacity_approval", "react_deployment") |
| |
| |
| workflow.add_edge("generate_model_name", END) |
| workflow.add_edge("react_deployment", END) |
| |
| |
| return workflow.compile() |
| |
| def get_compiled_graph(self): |
| """Return the compiled graph for embedding in parent graph""" |
| return self.graph |
|
|
| def should_validate_or_generate(self, state: Dict[str, Any]) -> str: |
| """ |
| Decision routing function after model extraction. |
| |
| Path 1: If model found and valid β proceed to capacity estimation |
| Path 1A: If no model info or invalid β generate helpful response with suggestions |
| |
| Args: |
| state: Current workflow state |
| |
| Returns: |
| Next node name or END |
| """ |
| if state.get("model_name") and state.get("model_info") and not state.get("model_info", {}).get("error"): |
| return "capacity_estimation" |
| else: |
| return "generate_model_name" |
|
|
| def should_continue_to_capacity_approval(self, state: Dict[str, Any]) -> str: |
| """ |
| Determine whether to proceed to human approval, auto-approval, or end. |
| |
| This function controls the flow after capacity estimation based on HUMAN_APPROVAL_CAPACITY setting: |
| - If HUMAN_APPROVAL_CAPACITY is True: Route to capacity_approval for manual approval |
| - If HUMAN_APPROVAL_CAPACITY is False: Route to auto_capacity_approval for automatic approval |
| - If capacity estimation failed: Route to end |
| |
| Args: |
| state: Current workflow state containing capacity estimation results |
| |
| Returns: |
| Next node name: "capacity_approval", "auto_capacity_approval", or "end" |
| """ |
| |
| if state.get("capacity_estimation_status") != "success": |
| logger.info("π Capacity estimation failed - routing to end") |
| return "end" |
| |
| |
| HUMAN_APPROVAL_CAPACITY = True if Constants.HUMAN_APPROVAL_CAPACITY == "true" else False |
| if not HUMAN_APPROVAL_CAPACITY: |
| logger.info("π HUMAN_APPROVAL_CAPACITY disabled - routing to auto-approval") |
| return "auto_capacity_approval" |
| else: |
| logger.info("π HUMAN_APPROVAL_CAPACITY enabled - routing to human approval") |
| return "capacity_approval" |
|
|
| def should_continue_after_capacity_approval(self, state: Dict[str, Any]) -> str: |
| """ |
| Decide whether to proceed to ReAct deployment, re-estimate capacity, or end. |
| """ |
| logger.info(f"π Routing after capacity approval:") |
| logger.info(f" - capacity_approved: {state.get('capacity_approved')}") |
| logger.info(f" - needs_re_estimation: {state.get('needs_re_estimation')}") |
| logger.info(f" - capacity_approval_status: {state.get('capacity_approval_status')}") |
|
|
| |
| needs_re_estimation = state.get("needs_re_estimation") |
| if needs_re_estimation is True: |
| logger.info("π Re-estimation requested - routing to capacity_estimation") |
| return "capacity_estimation" |
|
|
| |
| capacity_approved = state.get("capacity_approved") |
| if capacity_approved is True: |
| logger.info("β
Capacity approved - proceeding to react_deployment") |
| return "react_deployment" |
|
|
| |
| if capacity_approved is False: |
| logger.info("β Capacity rejected - ending workflow") |
| return "end" |
|
|
| |
| logger.warning(f"β οΈ Unexpected state in capacity approval routing") |
| logger.warning(f" capacity_approved: {capacity_approved} (type: {type(capacity_approved)})") |
| logger.warning(f" needs_re_estimation: {needs_re_estimation} (type: {type(needs_re_estimation)})") |
| logger.warning(f" Full state keys: {list(state.keys())}") |
| |
| |
| return "end" |
|
|
| async def ainvoke(self, |
| query: str, |
| user_id: str = "default_user", |
| session_id: str = "default_session", |
| enable_memory: bool = False, |
| config: Optional[Dict] = None) -> Dict[str, Any]: |
| """ |
| Asynchronously invoke the Deploy Model Agent workflow. |
| |
| Args: |
| query: User's model deployment query |
| user_id: User identifier for memory management |
| session_id: Session identifier for memory management |
| enable_memory: Whether to enable conversation memory management |
| config: Optional config dict |
| |
| Returns: |
| Final workflow state with deployment results |
| """ |
| |
| initial_state = { |
| |
| "query": query, |
| "response": "", |
| "current_step": "initialized", |
| "messages": [], |
| |
| |
| "agent_decision": "", |
| "deployment_approved": False, |
| |
| |
| "model_name": "", |
| "llm": None, |
| "model_card": {}, |
| "model_info": {}, |
| "capacity_estimate": {}, |
| "deployment_result": {}, |
| |
| |
| "react_results": {}, |
| "tool_calls": [], |
| "tool_results": [], |
| } |
| |
| |
| if config and "configurable" in config: |
| if "capacity_approved" in config["configurable"]: |
| initial_state["deployment_approved"] = config["configurable"]["capacity_approved"] |
| logger.info(f"π DeployModelAgent received approval: {config['configurable']['capacity_approved']}") |
| |
| |
| memory_config = None |
| if self.checkpointer: |
| thread_id = f"{user_id}:{session_id}" |
| memory_config = {"configurable": {"thread_id": thread_id}} |
| |
| |
| final_config = memory_config or {} |
| if config: |
| if "configurable" in final_config: |
| final_config["configurable"].update(config.get("configurable", {})) |
| else: |
| final_config = config |
| |
| logger.info(f"π Starting Deploy Model workflow") |
| |
| |
| if final_config: |
| result = await self.graph.ainvoke(initial_state, final_config) |
| else: |
| result = await self.graph.ainvoke(initial_state) |
| |
| return result |
|
|
| |
| def invoke(self, query: str, user_id: str = "default_user", session_id: str = "default_session", enable_memory: bool = False) -> Dict[str, Any]: |
| """ |
| Synchronously invoke the Deploy Model Agent workflow. |
| |
| Args: |
| query: User's model deployment query |
| user_id: User identifier for memory management |
| session_id: Session identifier for memory management |
| enable_memory: Whether to enable conversation memory management |
| |
| Returns: |
| Final workflow state with deployment results |
| """ |
| import asyncio |
| return asyncio.run(self.ainvoke(query, user_id, session_id, enable_memory)) |
| |
| def draw_graph(self, output_file_path: str = "deploy_model_graph.png"): |
| """ |
| Generate and save a visual representation of the Deploy Model workflow graph. |
| |
| Args: |
| output_file_path: Path where to save the graph PNG file |
| """ |
| try: |
| self.graph.get_graph().draw_mermaid_png(output_file_path=output_file_path) |
| logger.info(f"β
Graph visualization saved to: {output_file_path}") |
| except Exception as e: |
| logger.error(f"β Failed to generate graph visualization: {e}") |
| print(f"Error generating graph: {e}") |