| """
|
| JaneGPT v2 Intent Classifier — Inference Wrapper
|
|
|
| Simple interface for intent classification.
|
|
|
| Usage:
|
| from model.classifier import JaneGPTClassifier
|
|
|
| classifier = JaneGPTClassifier()
|
| intent, confidence = classifier.predict("turn up the volume")
|
|
|
| Created by Ravindu Senanayake
|
| """
|
|
|
| from pathlib import Path
|
| from typing import Optional, Dict, Tuple, List
|
|
|
| import torch
|
|
|
| from model.architecture import JaneGPTv2Classifier, ID_TO_INTENT, INTENT_LABELS
|
|
|
|
|
| class JaneGPTClassifier:
|
| """
|
| Ready-to-use intent classifier.
|
|
|
| Loads the trained model and tokenizer, provides simple
|
| predict() interface for intent classification.
|
|
|
| Args:
|
| model_path: Path to trained checkpoint (.pt file)
|
| tokenizer_path: Path to BPE tokenizer (.json file)
|
| device: "auto", "cuda", or "cpu"
|
| """
|
|
|
| MAX_LEN = 128
|
| PAD_ID = 0
|
|
|
| def __init__(
|
| self,
|
| model_path: str = "weights/janegpt_v2_classifier.pt",
|
| tokenizer_path: str = "weights/tokenizer.json",
|
| device: str = "auto",
|
| ):
|
| self.model_path = Path(model_path)
|
| self.tokenizer_path = Path(tokenizer_path)
|
| self.is_ready = False
|
|
|
| if device == "auto":
|
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| else:
|
| self.device = torch.device(device)
|
|
|
| self.tokenizer = None
|
| self.model = None
|
| self.id_to_intent = ID_TO_INTENT
|
|
|
| self._load()
|
|
|
| def _load(self):
|
| """Load model and tokenizer."""
|
| if not self.model_path.exists():
|
| raise FileNotFoundError(f"Model not found: {self.model_path}")
|
|
|
| if not self.tokenizer_path.exists():
|
| raise FileNotFoundError(f"Tokenizer not found: {self.tokenizer_path}")
|
|
|
|
|
| from tokenizers import Tokenizer
|
| self.tokenizer = Tokenizer.from_file(str(self.tokenizer_path))
|
|
|
|
|
| checkpoint = torch.load(
|
| self.model_path, map_location=self.device, weights_only=False
|
| )
|
|
|
| config = checkpoint.get('config', {})
|
|
|
| self.model = JaneGPTv2Classifier(
|
| vocab_size=config.get('vocab_size', 8192),
|
| embed_dim=config.get('embed_dim', 256),
|
| num_heads=config.get('num_heads', 8),
|
| num_kv_heads=config.get('num_kv_heads', 4),
|
| num_layers=config.get('num_layers', 8),
|
| ff_hidden=config.get('ff_hidden', 672),
|
| max_seq_len=config.get('max_seq_len', 256),
|
| dropout=config.get('dropout', 0.1),
|
| rope_theta=config.get('rope_theta', 10000.0),
|
| )
|
|
|
| self.model.load_state_dict(checkpoint['model_state_dict'])
|
| self.model.to(self.device)
|
| self.model.eval()
|
| self.is_ready = True
|
|
|
| def _format_input(self, text: str, context: Optional[Dict] = None) -> str:
|
| """Format input for the model."""
|
| if context and context.get('last_intent'):
|
| ctx_str = f"last_action={context['last_intent']}"
|
| else:
|
| ctx_str = "none"
|
|
|
| return f"user: {text}\ncontext: {ctx_str}\njane:"
|
|
|
| def _tokenize(self, text: str) -> torch.Tensor:
|
| """Tokenize and pad to MAX_LEN."""
|
| ids = self.tokenizer.encode(text).ids
|
|
|
| if len(ids) > self.MAX_LEN:
|
| ids = ids[:self.MAX_LEN]
|
| else:
|
| ids = ids + [self.PAD_ID] * (self.MAX_LEN - len(ids))
|
|
|
| return torch.tensor([ids], dtype=torch.long, device=self.device)
|
|
|
| def predict(
|
| self,
|
| text: str,
|
| context: Optional[Dict] = None
|
| ) -> Tuple[str, float]:
|
| """
|
| Predict intent for given text.
|
|
|
| Args:
|
| text: User utterance (e.g., "turn up the volume")
|
| context: Optional dict with 'last_intent' key
|
|
|
| Returns:
|
| Tuple of (intent_label, confidence)
|
|
|
| Example:
|
| >>> classifier.predict("open chrome")
|
| ('app_launch', 0.981)
|
| """
|
| if not self.is_ready:
|
| raise RuntimeError("Model not loaded")
|
|
|
| formatted = self._format_input(text, context)
|
| input_ids = self._tokenize(formatted)
|
|
|
| predicted_idx, confidence = self.model.predict(input_ids)
|
| intent = self.id_to_intent.get(predicted_idx, 'chat')
|
|
|
| return intent, confidence
|
|
|
| def predict_top_k(
|
| self,
|
| text: str,
|
| context: Optional[Dict] = None,
|
| k: int = 3
|
| ) -> List[Tuple[str, float]]:
|
| """
|
| Get top-k predictions with confidences.
|
|
|
| Args:
|
| text: User utterance
|
| context: Optional context dict
|
| k: Number of top predictions to return
|
|
|
| Returns:
|
| List of (intent_label, confidence) tuples
|
|
|
| Example:
|
| >>> classifier.predict_top_k("play something", k=3)
|
| [('media_play', 0.85), ('browser_search', 0.08), ('chat', 0.03)]
|
| """
|
| if not self.is_ready:
|
| raise RuntimeError("Model not loaded")
|
|
|
| formatted = self._format_input(text, context)
|
| input_ids = self._tokenize(formatted)
|
|
|
| with torch.no_grad():
|
| logits, _ = self.model(input_ids)
|
| probs = torch.softmax(logits, dim=-1)
|
| top_probs, top_indices = probs.topk(k, dim=-1)
|
|
|
| return [
|
| (self.id_to_intent.get(idx.item(), 'chat'), prob.item())
|
| for prob, idx in zip(top_probs[0], top_indices[0])
|
| ]
|
|
|
| @staticmethod
|
| def get_supported_intents() -> List[str]:
|
| """Get list of all supported intent labels."""
|
| return INTENT_LABELS.copy()
|
|
|
| def __repr__(self):
|
| return (
|
| f"JaneGPTClassifier("
|
| f"ready={self.is_ready}, "
|
| f"device={self.device}, "
|
| f"intents={len(INTENT_LABELS)})"
|
| ) |