Business_Chatbot / src /intent_classifier.py
Ancastal's picture
Upload folder using huggingface_hub
401b16c verified
import openai
import os
import dirtyjson as json
from typing import Dict, Any, Optional, Tuple
from enum import Enum
from pydantic import BaseModel
class IntentType(str, Enum):
TRANSACTION = "transaction"
QUERY = "query"
SEMANTIC_SEARCH = "semantic_search"
GENERAL_INFO = "general_info"
class IntentResult(BaseModel):
intent: IntentType
confidence: float
reasoning: str
entities_hint: Optional[str] = None
class IntentClassifier:
def __init__(self, api_key: Optional[str] = None):
"""Initialize OpenAI client for intent classification"""
self.client = openai.OpenAI(
api_key=api_key or os.getenv('OPENAI_API_KEY')
)
def classify_intent(self, user_message: str) -> IntentResult:
"""
Classify user intent using OpenAI API
Returns: IntentResult with intent type, confidence, and reasoning
"""
system_prompt = """You are an expert intent classifier for a business chatbot that handles sales, purchases, and general information storage.
Given a user message, classify it into one of these intents:
1. **QUERY**: User wants to retrieve or analyze STRUCTURED data from SQL database tables
- Examples: "How many USB drives did we buy?" (counts from purchases table)
- Examples: "What's the total value of all sales?" (sum from sales table)
- Examples: "Show me recent transactions" (list from transactions table)
- Examples: "List all customers" (data from customers table)
- Key indicators: Asking for counts, totals, lists, recent data from business transactions
- Must be answerable from structured database tables (purchases, sales, customers, suppliers, products)
2. **SEMANTIC_SEARCH**: User wants to find contextual information, tasks, or unstructured data
- Examples: "What does Mark need to do?" (searching for task/context info)
- Examples: "Find events related to supplier meetings" (contextual search)
- Examples: "When do I have the meeting with George?" (calendar/scheduling info)
- Examples: "Show me similar purchases to this one" (similarity search)
- Examples: "What did we discuss in the last meeting?" (meeting notes/context)
- Key indicators: Questions about tasks, meetings, discussions, or contextual information
- Information that would NOT be in structured database tables
3. **TRANSACTION**: User wants to record a business transaction (purchase or sale)
- Examples: "Add a purchase of 20 USB drives from TechMart at €5 each"
- Examples: "Sold 10 laptops to John Smith at €800 each"
- Contains: product names, quantities, suppliers/customers, prices
- Action: Recording new business data
4. **GENERAL_INFO**: User wants to store general business information or notes
- It cannot be a question.
- Examples: "Meeting with new supplier scheduled for next week"
- Examples: "Remember to check inventory levels before next order"
- Examples: "Mark needs to call the supplier tomorrow"
- Contains: notes, reminders, general business information, task assignments
Return your response in this exact JSON format:
{
"intent": "transaction|query|semantic_search|general_info",
"confidence": 0.0-1.0,
"reasoning": "Brief explanation of why you chose this intent",
"entities_hint": "Optional: Key entities you detected (for transaction intent)"
}
Be precise and consider context carefully. If unsure, choose the most likely intent and indicate lower confidence."""
user_prompt = f'Classify the intent of this user message: "{user_message}"'
try:
response = self.client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.1,
max_tokens=300
)
response_text = response.choices[0].message.content.strip()
# Clean JSON response more carefully
if response_text.startswith("```json"):
response_text = response_text[7:]
if response_text.startswith("```"):
response_text = response_text[3:]
if response_text.endswith("```"):
response_text = response_text[:-3]
response_text = response_text.strip()
# Parse JSON response
try:
result_dict = json.loads(response_text)
# Validate intent value
intent_value = result_dict.get("intent", "").lower()
if intent_value not in [e.value for e in IntentType]:
print(f"Invalid intent value: {intent_value}")
return self._fallback_classification(user_message, f"Invalid intent: {intent_value}")
return IntentResult(
intent=IntentType(intent_value),
confidence=float(result_dict.get("confidence", 0.5)),
reasoning=result_dict.get("reasoning", "No reasoning provided"),
entities_hint=result_dict.get("entities_hint")
)
except Exception as e:
# Fallback if JSON parsing fails
print(f"JSON parsing error: {e}")
print(f"Raw response: {response_text}")
return self._fallback_classification(user_message, f"JSON parsing failed: {str(e)}")
except Exception as e:
print(f"Error in intent classification: {e}")
return self._fallback_classification(user_message, str(e))
def _fallback_classification(self, user_message: str, error_info: str) -> IntentResult:
"""Fallback classification when OpenAI API fails"""
message_lower = user_message.lower()
# Simple keyword-based fallback
transaction_keywords = ["purchase", "buy", "sold", "sale", "from", "to", "€", "$"]
query_keywords = ["how many", "total", "list all", "recent transactions", "count"]
search_keywords = ["similar", "like", "related", "about", "need to do", "meeting", "discuss", "task"]
if any(keyword in message_lower for keyword in transaction_keywords):
intent = IntentType.TRANSACTION
confidence = 0.6
elif any(keyword in message_lower for keyword in query_keywords):
intent = IntentType.QUERY
confidence = 0.6
elif any(keyword in message_lower for keyword in search_keywords):
intent = IntentType.SEMANTIC_SEARCH
confidence = 0.6
else:
intent = IntentType.GENERAL_INFO
confidence = 0.5
return IntentResult(
intent=intent,
confidence=confidence,
reasoning=f"Fallback classification due to API error: {error_info[:100]}",
entities_hint=None
)
def get_intent_description(self, intent: IntentType) -> str:
"""Get human-readable description of intent type"""
descriptions = {
IntentType.TRANSACTION: "Recording a business transaction (purchase or sale)",
IntentType.QUERY: "Retrieving or analyzing data from the database",
IntentType.SEMANTIC_SEARCH: "Finding similar events or information",
IntentType.GENERAL_INFO: "Storing general business information or notes"
}
return descriptions.get(intent, "Unknown intent type")
def batch_classify(self, messages: list[str]) -> list[IntentResult]:
"""Classify multiple messages efficiently"""
results = []
for message in messages:
result = self.classify_intent(message)
results.append(result)
return results