sifa-classification-agentic-rag / src /hooks /useAIStreamHandler.tsx
Programmer-RD-AI
Update Agent UI Verion
bbe4eea
raw
history blame
16.2 kB
import { useCallback } from 'react'
import { APIRoutes } from '@/api/routes'
import useChatActions from '@/hooks/useChatActions'
import { useStore } from '../store'
import { RunEvent, RunResponseContent, type RunResponse } from '@/types/os'
import { constructEndpointUrl } from '@/lib/constructEndpointUrl'
import useAIResponseStream from './useAIResponseStream'
import { ToolCall } from '@/types/os'
import { useQueryState } from 'nuqs'
import { getJsonMarkdown } from '@/lib/utils'
const useAIChatStreamHandler = () => {
const setMessages = useStore((state) => state.setMessages)
const { addMessage, focusChatInput } = useChatActions()
const [agentId] = useQueryState('agent')
const [teamId] = useQueryState('team')
const [sessionId, setSessionId] = useQueryState('session')
const selectedEndpoint = useStore((state) => state.selectedEndpoint)
const mode = useStore((state) => state.mode)
const setStreamingErrorMessage = useStore(
(state) => state.setStreamingErrorMessage
)
const setIsStreaming = useStore((state) => state.setIsStreaming)
const setSessionsData = useStore((state) => state.setSessionsData)
const { streamResponse } = useAIResponseStream()
const updateMessagesWithErrorState = useCallback(() => {
setMessages((prevMessages) => {
const newMessages = [...prevMessages]
const lastMessage = newMessages[newMessages.length - 1]
if (lastMessage && lastMessage.role === 'agent') {
lastMessage.streamingError = true
}
return newMessages
})
}, [setMessages])
/**
* Processes a new tool call and adds it to the message
* @param toolCall - The tool call to add
* @param prevToolCalls - The previous tool calls array
* @returns Updated tool calls array
*/
const processToolCall = useCallback(
(toolCall: ToolCall, prevToolCalls: ToolCall[] = []) => {
const toolCallId =
toolCall.tool_call_id || `${toolCall.tool_name}-${toolCall.created_at}`
const existingToolCallIndex = prevToolCalls.findIndex(
(tc) =>
(tc.tool_call_id && tc.tool_call_id === toolCall.tool_call_id) ||
(!tc.tool_call_id &&
toolCall.tool_name &&
toolCall.created_at &&
`${tc.tool_name}-${tc.created_at}` === toolCallId)
)
if (existingToolCallIndex >= 0) {
const updatedToolCalls = [...prevToolCalls]
updatedToolCalls[existingToolCallIndex] = {
...updatedToolCalls[existingToolCallIndex],
...toolCall
}
return updatedToolCalls
} else {
return [...prevToolCalls, toolCall]
}
},
[]
)
/**
* Processes tool calls from a chunk, handling both single tool object and tools array formats
* @param chunk - The chunk containing tool call data
* @param existingToolCalls - The existing tool calls array
* @returns Updated tool calls array
*/
const processChunkToolCalls = useCallback(
(
chunk: RunResponseContent | RunResponse,
existingToolCalls: ToolCall[] = []
) => {
let updatedToolCalls = [...existingToolCalls]
// Handle new single tool object format
if (chunk.tool) {
updatedToolCalls = processToolCall(chunk.tool, updatedToolCalls)
}
// Handle legacy tools array format
if (chunk.tools && chunk.tools.length > 0) {
for (const toolCall of chunk.tools) {
updatedToolCalls = processToolCall(toolCall, updatedToolCalls)
}
}
return updatedToolCalls
},
[processToolCall]
)
const handleStreamResponse = useCallback(
async (input: string | FormData) => {
setIsStreaming(true)
const formData = input instanceof FormData ? input : new FormData()
if (typeof input === 'string') {
formData.append('message', input)
}
setMessages((prevMessages) => {
if (prevMessages.length >= 2) {
const lastMessage = prevMessages[prevMessages.length - 1]
const secondLastMessage = prevMessages[prevMessages.length - 2]
if (
lastMessage.role === 'agent' &&
lastMessage.streamingError &&
secondLastMessage.role === 'user'
) {
return prevMessages.slice(0, -2)
}
}
return prevMessages
})
addMessage({
role: 'user',
content: formData.get('message') as string,
created_at: Math.floor(Date.now() / 1000)
})
addMessage({
role: 'agent',
content: '',
tool_calls: [],
streamingError: false,
created_at: Math.floor(Date.now() / 1000) + 1
})
let lastContent = ''
let newSessionId = sessionId
try {
const endpointUrl = constructEndpointUrl(selectedEndpoint)
let RunUrl: string | null = null
if (mode === 'team' && teamId) {
RunUrl = APIRoutes.TeamRun(endpointUrl, teamId)
} else if (mode === 'agent' && agentId) {
RunUrl = APIRoutes.AgentRun(endpointUrl).replace(
'{agent_id}',
agentId
)
}
if (!RunUrl) {
updateMessagesWithErrorState()
setStreamingErrorMessage('Please select an agent or team first.')
setIsStreaming(false)
return
}
formData.append('stream', 'true')
formData.append('session_id', sessionId ?? '')
await streamResponse({
apiUrl: RunUrl,
requestBody: formData,
onChunk: (chunk: RunResponse) => {
if (
chunk.event === RunEvent.RunStarted ||
chunk.event === RunEvent.TeamRunStarted ||
chunk.event === RunEvent.ReasoningStarted ||
chunk.event === RunEvent.TeamReasoningStarted
) {
newSessionId = chunk.session_id as string
setSessionId(chunk.session_id as string)
if (
(!sessionId || sessionId !== chunk.session_id) &&
chunk.session_id
) {
const sessionData = {
session_id: chunk.session_id as string,
session_name: formData.get('message') as string,
created_at: chunk.created_at
}
setSessionsData((prevSessionsData) => {
const sessionExists = prevSessionsData?.some(
(session) => session.session_id === chunk.session_id
)
if (sessionExists) {
return prevSessionsData
}
return [sessionData, ...(prevSessionsData ?? [])]
})
}
} else if (
chunk.event === RunEvent.ToolCallStarted ||
chunk.event === RunEvent.TeamToolCallStarted ||
chunk.event === RunEvent.ToolCallCompleted ||
chunk.event === RunEvent.TeamToolCallCompleted
) {
setMessages((prevMessages) => {
const newMessages = [...prevMessages]
const lastMessage = newMessages[newMessages.length - 1]
if (lastMessage && lastMessage.role === 'agent') {
lastMessage.tool_calls = processChunkToolCalls(
chunk,
lastMessage.tool_calls
)
}
return newMessages
})
} else if (
chunk.event === RunEvent.RunContent ||
chunk.event === RunEvent.TeamRunContent
) {
setMessages((prevMessages) => {
const newMessages = [...prevMessages]
const lastMessage = newMessages[newMessages.length - 1]
if (
lastMessage &&
lastMessage.role === 'agent' &&
typeof chunk.content === 'string'
) {
const uniqueContent = chunk.content.replace(lastContent, '')
lastMessage.content += uniqueContent
lastContent = chunk.content
// Handle tool calls streaming
lastMessage.tool_calls = processChunkToolCalls(
chunk,
lastMessage.tool_calls
)
if (chunk.extra_data?.reasoning_steps) {
lastMessage.extra_data = {
...lastMessage.extra_data,
reasoning_steps: chunk.extra_data.reasoning_steps
}
}
if (chunk.extra_data?.references) {
lastMessage.extra_data = {
...lastMessage.extra_data,
references: chunk.extra_data.references
}
}
lastMessage.created_at =
chunk.created_at ?? lastMessage.created_at
if (chunk.images) {
lastMessage.images = chunk.images
}
if (chunk.videos) {
lastMessage.videos = chunk.videos
}
if (chunk.audio) {
lastMessage.audio = chunk.audio
}
} else if (
lastMessage &&
lastMessage.role === 'agent' &&
typeof chunk?.content !== 'string' &&
chunk.content !== null
) {
const jsonBlock = getJsonMarkdown(chunk?.content)
lastMessage.content += jsonBlock
lastContent = jsonBlock
} else if (
chunk.response_audio?.transcript &&
typeof chunk.response_audio?.transcript === 'string'
) {
const transcript = chunk.response_audio.transcript
lastMessage.response_audio = {
...lastMessage.response_audio,
transcript:
lastMessage.response_audio?.transcript + transcript
}
}
return newMessages
})
} else if (
chunk.event === RunEvent.ReasoningStep ||
chunk.event === RunEvent.TeamReasoningStep
) {
setMessages((prevMessages) => {
const newMessages = [...prevMessages]
const lastMessage = newMessages[newMessages.length - 1]
if (lastMessage && lastMessage.role === 'agent') {
const existingSteps =
lastMessage.extra_data?.reasoning_steps ?? []
const incomingSteps = chunk.extra_data?.reasoning_steps ?? []
lastMessage.extra_data = {
...lastMessage.extra_data,
reasoning_steps: [...existingSteps, ...incomingSteps]
}
}
return newMessages
})
} else if (
chunk.event === RunEvent.ReasoningCompleted ||
chunk.event === RunEvent.TeamReasoningCompleted
) {
setMessages((prevMessages) => {
const newMessages = [...prevMessages]
const lastMessage = newMessages[newMessages.length - 1]
if (lastMessage && lastMessage.role === 'agent') {
if (chunk.extra_data?.reasoning_steps) {
lastMessage.extra_data = {
...lastMessage.extra_data,
reasoning_steps: chunk.extra_data.reasoning_steps
}
}
}
return newMessages
})
} else if (
chunk.event === RunEvent.RunError ||
chunk.event === RunEvent.TeamRunError ||
chunk.event === RunEvent.TeamRunCancelled
) {
updateMessagesWithErrorState()
const errorContent =
(chunk.content as string) ||
(chunk.event === RunEvent.TeamRunCancelled
? 'Run cancelled'
: 'Error during run')
setStreamingErrorMessage(errorContent)
if (newSessionId) {
setSessionsData(
(prevSessionsData) =>
prevSessionsData?.filter(
(session) => session.session_id !== newSessionId
) ?? null
)
}
} else if (
chunk.event === RunEvent.UpdatingMemory ||
chunk.event === RunEvent.TeamMemoryUpdateStarted ||
chunk.event === RunEvent.TeamMemoryUpdateCompleted
) {
// No-op for now; could surface a lightweight UI indicator in the future
} else if (
chunk.event === RunEvent.RunCompleted ||
chunk.event === RunEvent.TeamRunCompleted
) {
setMessages((prevMessages) => {
const newMessages = prevMessages.map((message, index) => {
if (
index === prevMessages.length - 1 &&
message.role === 'agent'
) {
let updatedContent: string
if (typeof chunk.content === 'string') {
updatedContent = chunk.content
} else {
try {
updatedContent = JSON.stringify(chunk.content)
} catch {
updatedContent = 'Error parsing response'
}
}
return {
...message,
content: updatedContent,
tool_calls: processChunkToolCalls(
chunk,
message.tool_calls
),
images: chunk.images ?? message.images,
videos: chunk.videos ?? message.videos,
response_audio: chunk.response_audio,
created_at: chunk.created_at ?? message.created_at,
extra_data: {
reasoning_steps:
chunk.extra_data?.reasoning_steps ??
message.extra_data?.reasoning_steps,
references:
chunk.extra_data?.references ??
message.extra_data?.references
}
}
}
return message
})
return newMessages
})
}
},
onError: (error) => {
updateMessagesWithErrorState()
setStreamingErrorMessage(error.message)
if (newSessionId) {
setSessionsData(
(prevSessionsData) =>
prevSessionsData?.filter(
(session) => session.session_id !== newSessionId
) ?? null
)
}
},
onComplete: () => {}
})
} catch (error) {
updateMessagesWithErrorState()
setStreamingErrorMessage(
error instanceof Error ? error.message : String(error)
)
if (newSessionId) {
setSessionsData(
(prevSessionsData) =>
prevSessionsData?.filter(
(session) => session.session_id !== newSessionId
) ?? null
)
}
} finally {
focusChatInput()
setIsStreaming(false)
}
},
[
setMessages,
addMessage,
updateMessagesWithErrorState,
selectedEndpoint,
streamResponse,
agentId,
teamId,
mode,
setStreamingErrorMessage,
setIsStreaming,
focusChatInput,
setSessionsData,
sessionId,
setSessionId,
processChunkToolCalls
]
)
return { handleStreamResponse }
}
export default useAIChatStreamHandler