|
from __future__ import annotations |
|
|
|
import asyncio |
|
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast |
|
|
|
from langchain_core.callbacks import AsyncCallbackHandler |
|
from langchain_core.outputs import LLMResult |
|
|
|
|
|
|
|
|
|
class AsyncIteratorCallbackHandler(AsyncCallbackHandler): |
|
"""Callback handler that returns an async iterator.""" |
|
|
|
queue: asyncio.Queue[str] |
|
|
|
done: asyncio.Event |
|
|
|
@property |
|
def always_verbose(self) -> bool: |
|
return True |
|
|
|
def __init__(self) -> None: |
|
self.queue = asyncio.Queue() |
|
self.done = asyncio.Event() |
|
|
|
async def on_llm_start( |
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any |
|
) -> None: |
|
|
|
self.done.clear() |
|
|
|
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: |
|
if token is not None and token != "": |
|
self.queue.put_nowait(token) |
|
|
|
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: |
|
self.done.set() |
|
|
|
async def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: |
|
self.done.set() |
|
|
|
|
|
|
|
async def aiter(self) -> AsyncIterator[str]: |
|
while not self.queue.empty() or not self.done.is_set(): |
|
|
|
|
|
done, other = await asyncio.wait( |
|
[ |
|
|
|
|
|
asyncio.ensure_future(self.queue.get()), |
|
asyncio.ensure_future(self.done.wait()), |
|
], |
|
return_when=asyncio.FIRST_COMPLETED, |
|
) |
|
|
|
|
|
if other: |
|
other.pop().cancel() |
|
|
|
|
|
token_or_done = cast(Union[str, Literal[True]], done.pop().result()) |
|
|
|
|
|
if token_or_done is True: |
|
break |
|
|
|
|
|
yield token_or_done |
|
|