|
|
import { useCallback } from 'react' |
|
|
|
|
|
import { APIRoutes } from '@/api/routes' |
|
|
|
|
|
import useChatActions from '@/hooks/useChatActions' |
|
|
import { useStore } from '../store' |
|
|
import { RunEvent, RunResponseContent, type RunResponse } from '@/types/os' |
|
|
import { constructEndpointUrl } from '@/lib/constructEndpointUrl' |
|
|
import useAIResponseStream from './useAIResponseStream' |
|
|
import { ToolCall } from '@/types/os' |
|
|
import { useQueryState } from 'nuqs' |
|
|
import { getJsonMarkdown } from '@/lib/utils' |
|
|
|
|
|
const useAIChatStreamHandler = () => { |
|
|
const setMessages = useStore((state) => state.setMessages) |
|
|
const { addMessage, focusChatInput } = useChatActions() |
|
|
const [agentId] = useQueryState('agent') |
|
|
const [teamId] = useQueryState('team') |
|
|
const [sessionId, setSessionId] = useQueryState('session') |
|
|
const selectedEndpoint = useStore((state) => state.selectedEndpoint) |
|
|
const mode = useStore((state) => state.mode) |
|
|
const setStreamingErrorMessage = useStore( |
|
|
(state) => state.setStreamingErrorMessage |
|
|
) |
|
|
const setIsStreaming = useStore((state) => state.setIsStreaming) |
|
|
const setSessionsData = useStore((state) => state.setSessionsData) |
|
|
const { streamResponse } = useAIResponseStream() |
|
|
|
|
|
const updateMessagesWithErrorState = useCallback(() => { |
|
|
setMessages((prevMessages) => { |
|
|
const newMessages = [...prevMessages] |
|
|
const lastMessage = newMessages[newMessages.length - 1] |
|
|
if (lastMessage && lastMessage.role === 'agent') { |
|
|
lastMessage.streamingError = true |
|
|
} |
|
|
return newMessages |
|
|
}) |
|
|
}, [setMessages]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const processToolCall = useCallback( |
|
|
(toolCall: ToolCall, prevToolCalls: ToolCall[] = []) => { |
|
|
const toolCallId = |
|
|
toolCall.tool_call_id || `${toolCall.tool_name}-${toolCall.created_at}` |
|
|
|
|
|
const existingToolCallIndex = prevToolCalls.findIndex( |
|
|
(tc) => |
|
|
(tc.tool_call_id && tc.tool_call_id === toolCall.tool_call_id) || |
|
|
(!tc.tool_call_id && |
|
|
toolCall.tool_name && |
|
|
toolCall.created_at && |
|
|
`${tc.tool_name}-${tc.created_at}` === toolCallId) |
|
|
) |
|
|
if (existingToolCallIndex >= 0) { |
|
|
const updatedToolCalls = [...prevToolCalls] |
|
|
updatedToolCalls[existingToolCallIndex] = { |
|
|
...updatedToolCalls[existingToolCallIndex], |
|
|
...toolCall |
|
|
} |
|
|
return updatedToolCalls |
|
|
} else { |
|
|
return [...prevToolCalls, toolCall] |
|
|
} |
|
|
}, |
|
|
[] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const processChunkToolCalls = useCallback( |
|
|
( |
|
|
chunk: RunResponseContent | RunResponse, |
|
|
existingToolCalls: ToolCall[] = [] |
|
|
) => { |
|
|
let updatedToolCalls = [...existingToolCalls] |
|
|
|
|
|
if (chunk.tool) { |
|
|
updatedToolCalls = processToolCall(chunk.tool, updatedToolCalls) |
|
|
} |
|
|
|
|
|
if (chunk.tools && chunk.tools.length > 0) { |
|
|
for (const toolCall of chunk.tools) { |
|
|
updatedToolCalls = processToolCall(toolCall, updatedToolCalls) |
|
|
} |
|
|
} |
|
|
|
|
|
return updatedToolCalls |
|
|
}, |
|
|
[processToolCall] |
|
|
) |
|
|
|
|
|
const handleStreamResponse = useCallback( |
|
|
async (input: string | FormData) => { |
|
|
setIsStreaming(true) |
|
|
|
|
|
const formData = input instanceof FormData ? input : new FormData() |
|
|
if (typeof input === 'string') { |
|
|
formData.append('message', input) |
|
|
} |
|
|
|
|
|
setMessages((prevMessages) => { |
|
|
if (prevMessages.length >= 2) { |
|
|
const lastMessage = prevMessages[prevMessages.length - 1] |
|
|
const secondLastMessage = prevMessages[prevMessages.length - 2] |
|
|
if ( |
|
|
lastMessage.role === 'agent' && |
|
|
lastMessage.streamingError && |
|
|
secondLastMessage.role === 'user' |
|
|
) { |
|
|
return prevMessages.slice(0, -2) |
|
|
} |
|
|
} |
|
|
return prevMessages |
|
|
}) |
|
|
|
|
|
addMessage({ |
|
|
role: 'user', |
|
|
content: formData.get('message') as string, |
|
|
created_at: Math.floor(Date.now() / 1000) |
|
|
}) |
|
|
|
|
|
addMessage({ |
|
|
role: 'agent', |
|
|
content: '', |
|
|
tool_calls: [], |
|
|
streamingError: false, |
|
|
created_at: Math.floor(Date.now() / 1000) + 1 |
|
|
}) |
|
|
|
|
|
let lastContent = '' |
|
|
let newSessionId = sessionId |
|
|
try { |
|
|
const endpointUrl = constructEndpointUrl(selectedEndpoint) |
|
|
|
|
|
let RunUrl: string | null = null |
|
|
|
|
|
if (mode === 'team' && teamId) { |
|
|
RunUrl = APIRoutes.TeamRun(endpointUrl, teamId) |
|
|
} else if (mode === 'agent' && agentId) { |
|
|
RunUrl = APIRoutes.AgentRun(endpointUrl).replace( |
|
|
'{agent_id}', |
|
|
agentId |
|
|
) |
|
|
} |
|
|
|
|
|
if (!RunUrl) { |
|
|
updateMessagesWithErrorState() |
|
|
setStreamingErrorMessage('Please select an agent or team first.') |
|
|
setIsStreaming(false) |
|
|
return |
|
|
} |
|
|
|
|
|
formData.append('stream', 'true') |
|
|
formData.append('session_id', sessionId ?? '') |
|
|
|
|
|
await streamResponse({ |
|
|
apiUrl: RunUrl, |
|
|
requestBody: formData, |
|
|
onChunk: (chunk: RunResponse) => { |
|
|
if ( |
|
|
chunk.event === RunEvent.RunStarted || |
|
|
chunk.event === RunEvent.TeamRunStarted || |
|
|
chunk.event === RunEvent.ReasoningStarted || |
|
|
chunk.event === RunEvent.TeamReasoningStarted |
|
|
) { |
|
|
newSessionId = chunk.session_id as string |
|
|
setSessionId(chunk.session_id as string) |
|
|
if ( |
|
|
(!sessionId || sessionId !== chunk.session_id) && |
|
|
chunk.session_id |
|
|
) { |
|
|
const sessionData = { |
|
|
session_id: chunk.session_id as string, |
|
|
session_name: formData.get('message') as string, |
|
|
created_at: chunk.created_at |
|
|
} |
|
|
setSessionsData((prevSessionsData) => { |
|
|
const sessionExists = prevSessionsData?.some( |
|
|
(session) => session.session_id === chunk.session_id |
|
|
) |
|
|
if (sessionExists) { |
|
|
return prevSessionsData |
|
|
} |
|
|
return [sessionData, ...(prevSessionsData ?? [])] |
|
|
}) |
|
|
} |
|
|
} else if ( |
|
|
chunk.event === RunEvent.ToolCallStarted || |
|
|
chunk.event === RunEvent.TeamToolCallStarted || |
|
|
chunk.event === RunEvent.ToolCallCompleted || |
|
|
chunk.event === RunEvent.TeamToolCallCompleted |
|
|
) { |
|
|
setMessages((prevMessages) => { |
|
|
const newMessages = [...prevMessages] |
|
|
const lastMessage = newMessages[newMessages.length - 1] |
|
|
if (lastMessage && lastMessage.role === 'agent') { |
|
|
lastMessage.tool_calls = processChunkToolCalls( |
|
|
chunk, |
|
|
lastMessage.tool_calls |
|
|
) |
|
|
} |
|
|
return newMessages |
|
|
}) |
|
|
} else if ( |
|
|
chunk.event === RunEvent.RunContent || |
|
|
chunk.event === RunEvent.TeamRunContent |
|
|
) { |
|
|
setMessages((prevMessages) => { |
|
|
const newMessages = [...prevMessages] |
|
|
const lastMessage = newMessages[newMessages.length - 1] |
|
|
if ( |
|
|
lastMessage && |
|
|
lastMessage.role === 'agent' && |
|
|
typeof chunk.content === 'string' |
|
|
) { |
|
|
const uniqueContent = chunk.content.replace(lastContent, '') |
|
|
lastMessage.content += uniqueContent |
|
|
lastContent = chunk.content |
|
|
|
|
|
|
|
|
lastMessage.tool_calls = processChunkToolCalls( |
|
|
chunk, |
|
|
lastMessage.tool_calls |
|
|
) |
|
|
if (chunk.extra_data?.reasoning_steps) { |
|
|
lastMessage.extra_data = { |
|
|
...lastMessage.extra_data, |
|
|
reasoning_steps: chunk.extra_data.reasoning_steps |
|
|
} |
|
|
} |
|
|
|
|
|
if (chunk.extra_data?.references) { |
|
|
lastMessage.extra_data = { |
|
|
...lastMessage.extra_data, |
|
|
references: chunk.extra_data.references |
|
|
} |
|
|
} |
|
|
|
|
|
lastMessage.created_at = |
|
|
chunk.created_at ?? lastMessage.created_at |
|
|
if (chunk.images) { |
|
|
lastMessage.images = chunk.images |
|
|
} |
|
|
if (chunk.videos) { |
|
|
lastMessage.videos = chunk.videos |
|
|
} |
|
|
if (chunk.audio) { |
|
|
lastMessage.audio = chunk.audio |
|
|
} |
|
|
} else if ( |
|
|
lastMessage && |
|
|
lastMessage.role === 'agent' && |
|
|
typeof chunk?.content !== 'string' && |
|
|
chunk.content !== null |
|
|
) { |
|
|
const jsonBlock = getJsonMarkdown(chunk?.content) |
|
|
|
|
|
lastMessage.content += jsonBlock |
|
|
lastContent = jsonBlock |
|
|
} else if ( |
|
|
chunk.response_audio?.transcript && |
|
|
typeof chunk.response_audio?.transcript === 'string' |
|
|
) { |
|
|
const transcript = chunk.response_audio.transcript |
|
|
lastMessage.response_audio = { |
|
|
...lastMessage.response_audio, |
|
|
transcript: |
|
|
lastMessage.response_audio?.transcript + transcript |
|
|
} |
|
|
} |
|
|
return newMessages |
|
|
}) |
|
|
} else if ( |
|
|
chunk.event === RunEvent.ReasoningStep || |
|
|
chunk.event === RunEvent.TeamReasoningStep |
|
|
) { |
|
|
setMessages((prevMessages) => { |
|
|
const newMessages = [...prevMessages] |
|
|
const lastMessage = newMessages[newMessages.length - 1] |
|
|
if (lastMessage && lastMessage.role === 'agent') { |
|
|
const existingSteps = |
|
|
lastMessage.extra_data?.reasoning_steps ?? [] |
|
|
const incomingSteps = chunk.extra_data?.reasoning_steps ?? [] |
|
|
lastMessage.extra_data = { |
|
|
...lastMessage.extra_data, |
|
|
reasoning_steps: [...existingSteps, ...incomingSteps] |
|
|
} |
|
|
} |
|
|
return newMessages |
|
|
}) |
|
|
} else if ( |
|
|
chunk.event === RunEvent.ReasoningCompleted || |
|
|
chunk.event === RunEvent.TeamReasoningCompleted |
|
|
) { |
|
|
setMessages((prevMessages) => { |
|
|
const newMessages = [...prevMessages] |
|
|
const lastMessage = newMessages[newMessages.length - 1] |
|
|
if (lastMessage && lastMessage.role === 'agent') { |
|
|
if (chunk.extra_data?.reasoning_steps) { |
|
|
lastMessage.extra_data = { |
|
|
...lastMessage.extra_data, |
|
|
reasoning_steps: chunk.extra_data.reasoning_steps |
|
|
} |
|
|
} |
|
|
} |
|
|
return newMessages |
|
|
}) |
|
|
} else if ( |
|
|
chunk.event === RunEvent.RunError || |
|
|
chunk.event === RunEvent.TeamRunError || |
|
|
chunk.event === RunEvent.TeamRunCancelled |
|
|
) { |
|
|
updateMessagesWithErrorState() |
|
|
const errorContent = |
|
|
(chunk.content as string) || |
|
|
(chunk.event === RunEvent.TeamRunCancelled |
|
|
? 'Run cancelled' |
|
|
: 'Error during run') |
|
|
setStreamingErrorMessage(errorContent) |
|
|
if (newSessionId) { |
|
|
setSessionsData( |
|
|
(prevSessionsData) => |
|
|
prevSessionsData?.filter( |
|
|
(session) => session.session_id !== newSessionId |
|
|
) ?? null |
|
|
) |
|
|
} |
|
|
} else if ( |
|
|
chunk.event === RunEvent.UpdatingMemory || |
|
|
chunk.event === RunEvent.TeamMemoryUpdateStarted || |
|
|
chunk.event === RunEvent.TeamMemoryUpdateCompleted |
|
|
) { |
|
|
|
|
|
} else if ( |
|
|
chunk.event === RunEvent.RunCompleted || |
|
|
chunk.event === RunEvent.TeamRunCompleted |
|
|
) { |
|
|
setMessages((prevMessages) => { |
|
|
const newMessages = prevMessages.map((message, index) => { |
|
|
if ( |
|
|
index === prevMessages.length - 1 && |
|
|
message.role === 'agent' |
|
|
) { |
|
|
let updatedContent: string |
|
|
if (typeof chunk.content === 'string') { |
|
|
updatedContent = chunk.content |
|
|
} else { |
|
|
try { |
|
|
updatedContent = JSON.stringify(chunk.content) |
|
|
} catch { |
|
|
updatedContent = 'Error parsing response' |
|
|
} |
|
|
} |
|
|
return { |
|
|
...message, |
|
|
content: updatedContent, |
|
|
tool_calls: processChunkToolCalls( |
|
|
chunk, |
|
|
message.tool_calls |
|
|
), |
|
|
images: chunk.images ?? message.images, |
|
|
videos: chunk.videos ?? message.videos, |
|
|
response_audio: chunk.response_audio, |
|
|
created_at: chunk.created_at ?? message.created_at, |
|
|
extra_data: { |
|
|
reasoning_steps: |
|
|
chunk.extra_data?.reasoning_steps ?? |
|
|
message.extra_data?.reasoning_steps, |
|
|
references: |
|
|
chunk.extra_data?.references ?? |
|
|
message.extra_data?.references |
|
|
} |
|
|
} |
|
|
} |
|
|
return message |
|
|
}) |
|
|
return newMessages |
|
|
}) |
|
|
} |
|
|
}, |
|
|
onError: (error) => { |
|
|
updateMessagesWithErrorState() |
|
|
setStreamingErrorMessage(error.message) |
|
|
if (newSessionId) { |
|
|
setSessionsData( |
|
|
(prevSessionsData) => |
|
|
prevSessionsData?.filter( |
|
|
(session) => session.session_id !== newSessionId |
|
|
) ?? null |
|
|
) |
|
|
} |
|
|
}, |
|
|
onComplete: () => {} |
|
|
}) |
|
|
} catch (error) { |
|
|
updateMessagesWithErrorState() |
|
|
setStreamingErrorMessage( |
|
|
error instanceof Error ? error.message : String(error) |
|
|
) |
|
|
if (newSessionId) { |
|
|
setSessionsData( |
|
|
(prevSessionsData) => |
|
|
prevSessionsData?.filter( |
|
|
(session) => session.session_id !== newSessionId |
|
|
) ?? null |
|
|
) |
|
|
} |
|
|
} finally { |
|
|
focusChatInput() |
|
|
setIsStreaming(false) |
|
|
} |
|
|
}, |
|
|
[ |
|
|
setMessages, |
|
|
addMessage, |
|
|
updateMessagesWithErrorState, |
|
|
selectedEndpoint, |
|
|
streamResponse, |
|
|
agentId, |
|
|
teamId, |
|
|
mode, |
|
|
setStreamingErrorMessage, |
|
|
setIsStreaming, |
|
|
focusChatInput, |
|
|
setSessionsData, |
|
|
sessionId, |
|
|
setSessionId, |
|
|
processChunkToolCalls |
|
|
] |
|
|
) |
|
|
|
|
|
return { handleStreamResponse } |
|
|
} |
|
|
|
|
|
export default useAIChatStreamHandler |
|
|
|