| import logging |
| import requests |
| from typing import Any, Dict, List, Optional, Text |
|
|
| from rasa.nlu.classifiers.classifier import IntentClassifier |
| from rasa.shared.nlu.constants import TEXT, INTENT |
| from rasa.nlu.config import RasaNLUModelConfig |
| from rasa.shared.nlu.training_data.training_data import TrainingData |
| from rasa.shared.nlu.training_data.message import Message |
| from rasa.nlu.model import Metadata |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class LlmIntentClassifier(IntentClassifier): |
| """Delegates intent classification to an external HTTP micro-service.""" |
|
|
| name = "LlmIntentClassifier" |
| defaults = { |
| "classifier_url": "http://classifier:8000/classify", |
| "timeout": 5.0, |
| "model_name": None, |
| "base_url": None, |
| "class_set": [], |
| "prompt_template": None, |
| } |
|
|
| def __init__( |
| self, |
| component_config: Optional[Dict[Text, Any]] = None, |
| ) -> None: |
| super().__init__(component_config or {}) |
|
|
| self.url: str = self.component_config.get("classifier_url") |
| self.timeout: float = float(self.component_config.get("timeout")) |
| self.model_name: Optional[Text] = self.component_config.get("model_name") |
| self.base_url: Optional[Text] = self.component_config.get("base_url") |
| self.class_set: List[Text] = self.component_config.get("class_set", []) |
| self.prompt_template: Optional[Text] = self.component_config.get("prompt_template") |
|
|
| |
| missing: List[str] = [] |
| if not self.model_name: |
| missing.append("model_name") |
| if not self.base_url: |
| missing.append("base_url") |
| if not self.class_set: |
| missing.append("class_set") |
| if not self.prompt_template: |
| missing.append("prompt_template") |
| if missing: |
| raise ValueError( |
| f"Missing configuration for {', '.join(missing)} in LlmIntentClassifier" |
| ) |
|
|
| |
| self._configure_remote_classifier() |
|
|
| def _configure_remote_classifier(self) -> None: |
| """Send configuration to the classifier backend to initialize the model.""" |
| payload = { |
| "model_name": self.model_name, |
| "base_url": self.base_url, |
| "class_set": self.class_set, |
| "prompt_template": self.prompt_template, |
| } |
| try: |
| config_url = self.url.replace("/classify", "/config") |
| logger.debug(f"Sending classifier config to: {config_url}") |
| response = requests.post(config_url, json=payload, timeout=self.timeout) |
| response.raise_for_status() |
| logger.info("Remote classifier initialized successfully.") |
| except Exception as e: |
| logger.warning(f"Failed to initialize remote classifier: {e}") |
|
|
| def train( |
| self, |
| training_data: TrainingData, |
| config: Optional[RasaNLUModelConfig] = None, |
| **kwargs: Any, |
| ) -> None: |
| |
| pass |
|
|
| def process(self, message: Message, **kwargs: Any) -> None: |
| text: Optional[Text] = message.get(TEXT) |
| intent_name: Optional[Text] = None |
| confidence: float = 0.0 |
|
|
| if text: |
| payload: Dict[str, Any] = {"message": text} |
| try: |
| resp = requests.post(self.url, json=payload, timeout=self.timeout) |
| resp.raise_for_status() |
| result = resp.json().get("result") |
| if isinstance(result, str): |
| intent_name = result |
| confidence = 1.0 |
| except Exception as e: |
| logger.warning(f"LlmIntentClassifier HTTP error: {e}") |
|
|
| message.set(INTENT, {"name": intent_name, "confidence": confidence}, add_to_output=True) |
|
|
| def persist( |
| self, |
| file_name: Text, |
| model_dir: Text, |
| ) -> Optional[Dict[Text, Any]]: |
| |
| return { |
| "classifier_url": self.url, |
| "timeout": self.timeout, |
| "model_name": self.model_name, |
| "base_url": self.base_url, |
| "class_set": self.class_set, |
| "prompt_template": self.prompt_template, |
| } |
|
|
| @classmethod |
| def load( |
| cls, |
| meta: Dict[Text, Any], |
| model_dir: Text, |
| model_metadata: Metadata = None, |
| cached_component: Optional["LlmIntentClassifier"] = None, |
| **kwargs: Any, |
| ) -> "LlmIntentClassifier": |
| |
| return cls(meta) |
|
|