Praneeth Yerrapragada
feat: repo setup
2636575
raw
history blame
No virus
4.75 kB
from pydantic import BaseModel
from typing import List, Any, Optional, Dict, Tuple
from fastapi import APIRouter, Depends, HTTPException, Request, status
from llama_index.core.chat_engine.types import BaseChatEngine
from llama_index.core.schema import NodeWithScore
from llama_index.core.llms import ChatMessage, MessageRole
from app.engine import get_chat_engine
from app.api.routers.vercel_response import VercelStreamResponse
from app.api.routers.messaging import EventCallbackHandler
from aiostream import stream
chat_router = r = APIRouter()
class _Message(BaseModel):
role: MessageRole
content: str
class _ChatData(BaseModel):
messages: List[_Message]
class Config:
json_schema_extra = {
"example": {
"messages": [
{
"role": "user",
"content": "What standards for letters exist?",
}
]
}
}
class _SourceNodes(BaseModel):
id: str
metadata: Dict[str, Any]
score: Optional[float]
text: str
@classmethod
def from_source_node(cls, source_node: NodeWithScore):
return cls(
id=source_node.node.node_id,
metadata=source_node.node.metadata,
score=source_node.score,
text=source_node.node.text, # type: ignore
)
@classmethod
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
return [cls.from_source_node(node) for node in source_nodes]
class _Result(BaseModel):
result: _Message
nodes: List[_SourceNodes]
async def parse_chat_data(data: _ChatData) -> Tuple[str, List[ChatMessage]]:
# check preconditions and get last message
if len(data.messages) == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No messages provided",
)
last_message = data.messages.pop()
if last_message.role != MessageRole.USER:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Last message must be from user",
)
# convert messages coming from the request to type ChatMessage
messages = [
ChatMessage(
role=m.role,
content=m.content,
)
for m in data.messages
]
return last_message.content, messages
# streaming endpoint - delete if not needed
@r.post("")
async def chat(
request: Request,
data: _ChatData,
chat_engine: BaseChatEngine = Depends(get_chat_engine),
):
last_message_content, messages = await parse_chat_data(data)
event_handler = EventCallbackHandler()
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
response = await chat_engine.astream_chat(last_message_content, messages)
async def content_generator():
# Yield the text response
async def _text_generator():
async for token in response.async_response_gen():
yield VercelStreamResponse.convert_text(token)
# the text_generator is the leading stream, once it's finished, also finish the event stream
event_handler.is_done = True
# Yield the events from the event handler
async def _event_generator():
async for event in event_handler.async_event_gen():
event_response = event.to_response()
if event_response is not None:
yield VercelStreamResponse.convert_data(event_response)
combine = stream.merge(_text_generator(), _event_generator())
async with combine.stream() as streamer:
async for item in streamer:
if await request.is_disconnected():
break
yield item
# Yield the source nodes
yield VercelStreamResponse.convert_data(
{
"type": "sources",
"data": {
"nodes": [
_SourceNodes.from_source_node(node).dict()
for node in response.source_nodes
]
},
}
)
return VercelStreamResponse(content=content_generator())
# non-streaming endpoint - delete if not needed
@r.post("/request")
async def chat_request(
data: _ChatData,
chat_engine: BaseChatEngine = Depends(get_chat_engine),
) -> _Result:
last_message_content, messages = await parse_chat_data(data)
response = await chat_engine.achat(last_message_content, messages)
return _Result(
result=_Message(role=MessageRole.ASSISTANT, content=response.response),
nodes=_SourceNodes.from_source_nodes(response.source_nodes),
)