wuyiqun0718's picture
update
159e7fa
raw
history blame
No virus
10.7 kB
import { StreamingTextResponse } from 'ai';
// import { auth } from '@/auth';
import { MessageUI } from '@/lib/types';
import { logger, withLogging } from '@/lib/logger';
import { getPresignedUrl } from '@/lib/aws';
import { dbPostUpdateMessageResponse } from '@/lib/db/functions';
// export const runtime = 'edge';
export const dynamic = 'force-dynamic';
export const maxDuration = 300; // This function can run for a maximum of 5 minutes
const TIMEOUT_MILI_SECONDS = 2 * 60 * 1000;
const FINAL_TIMEOUT_ERROR: PrismaJson.FinalErrorBody = {
type: 'final_error',
status: 'failed',
payload: {
name: 'AgentTimeout',
value: `Haven't received any response in last ${TIMEOUT_MILI_SECONDS / 60000} minutes, agent timed out.`,
traceback_raw: [],
},
};
const uploadBase64 = async (
base64: string,
messageId: string,
chatId: string,
index: number,
user: string,
) => {
const res = await fetch(base64);
const blob = await res.blob();
const { signedUrl, publicUrl, fields } = await getPresignedUrl(
`answer-${index}.${blob.type.split('/')[1]}`,
blob.type,
`${chatId}/${messageId}`,
user,
);
const formData = new FormData();
Object.entries(fields).forEach(([key, value]) => {
formData.append(key, value as string);
});
formData.append('file', blob);
const uploadResponse = await fetch(signedUrl, {
method: 'POST',
body: formData,
});
if (uploadResponse.ok) {
return publicUrl;
} else {
throw new Error('Upload failed');
}
};
const modifyCodePayload = async (
msg: PrismaJson.MessageBody,
messageId: string,
chatId: string,
user: string,
): Promise<PrismaJson.MessageBody> => {
if (
(msg.type !== 'final_code' &&
(msg.type !== 'code' ||
msg.status === 'started' ||
msg.status === 'running')) ||
!msg.payload?.result
) {
return msg;
}
const result = (
typeof msg.payload.result === 'string'
? JSON.parse(msg.payload.result)
: msg.payload.result
) as PrismaJson.StructuredResult;
if (msg.type === 'code') {
if (result && result.results) {
msg.payload.result = {
...result,
results: result.results.map((_result: any) => {
return {
..._result,
png: undefined,
mp4: undefined,
};
}),
};
}
return msg;
}
for (let index = 0; index < result.results.length; index++) {
const png = result.results[index].png ?? '';
const mp4 = result.results[index].mp4 ?? '';
if (!png && !mp4) continue;
const resp = await uploadBase64(
png ? 'data:image/png;base64,' + png : 'data:video/mp4;base64,' + mp4,
messageId,
chatId,
index,
user,
);
if (png) result.results[index].png = resp;
if (mp4) result.results[index].mp4 = resp;
}
msg.payload.result = result;
return msg;
};
export const POST = withLogging(
async (
session,
json: {
apiMessages: string;
id: string;
mediaUrl: string;
},
request,
) => {
const { apiMessages, mediaUrl, id: chatId } = json;
const messages: MessageUI[] = JSON.parse(apiMessages);
const messageId = messages[messages.length - 1].id.split('-')[0];
const user = session?.user?.email ?? 'anonymous';
const formData = new FormData();
formData.append('input', apiMessages);
formData.append('image', mediaUrl);
const agentHost = process.env.LND_TIER
? 'http://publicrestapi-app-lndsvc.publicrestapi.svc.cluster.local:5000'
: 'https://api.dev.landing.ai';
const fetchResponse = await fetch(
`${agentHost}/v1/agent/chat?agent_class=vision_agent&self_reflection=false`,
// `https://api.dev.landing.ai/v1/agent/chat?agent_class=vision_agent&self_reflection=false`,
// `http://localhost:5001/v1/agent/chat?agent_class=vision_agent&self_reflection=false`,
{
method: 'POST',
headers: {
// default to dev apikey
apikey:
process.env.LND_TIER === 'production'
? 'land_sk_nMnUf8xiJJUjyw1l5QaIJJ4ZyrvPthzVmPAIG7TtJY7F9CW6lu' // prod key
: 'land_sk_DKeoYtaZZrYqJ9TMMiXe4BIQgJcZ0s3XAoB0JT3jv73FFqnr6k', // dev key
},
body: formData,
},
);
if (!fetchResponse.ok && fetchResponse.body) {
const reader = fetchResponse.body.getReader();
return new StreamingTextResponse(
new ReadableStream({
async start(controller) {
try {
const { done, value } = await reader?.read();
if (!done) {
const errorText = new TextDecoder().decode(value);
logger.error(session, { message: errorText }, request);
controller.error(new Error(`Response error: ${errorText}`));
}
} catch (e) {
logger.error(session, (e as Error).message, request);
}
},
}),
{
status: 400,
},
);
}
// const streamData = new experimental_StreamData();
if (!fetchResponse.body) {
return fetchResponse;
}
const encoder = new TextEncoder();
const decoder = new TextDecoder('utf-8');
let maxChunkSize = 0;
let buffer = '';
let time = Date.now();
const results: PrismaJson.MessageBody[] = [];
const stream = new ReadableStream({
async start(controller) {
const parseLine = async (
line: string,
ignoreParsingError = false,
): Promise<{ data?: PrismaJson.MessageBody; error?: Error }> => {
let msg = null;
try {
msg = JSON.parse(line);
} catch (e) {
if (ignoreParsingError) return {};
else {
return { error: e as Error };
}
}
if (!msg) return {};
try {
const modifiedMsg = await modifyCodePayload(
{
...msg,
timestamp: new Date(),
},
messageId,
chatId,
user,
);
return { data: modifiedMsg };
} catch (e) {
return { error: e as Error };
}
};
const processChunk = async (lines: string[]) => {
if (lines.length === 0) {
if (Date.now() - time > TIMEOUT_MILI_SECONDS) {
results.push(FINAL_TIMEOUT_ERROR);
// https://github.com/vercel/ai/blob/f7002ad2c5aa58ce6ed83e8d31fe22f71ebdb7d7/packages/ui-utils/src/stream-parts.ts#L62
controller.enqueue(
'2:' +
encoder.encode(JSON.stringify(FINAL_TIMEOUT_ERROR) + '\n'),
);
return { done: true, reason: 'timeout' };
}
} else {
time = Date.now();
}
buffer = lines.pop() ?? ''; // Save the last incomplete line back to the buffer
for (let line of lines) {
const { data: parsedMsg, error } = await parseLine(line);
if (error) {
results.push({
type: 'final_error',
status: 'failed',
payload: {
name: 'ParseError',
value: line,
traceback_raw: [],
},
});
return { done: true, reason: 'api_error', error };
} else if (parsedMsg) {
results.push(parsedMsg);
controller.enqueue(
encoder.encode('2:' + JSON.stringify([parsedMsg]) + '\n'),
);
if (parsedMsg.type === 'final_code') {
return { done: true, reason: 'agent_concluded' };
} else if (parsedMsg.type === 'final_error') {
return {
done: true,
reason: 'agent_error',
error: parsedMsg.payload,
};
}
} else {
controller.enqueue(encoder.encode(''));
}
}
if (buffer) {
const { data: parsedBuffer, error } = await parseLine(buffer, true);
if (error) {
results.push({
type: 'final_error',
status: 'failed',
payload: {
name: 'ParseError',
value: buffer,
traceback_raw: [],
},
});
return { done: true, reason: 'api_error', error };
} else if (parsedBuffer) {
buffer = '';
results.push(parsedBuffer);
controller.enqueue(
encoder.encode('2:' + JSON.stringify([parsedBuffer]) + '\n'),
);
if (parsedBuffer.type === 'final_code') {
return { done: true, reason: 'agent_concluded' };
} else if (parsedBuffer.type === 'final_error') {
return {
done: true,
reason: 'agent_error',
error: parsedBuffer.payload,
};
}
} else {
controller.enqueue(encoder.encode(''));
}
}
return { done: false };
};
// const parser = createParser(streamParser);
for await (const chunk of fetchResponse.body as any) {
const data = decoder.decode(chunk);
buffer += data;
maxChunkSize = Math.max(data.length, maxChunkSize);
const lines = buffer
.split('\n')
.filter(line => line.trim().length > 0);
const { done, reason, error } = await processChunk(lines);
if (done) {
const processMsgs = results.filter(
res => res.type !== 'final_code',
) as PrismaJson.AgentResponseBodies;
await dbPostUpdateMessageResponse(messageId, {
response: processMsgs.map(res => JSON.stringify(res)).join('\n'),
result: results.find(
res => res.type === 'final_code',
) as PrismaJson.FinalCodeBody,
responseBody: processMsgs,
});
logger.info(
session,
{
message: 'Streaming ended',
maxChunkSize,
reason,
error,
},
request,
'__AGENT_DONE',
);
controller.close();
}
}
},
});
return new Response(stream, {
headers: {
'Content-Type': 'application/x-ndjson',
},
});
},
);