| """Service for topic extraction from text using LangChain Groq""" |
|
|
| import logging |
| from typing import Optional, List |
| from langchain_core.messages import HumanMessage, SystemMessage |
| from langchain_groq import ChatGroq |
| from pydantic import BaseModel, Field |
| from langsmith import traceable |
|
|
| from config import GROQ_API_KEY |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class TopicOutput(BaseModel): |
| """Pydantic schema for topic extraction output""" |
| topic: str = Field(..., description="A specific, detailed topic description") |
|
|
|
|
| class TopicService: |
| """Service for extracting topics from text arguments""" |
| |
| def __init__(self): |
| self.llm = None |
| self.model_name = "openai/gpt-oss-safeguard-20b" |
| self.initialized = False |
| |
| def initialize(self, model_name: Optional[str] = None): |
| """Initialize the Groq LLM with structured output""" |
| if self.initialized: |
| logger.info("Topic service already initialized") |
| return |
| |
| if not GROQ_API_KEY: |
| raise ValueError("GROQ_API_KEY not found in environment variables") |
| |
| if model_name: |
| self.model_name = model_name |
| |
| try: |
| logger.info(f"Initializing topic extraction service with model: {self.model_name}") |
| |
| llm = ChatGroq( |
| model=self.model_name, |
| api_key=GROQ_API_KEY, |
| temperature=0.0, |
| max_tokens=512, |
| ) |
| |
| |
| self.llm = llm.with_structured_output(TopicOutput) |
| self.initialized = True |
| |
| logger.info("✓ Topic extraction service initialized successfully") |
| |
| except Exception as e: |
| logger.error(f"Error initializing topic service: {str(e)}") |
| raise RuntimeError(f"Failed to initialize topic service: {str(e)}") |
| |
| @traceable(name="extract_topic") |
| def extract_topic(self, text: str) -> str: |
| """ |
| Extract a topic from the given text/argument |
| |
| Args: |
| text: The input text/argument to extract topic from |
| |
| Returns: |
| The extracted topic string |
| """ |
| if not self.initialized: |
| self.initialize() |
| |
| if not text or not isinstance(text, str): |
| raise ValueError("Text must be a non-empty string") |
| |
| text = text.strip() |
| if len(text) == 0: |
| raise ValueError("Text cannot be empty") |
| |
| system_message = """You are an information extraction model. |
| Extract a topic from the user text. The topic should be a single sentence that captures the main idea of the text in simple english. |
| |
| Examples: |
| - Text: "Governments should subsidize electric cars to encourage adoption." |
| Output: topic="government subsidies for electric vehicle adoption" |
| |
| - Text: "Raising the minimum wage will hurt small businesses and cost jobs." |
| Output: topic="raising the minimum wage and its economic impact on small businesses" |
| """ |
| |
| try: |
| result = self.llm.invoke( |
| [ |
| SystemMessage(content=system_message), |
| HumanMessage(content=text), |
| ] |
| ) |
| |
| return result.topic |
| |
| except Exception as e: |
| logger.error(f"Error extracting topic: {str(e)}") |
| raise RuntimeError(f"Topic extraction failed: {str(e)}") |
| |
| def batch_extract_topics(self, texts: List[str]) -> List[str]: |
| """ |
| Extract topics from multiple texts |
| |
| Args: |
| texts: List of input texts/arguments |
| |
| Returns: |
| List of extracted topics |
| """ |
| if not self.initialized: |
| self.initialize() |
| |
| if not texts or not isinstance(texts, list): |
| raise ValueError("Texts must be a non-empty list") |
| |
| results = [] |
| for text in texts: |
| try: |
| topic = self.extract_topic(text) |
| results.append(topic) |
| except Exception as e: |
| logger.error(f"Error extracting topic for text '{text[:50]}...': {str(e)}") |
| results.append(None) |
| |
| return results |
|
|
|
|
| |
| topic_service = TopicService() |
|
|
|
|