Spaces:
Sleeping
Sleeping
import { ApiPath, Google } from "@/app/constant"; | |
import { | |
ChatOptions, | |
getHeaders, | |
LLMApi, | |
LLMModel, | |
LLMUsage, | |
SpeechOptions, | |
} from "../api"; | |
import { | |
useAccessStore, | |
useAppConfig, | |
useChatStore, | |
usePluginStore, | |
ChatMessageTool, | |
} from "@/app/store"; | |
import { stream } from "@/app/utils/chat"; | |
import { getClientConfig } from "@/app/config/client"; | |
import { GEMINI_BASE_URL } from "@/app/constant"; | |
import { | |
getMessageTextContent, | |
getMessageImages, | |
isVisionModel, | |
getTimeoutMSByModel, | |
} from "@/app/utils"; | |
import { preProcessImageContent } from "@/app/utils/chat"; | |
import { nanoid } from "nanoid"; | |
import { RequestPayload } from "./openai"; | |
import { fetch } from "@/app/utils/stream"; | |
export class GeminiProApi implements LLMApi { | |
path(path: string, shouldStream = false): string { | |
const accessStore = useAccessStore.getState(); | |
let baseUrl = ""; | |
if (accessStore.useCustomConfig) { | |
baseUrl = accessStore.googleUrl; | |
} | |
const isApp = !!getClientConfig()?.isApp; | |
if (baseUrl.length === 0) { | |
baseUrl = isApp ? GEMINI_BASE_URL : ApiPath.Google; | |
} | |
if (baseUrl.endsWith("/")) { | |
baseUrl = baseUrl.slice(0, baseUrl.length - 1); | |
} | |
if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.Google)) { | |
baseUrl = "https://" + baseUrl; | |
} | |
console.log("[Proxy Endpoint] ", baseUrl, path); | |
let chatPath = [baseUrl, path].join("/"); | |
if (shouldStream) { | |
chatPath += chatPath.includes("?") ? "&alt=sse" : "?alt=sse"; | |
} | |
return chatPath; | |
} | |
extractMessage(res: any) { | |
console.log("[Response] gemini-pro response: ", res); | |
const getTextFromParts = (parts: any[]) => { | |
if (!Array.isArray(parts)) return ""; | |
return parts | |
.map((part) => part?.text || "") | |
.filter((text) => text.trim() !== "") | |
.join("\n\n"); | |
}; | |
let content = ""; | |
if (Array.isArray(res)) { | |
res.map((item) => { | |
content += getTextFromParts(item?.candidates?.at(0)?.content?.parts); | |
}); | |
} | |
return ( | |
getTextFromParts(res?.candidates?.at(0)?.content?.parts) || | |
content || //getTextFromParts(res?.at(0)?.candidates?.at(0)?.content?.parts) || | |
res?.error?.message || | |
"" | |
); | |
} | |
speech(options: SpeechOptions): Promise<ArrayBuffer> { | |
throw new Error("Method not implemented."); | |
} | |
async chat(options: ChatOptions): Promise<void> { | |
const apiClient = this; | |
let multimodal = false; | |
// try get base64image from local cache image_url | |
const _messages: ChatOptions["messages"] = []; | |
for (const v of options.messages) { | |
const content = await preProcessImageContent(v.content); | |
_messages.push({ role: v.role, content }); | |
} | |
const messages = _messages.map((v) => { | |
let parts: any[] = [{ text: getMessageTextContent(v) }]; | |
if (isVisionModel(options.config.model)) { | |
const images = getMessageImages(v); | |
if (images.length > 0) { | |
multimodal = true; | |
parts = parts.concat( | |
images.map((image) => { | |
const imageType = image.split(";")[0].split(":")[1]; | |
const imageData = image.split(",")[1]; | |
return { | |
inline_data: { | |
mime_type: imageType, | |
data: imageData, | |
}, | |
}; | |
}), | |
); | |
} | |
} | |
return { | |
role: v.role.replace("assistant", "model").replace("system", "user"), | |
parts: parts, | |
}; | |
}); | |
// google requires that role in neighboring messages must not be the same | |
for (let i = 0; i < messages.length - 1; ) { | |
// Check if current and next item both have the role "model" | |
if (messages[i].role === messages[i + 1].role) { | |
// Concatenate the 'parts' of the current and next item | |
messages[i].parts = messages[i].parts.concat(messages[i + 1].parts); | |
// Remove the next item | |
messages.splice(i + 1, 1); | |
} else { | |
// Move to the next item | |
i++; | |
} | |
} | |
// if (visionModel && messages.length > 1) { | |
// options.onError?.(new Error("Multiturn chat is not enabled for models/gemini-pro-vision")); | |
// } | |
const accessStore = useAccessStore.getState(); | |
const modelConfig = { | |
...useAppConfig.getState().modelConfig, | |
...useChatStore.getState().currentSession().mask.modelConfig, | |
...{ | |
model: options.config.model, | |
}, | |
}; | |
const requestPayload = { | |
contents: messages, | |
generationConfig: { | |
// stopSequences: [ | |
// "Title" | |
// ], | |
temperature: modelConfig.temperature, | |
maxOutputTokens: modelConfig.max_tokens, | |
topP: modelConfig.top_p, | |
// "topK": modelConfig.top_k, | |
}, | |
safetySettings: [ | |
{ | |
category: "HARM_CATEGORY_HARASSMENT", | |
threshold: accessStore.googleSafetySettings, | |
}, | |
{ | |
category: "HARM_CATEGORY_HATE_SPEECH", | |
threshold: accessStore.googleSafetySettings, | |
}, | |
{ | |
category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
threshold: accessStore.googleSafetySettings, | |
}, | |
{ | |
category: "HARM_CATEGORY_DANGEROUS_CONTENT", | |
threshold: accessStore.googleSafetySettings, | |
}, | |
], | |
}; | |
let shouldStream = !!options.config.stream; | |
const controller = new AbortController(); | |
options.onController?.(controller); | |
try { | |
// https://github.com/google-gemini/cookbook/blob/main/quickstarts/rest/Streaming_REST.ipynb | |
const chatPath = this.path( | |
Google.ChatPath(modelConfig.model), | |
shouldStream, | |
); | |
const chatPayload = { | |
method: "POST", | |
body: JSON.stringify(requestPayload), | |
signal: controller.signal, | |
headers: getHeaders(), | |
}; | |
const isThinking = options.config.model.includes("-thinking"); | |
// make a fetch request | |
const requestTimeoutId = setTimeout( | |
() => controller.abort(), | |
getTimeoutMSByModel(options.config.model), | |
); | |
if (shouldStream) { | |
const [tools, funcs] = usePluginStore | |
.getState() | |
.getAsTools( | |
useChatStore.getState().currentSession().mask?.plugin || [], | |
); | |
return stream( | |
chatPath, | |
requestPayload, | |
getHeaders(), | |
// @ts-ignore | |
tools.length > 0 | |
? // @ts-ignore | |
[{ functionDeclarations: tools.map((tool) => tool.function) }] | |
: [], | |
funcs, | |
controller, | |
// parseSSE | |
(text: string, runTools: ChatMessageTool[]) => { | |
// console.log("parseSSE", text, runTools); | |
const chunkJson = JSON.parse(text); | |
const functionCall = chunkJson?.candidates | |
?.at(0) | |
?.content.parts.at(0)?.functionCall; | |
if (functionCall) { | |
const { name, args } = functionCall; | |
runTools.push({ | |
id: nanoid(), | |
type: "function", | |
function: { | |
name, | |
arguments: JSON.stringify(args), // utils.chat call function, using JSON.parse | |
}, | |
}); | |
} | |
return chunkJson?.candidates | |
?.at(0) | |
?.content.parts?.map((part: { text: string }) => part.text) | |
.join("\n\n"); | |
}, | |
// processToolMessage, include tool_calls message and tool call results | |
( | |
requestPayload: RequestPayload, | |
toolCallMessage: any, | |
toolCallResult: any[], | |
) => { | |
// @ts-ignore | |
requestPayload?.contents?.splice( | |
// @ts-ignore | |
requestPayload?.contents?.length, | |
0, | |
{ | |
role: "model", | |
parts: toolCallMessage.tool_calls.map( | |
(tool: ChatMessageTool) => ({ | |
functionCall: { | |
name: tool?.function?.name, | |
args: JSON.parse(tool?.function?.arguments as string), | |
}, | |
}), | |
), | |
}, | |
// @ts-ignore | |
...toolCallResult.map((result) => ({ | |
role: "function", | |
parts: [ | |
{ | |
functionResponse: { | |
name: result.name, | |
response: { | |
name: result.name, | |
content: result.content, // TODO just text content... | |
}, | |
}, | |
}, | |
], | |
})), | |
); | |
}, | |
options, | |
); | |
} else { | |
const res = await fetch(chatPath, chatPayload); | |
clearTimeout(requestTimeoutId); | |
const resJson = await res.json(); | |
if (resJson?.promptFeedback?.blockReason) { | |
// being blocked | |
options.onError?.( | |
new Error( | |
"Message is being blocked for reason: " + | |
resJson.promptFeedback.blockReason, | |
), | |
); | |
} | |
const message = apiClient.extractMessage(resJson); | |
options.onFinish(message, res); | |
} | |
} catch (e) { | |
console.log("[Request] failed to make a chat request", e); | |
options.onError?.(e as Error); | |
} | |
} | |
usage(): Promise<LLMUsage> { | |
throw new Error("Method not implemented."); | |
} | |
async models(): Promise<LLMModel[]> { | |
return []; | |
} | |
} | |