Spaces:
Sleeping
Sleeping
| """ | |
| DevOps Agent — LLM-based terminal troubleshooting agent. | |
| Wraps a fine-tunable LLM (or rule-based fallback) to generate shell | |
| commands from error observations. Supports both Unsloth/HuggingFace | |
| models and a deterministic rule-based baseline for testing. | |
| """ | |
| from __future__ import annotations | |
| import re | |
| from typing import Any, Dict, List, Optional | |
| from agent.prompts import format_chat_messages, format_prompt | |
| class DevOpsAgent: | |
| """LLM-powered DevOps troubleshooting agent. | |
| Generates shell commands to fix broken environments based on | |
| error logs and command history. Supports fine-tuned LLM mode | |
| and rule-based fallback mode. | |
| Usage: | |
| # Rule-based mode (no GPU needed) | |
| agent = DevOpsAgent(model_name="rule-based") | |
| cmd = agent.act(observation) | |
| # LLM mode | |
| agent = DevOpsAgent(model_name="unsloth/llama-3.2-3b-instruct") | |
| cmd = agent.act(observation) | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str = "rule-based", | |
| use_lora: bool = True, | |
| max_new_tokens: int = 64, | |
| temperature: float = 0.7, | |
| device: str = "auto", | |
| model: Any | None = None, | |
| tokenizer: Any | None = None, | |
| auto_load: bool = True, | |
| ) -> None: | |
| """Initialize the agent. | |
| Args: | |
| model_name: HuggingFace model ID or 'rule-based' for baseline. | |
| use_lora: Whether to use LoRA adapters. | |
| max_new_tokens: Maximum tokens to generate. | |
| temperature: Sampling temperature. | |
| device: Device to load model on ('auto', 'cuda', 'cpu'). | |
| model: Optional preloaded model instance. | |
| tokenizer: Optional preloaded tokenizer instance. | |
| auto_load: Whether to auto-load model when model_name is not rule-based. | |
| """ | |
| self.model_name = model_name | |
| self.use_lora = use_lora | |
| self.max_new_tokens = max_new_tokens | |
| self.temperature = temperature | |
| self.device = device | |
| self._model = model | |
| self._tokenizer = tokenizer | |
| self._is_loaded = self._model is not None and self._tokenizer is not None | |
| if model_name != "rule-based" and auto_load and not self._is_loaded: | |
| self._load_model() | |
| def _load_model(self) -> None: | |
| """Load the LLM model and tokenizer.""" | |
| try: | |
| from unsloth import FastLanguageModel | |
| self._model, self._tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=self.model_name, | |
| max_seq_length=2048, | |
| load_in_4bit=True, | |
| dtype=None, | |
| ) | |
| if self.use_lora: | |
| self._model = FastLanguageModel.get_peft_model( | |
| self._model, | |
| r=16, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj"], | |
| lora_alpha=16, | |
| lora_dropout=0, | |
| bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| ) | |
| FastLanguageModel.for_inference(self._model) | |
| self._is_loaded = True | |
| except ImportError: | |
| print("[DevOpsAgent] Unsloth not available. Falling back to transformers.") | |
| try: | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| self._tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self._model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, device_map=self.device, | |
| ) | |
| self._is_loaded = True | |
| except Exception as e: | |
| print(f"[DevOpsAgent] Failed to load model: {e}. Using rule-based fallback.") | |
| self.model_name = "rule-based" | |
| def act(self, observation: Dict) -> str: | |
| """Generate a shell command from the current observation. | |
| Args: | |
| observation: Dict with error_log, command_history, error_type, etc. | |
| Returns: | |
| Shell command string. | |
| """ | |
| if self.model_name == "rule-based": | |
| return self._rule_based_act(observation) | |
| return self._llm_act(observation) | |
| def _llm_act(self, observation: Dict) -> str: | |
| """Generate command using the LLM.""" | |
| messages = format_chat_messages( | |
| error_log=observation.get("error_log", ""), | |
| error_type=observation.get("error_type", "unknown"), | |
| command_history=observation.get("command_history", []), | |
| ) | |
| if self._tokenizer is None or self._model is None: | |
| return self._rule_based_act(observation) | |
| inputs = self._tokenizer.apply_chat_template( | |
| messages, tokenize=True, add_generation_prompt=True, | |
| return_tensors="pt", | |
| ).to(self._model.device) | |
| outputs = self._model.generate( | |
| input_ids=inputs, | |
| max_new_tokens=self.max_new_tokens, | |
| temperature=self.temperature, | |
| do_sample=True, | |
| top_p=0.9, | |
| ) | |
| response = self._tokenizer.decode( | |
| outputs[0][inputs.shape[-1]:], skip_special_tokens=True, | |
| ).strip() | |
| # Clean up: extract just the command | |
| command = self._extract_command(response) | |
| return command | |
| def _extract_command(self, response: str) -> str: | |
| """Extract a clean shell command from LLM output. | |
| Strips markdown formatting, explanations, and extracts | |
| just the command line. | |
| Args: | |
| response: Raw LLM output. | |
| Returns: | |
| Clean shell command string. | |
| """ | |
| # Remove markdown code blocks | |
| response = re.sub(r'```[\w]*\n?', '', response) | |
| response = re.sub(r'```', '', response) | |
| # Take only the first line (should be the command) | |
| lines = [l.strip() for l in response.strip().split('\n') if l.strip()] | |
| if not lines: | |
| return "echo 'no command generated'" | |
| command = lines[0] | |
| # Remove common prefixes | |
| command = re.sub(r'^[\$#>\s]+', '', command) | |
| command = re.sub(r'^\d+[\.)]\s*', '', command) | |
| command = re.sub(r'^[A-Za-z][A-Za-z0-9\s]*:\s*', '', command) | |
| command = re.sub(r'\s+#.*$', '', command) | |
| command = command.strip() | |
| # Remove backticks | |
| command = command.strip('`') | |
| return command if command else "echo 'no command generated'" | |
| def _rule_based_act(self, observation: Dict) -> str: | |
| """Generate command using rule-based heuristics. | |
| This serves as both a baseline for comparison and a fallback | |
| when no LLM is available. | |
| Args: | |
| observation: Dict with error_log, command_history, error_type. | |
| Returns: | |
| Shell command string. | |
| """ | |
| error_log = observation.get("error_log", "") | |
| error_type = observation.get("error_type", "unknown") | |
| history = observation.get("command_history", []) | |
| # Rule-based strategy based on error type | |
| if error_type == "missing_package": | |
| return self._handle_missing_package(error_log, history) | |
| elif error_type == "port_conflict": | |
| return self._handle_port_conflict(error_log, history) | |
| elif error_type == "missing_env": | |
| return self._handle_missing_env(error_log, history) | |
| elif error_type == "version_conflict": | |
| return self._handle_version_conflict(error_log, history) | |
| elif error_type == "syntax_error": | |
| return self._handle_syntax_error(error_log, history) | |
| elif error_type == "config_error": | |
| return self._handle_config_error(error_log, history) | |
| elif error_type == "file_not_found": | |
| return self._handle_file_not_found(error_log, history) | |
| elif error_type == "service_not_running": | |
| return self._handle_service_not_running(error_log, history) | |
| else: | |
| return self._handle_unknown(error_log, history) | |
| def _handle_missing_package(self, error_log: str, history: List[str]) -> str: | |
| """Handle missing package errors.""" | |
| # Extract the module name | |
| match = re.search(r"No module named ['\"]?(\w+)", error_log) | |
| if match: | |
| module = match.group(1) | |
| cmd = f"pip install {module}" | |
| if cmd not in history: | |
| return cmd | |
| return f"pip3 install {module}" | |
| match = re.search(r"ModuleNotFoundError.*?['\"](\w+)", error_log) | |
| if match: | |
| return f"pip install {match.group(1)}" | |
| return "pip install -r requirements.txt" | |
| def _handle_port_conflict(self, error_log: str, history: List[str]) -> str: | |
| """Handle port conflict errors.""" | |
| # Extract port number | |
| match = re.search(r"port\s+(\d+)", error_log, re.IGNORECASE) | |
| port = match.group(1) if match else "5000" | |
| if not any("lsof" in cmd or "kill" in cmd for cmd in history): | |
| return f"lsof -t -i:{port} | xargs kill -9" | |
| return f"python /app/server.py &" | |
| def _handle_missing_env(self, error_log: str, history: List[str]) -> str: | |
| """Handle missing environment variable errors.""" | |
| match = re.search(r"KeyError:\s*['\"](\w+)['\"]", error_log) | |
| if match: | |
| var_name = match.group(1) | |
| if not any("export" in cmd for cmd in history): | |
| defaults = { | |
| "DATABASE_URL": "postgresql://localhost:5432/mydb", | |
| "SECRET_KEY": "dev-secret-key-12345", | |
| "API_KEY": "test-api-key", | |
| } | |
| value = defaults.get(var_name, "placeholder_value") | |
| return f"export {var_name}={value}" | |
| return "python /app/db_app.py" | |
| return "env" | |
| def _handle_version_conflict(self, error_log: str, history: List[str]) -> str: | |
| """Handle version conflict errors.""" | |
| if not any("sed" in cmd for cmd in history): | |
| match = re.search(r"requested\s+(\w+)==(\S+)", error_log) | |
| if match: | |
| pkg = match.group(1) | |
| return f"sed -i 's/{pkg}==.*/{pkg}>=0/' /app/requirements.txt" | |
| return "sed -i 's/werkzeug==1.0.0/werkzeug>=2.3.0/' /app/requirements.txt" | |
| return "pip install -r /app/requirements.txt" | |
| def _handle_syntax_error(self, error_log: str, history: List[str]) -> str: | |
| """Handle Python syntax errors.""" | |
| if "python2" in error_log or "python3 shebang" in error_log.lower(): | |
| match = re.search(r'File "([^"]+)"', error_log) | |
| if match: | |
| return f"python3 {match.group(1)}" | |
| return "python3 /app/main.py" | |
| def _handle_config_error(self, error_log: str, history: List[str]) -> str: | |
| """Handle configuration errors.""" | |
| if "127.0.0.1" in error_log or "binding" in error_log.lower(): | |
| if not any("sed" in cmd for cmd in history): | |
| return "sed -i 's/127.0.0.1/0.0.0.0/' /app/config.py" | |
| if not any("kill" in cmd for cmd in history): | |
| return "kill $(lsof -t -i:8080) 2>/dev/null; true" | |
| return "python /app/server.py &" | |
| if "NameError" in error_log or "INVALID" in error_log: | |
| match = re.search(r'File "([^"]+)"', error_log) | |
| if match: | |
| filepath = match.group(1) | |
| if not any("cat >" in cmd for cmd in history): | |
| return f"cat {filepath}" | |
| return "python /app/migrate.py" | |
| return "cat /app/config.py" | |
| def _handle_file_not_found(self, error_log: str, history: List[str]) -> str: | |
| """Handle file not found errors.""" | |
| if "venv" in error_log or "bad interpreter" in error_log: | |
| if not any("rm" in cmd for cmd in history): | |
| return "rm -rf /app/venv" | |
| if not any("venv" in cmd and "python3" in cmd for cmd in history): | |
| return "python3 -m venv /app/venv" | |
| return "source /app/venv/bin/activate && pip install flask" | |
| match = re.search(r"No such file.*?['\"]?(/\S+)", error_log) | |
| if match: | |
| return f"ls -la {match.group(1)}" | |
| return "ls -la /app/" | |
| def _handle_service_not_running(self, error_log: str, history: List[str]) -> str: | |
| """Handle service not running errors.""" | |
| if "Connection refused" in error_log: | |
| match = re.search(r"port\s+(\d+)", error_log, re.IGNORECASE) | |
| port = match.group(1) if match else "8080" | |
| return f"python /app/server.py --port {port} &" | |
| return "ps aux | grep python" | |
| def _handle_unknown(self, error_log: str, history: List[str]) -> str: | |
| """Handle unclassified errors.""" | |
| if not history: | |
| return "cat /app/*.py 2>/dev/null || ls -la /app/" | |
| return "echo 'Analyzing error...'" | |
| def format_prompt(self, observation: Dict) -> str: | |
| """Build the prompt string from an observation dict. | |
| Args: | |
| observation: Environment observation dict. | |
| Returns: | |
| Formatted prompt string for the LLM. | |
| """ | |
| return format_prompt( | |
| error_log=observation.get("error_log", ""), | |
| error_type=observation.get("error_type", "unknown"), | |
| command_history=observation.get("command_history", []), | |
| ) | |
| def load_checkpoint(self, checkpoint_path: str) -> None: | |
| """Load a fine-tuned model checkpoint. | |
| Args: | |
| checkpoint_path: Path to the saved model/adapter. | |
| """ | |
| if self.model_name == "rule-based": | |
| print("[DevOpsAgent] Cannot load checkpoint for rule-based agent.") | |
| return | |
| try: | |
| from peft import PeftModel | |
| if self._model is not None: | |
| self._model = PeftModel.from_pretrained(self._model, checkpoint_path) | |
| print(f"[DevOpsAgent] Loaded checkpoint from {checkpoint_path}") | |
| except Exception as e: | |
| print(f"[DevOpsAgent] Failed to load checkpoint: {e}") | |