"""Callback handlers used in the app.""" from typing import Any, Dict, List from langchain.callbacks.base import AsyncCallbackHandler from schemas import ChatResponse class StreamingLLMCallbackHandler(AsyncCallbackHandler): """Callback handler for streaming LLM responses.""" def __init__(self, websocket): self.websocket = websocket async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: resp = ChatResponse(sender="bot", message=token, type="stream") await self.websocket.send_json(resp.dict()) class QuestionGenCallbackHandler(AsyncCallbackHandler): """Callback handler for question generation.""" def __init__(self, websocket): self.websocket = websocket async def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Run when LLM starts running.""" resp = ChatResponse( sender="bot", message="Synthesizing question...", type="info" ) await self.websocket.send_json(resp.dict())