Spaces:
Running
Running
| /** | |
| * Copyright 2024 Google LLC | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| import { Content, GenerativeContentBlob, Part } from "@google/generative-ai"; | |
| import { EventEmitter } from "eventemitter3"; | |
| import { difference } from "lodash"; | |
| import { | |
| ClientContentMessage, | |
| isInterrupted, | |
| isModelTurn, | |
| isServerContentMessage, | |
| isSetupCompleteMessage, | |
| isToolCallCancellationMessage, | |
| isToolCallMessage, | |
| isTurnComplete, | |
| LiveIncomingMessage, | |
| ModelTurn, | |
| RealtimeInputMessage, | |
| ServerContent, | |
| SetupMessage, | |
| StreamingLog, | |
| ToolCall, | |
| ToolCallCancellation, | |
| ToolResponseMessage, | |
| type LiveConfig, | |
| } from "../multimodal-live-types"; | |
| import { blobToJSON, base64ToArrayBuffer } from "./utils"; | |
| /** | |
| * the events that this client will emit | |
| */ | |
| interface MultimodalLiveClientEventTypes { | |
| open: () => void; | |
| log: (log: StreamingLog) => void; | |
| close: (event: CloseEvent) => void; | |
| audio: (data: ArrayBuffer) => void; | |
| content: (data: ServerContent) => void; | |
| interrupted: () => void; | |
| setupcomplete: () => void; | |
| turncomplete: () => void; | |
| toolcall: (toolCall: ToolCall) => void; | |
| toolcallcancellation: (toolcallCancellation: ToolCallCancellation) => void; | |
| } | |
| export type MultimodalLiveAPIClientConnection = { | |
| url?: string; | |
| apiKey?: string; | |
| }; | |
| /** | |
| * A event-emitting class that manages the connection to the websocket and emits | |
| * events to the rest of the application. | |
| * If you dont want to use react you can still use this. | |
| */ | |
| export class MultimodalLiveClient extends EventEmitter<MultimodalLiveClientEventTypes> { | |
| public ws: WebSocket | null = null; | |
| protected config: LiveConfig | null = null; | |
| public url: string; | |
| constructor({ url, apiKey }: MultimodalLiveAPIClientConnection = {}) { | |
| super(); | |
| console.log('π§ Initializing MultimodalLiveClient with URL:', url || `${window.location.protocol === 'https:' ? 'wss:' : 'ws:'}//${window.location.host}/ws`); | |
| this.url = url || `${window.location.protocol === 'https:' ? 'wss:' : 'ws:'}//${window.location.host}/ws`; | |
| this.send = this.send.bind(this); | |
| } | |
| log(type: string, message: StreamingLog["message"]) { | |
| const log: StreamingLog = { | |
| date: new Date(), | |
| type, | |
| message, | |
| }; | |
| this.emit("log", log); | |
| } | |
| connect(config: LiveConfig): Promise<boolean> { | |
| console.log('π Attempting WebSocket connection to:', this.url); | |
| this.config = config; | |
| console.log('π MultimodalLiveClient: Starting WebSocket connection to:', this.url); | |
| const ws = new WebSocket(this.url); | |
| ws.addEventListener("message", async (evt: MessageEvent) => { | |
| console.log('π¨ Received WebSocket message:', evt.data instanceof Blob ? 'Blob data' : evt.data); | |
| if (evt.data instanceof Blob) { | |
| console.log('π© MultimodalLiveClient: Received blob message'); | |
| this.receive(evt.data); | |
| } else { | |
| console.log("non blob message", evt); | |
| } | |
| }); | |
| return new Promise((resolve, reject) => { | |
| const onError = (ev: Event) => { | |
| this.disconnect(ws); | |
| const message = `Could not connect to "${this.url}"`; | |
| this.log(`server.${ev.type}`, message); | |
| reject(new Error(message)); | |
| }; | |
| ws.addEventListener("error", onError); | |
| ws.addEventListener("open", (ev: Event) => { | |
| console.log('β WebSocket connection opened successfully'); | |
| if (!this.config) { | |
| reject("Invalid config sent to `connect(config)`"); | |
| return; | |
| } | |
| console.log('β¨ MultimodalLiveClient: WebSocket connection established'); | |
| this.log(`client.${ev.type}`, `connected to socket`); | |
| this.emit("open"); | |
| this.ws = ws; | |
| const setupMessage: SetupMessage = { | |
| setup: this.config, | |
| }; | |
| this._sendDirect(setupMessage); | |
| this.log("client.send", "setup"); | |
| ws.removeEventListener("error", onError); | |
| ws.addEventListener("close", (ev: CloseEvent) => { | |
| console.log(ev); | |
| this.disconnect(ws); | |
| let reason = ev.reason || ""; | |
| if (reason.toLowerCase().includes("error")) { | |
| const prelude = "ERROR]"; | |
| const preludeIndex = reason.indexOf(prelude); | |
| if (preludeIndex > 0) { | |
| reason = reason.slice( | |
| preludeIndex + prelude.length + 1, | |
| Infinity, | |
| ); | |
| } | |
| } | |
| console.log('π Close reason:', reason || 'No reason provided'); | |
| this.log( | |
| `server.${ev.type}`, | |
| `disconnected ${reason ? `with reason: ${reason}` : ``}`, | |
| ); | |
| this.emit("close", ev); | |
| }); | |
| resolve(true); | |
| }); | |
| }); | |
| } | |
| disconnect(ws?: WebSocket) { | |
| console.log('π Attempting to disconnect WebSocket'); | |
| // could be that this is an old websocket and theres already a new instance | |
| // only close it if its still the correct reference | |
| if ((!ws || this.ws === ws) && this.ws) { | |
| console.log('π Closing WebSocket connection'); | |
| this.ws.close(); | |
| this.ws = null; | |
| this.log("client.close", `Disconnected`); | |
| return true; | |
| } | |
| console.log('β οΈ No active WebSocket to disconnect'); | |
| return false; | |
| } | |
| protected async receive(blob: Blob) { | |
| const response: LiveIncomingMessage = (await blobToJSON( | |
| blob, | |
| )) as LiveIncomingMessage; | |
| console.log('π₯ Received message:', response); | |
| if (isToolCallMessage(response)) { | |
| console.log('π οΈ MultimodalLiveClient: Received tool call'); | |
| this.log("server.toolCall", response); | |
| this.emit("toolcall", response.toolCall); | |
| return; | |
| } | |
| if (isToolCallCancellationMessage(response)) { | |
| console.log('π« MultimodalLiveClient: Received tool call cancellation'); | |
| this.log("receive.toolCallCancellation", response); | |
| this.emit("toolcallcancellation", response.toolCallCancellation); | |
| return; | |
| } | |
| if (isSetupCompleteMessage(response)) { | |
| console.log('π MultimodalLiveClient: Setup complete received'); | |
| this.log("server.send", "setupComplete"); | |
| this.emit("setupcomplete"); | |
| return; | |
| } | |
| // this json also might be `contentUpdate { interrupted: true }` | |
| // or contentUpdate { end_of_turn: true } | |
| if (isServerContentMessage(response)) { | |
| const { serverContent } = response; | |
| if (isInterrupted(serverContent)) { | |
| this.log("receive.serverContent", "interrupted"); | |
| this.emit("interrupted"); | |
| return; | |
| } | |
| if (isTurnComplete(serverContent)) { | |
| this.log("server.send", "turnComplete"); | |
| this.emit("turncomplete"); | |
| //plausible theres more to the message, continue | |
| } | |
| if (isModelTurn(serverContent)) { | |
| let parts: Part[] = serverContent.modelTurn.parts; | |
| // when its audio that is returned for modelTurn | |
| const audioParts = parts.filter( | |
| (p) => p.inlineData && p.inlineData.mimeType.startsWith("audio/pcm"), | |
| ); | |
| const base64s = audioParts.map((p) => p.inlineData?.data); | |
| // strip the audio parts out of the modelTurn | |
| const otherParts = difference(parts, audioParts); | |
| // console.log("otherParts", otherParts); | |
| base64s.forEach((b64) => { | |
| if (b64) { | |
| const data = base64ToArrayBuffer(b64); | |
| this.emit("audio", data); | |
| this.log(`server.audio`, `buffer (${data.byteLength})`); | |
| } | |
| }); | |
| if (!otherParts.length) { | |
| return; | |
| } | |
| parts = otherParts; | |
| const content: ModelTurn = { modelTurn: { parts } }; | |
| this.emit("content", content); | |
| this.log(`server.content`, response); | |
| } | |
| } else { | |
| console.log("received unmatched message", response); | |
| } | |
| } | |
| /** | |
| * send realtimeInput, this is base64 chunks of "audio/pcm" and/or "image/jpg" | |
| */ | |
| sendRealtimeInput(chunks: GenerativeContentBlob[]) { | |
| let hasAudio = false; | |
| let hasVideo = false; | |
| for (let i = 0; i < chunks.length; i++) { | |
| const ch = chunks[i]; | |
| if (ch.mimeType.includes("audio")) { | |
| hasAudio = true; | |
| } | |
| if (ch.mimeType.includes("image")) { | |
| hasVideo = true; | |
| } | |
| if (hasAudio && hasVideo) { | |
| break; | |
| } | |
| } | |
| const message = | |
| hasAudio && hasVideo | |
| ? "audio + video" | |
| : hasAudio | |
| ? "audio" | |
| : hasVideo | |
| ? "video" | |
| : "unknown"; | |
| const data: RealtimeInputMessage = { | |
| realtimeInput: { | |
| mediaChunks: chunks, | |
| }, | |
| }; | |
| this._sendDirect(data); | |
| this.log(`client.realtimeInput`, message); | |
| } | |
| /** | |
| * send a response to a function call and provide the id of the functions you are responding to | |
| */ | |
| sendToolResponse(toolResponse: ToolResponseMessage["toolResponse"]) { | |
| const message: ToolResponseMessage = { | |
| toolResponse, | |
| }; | |
| this._sendDirect(message); | |
| this.log(`client.toolResponse`, message); | |
| } | |
| /** | |
| * send normal content parts such as { text } | |
| */ | |
| send(parts: Part | Part[], turnComplete: boolean = true) { | |
| parts = Array.isArray(parts) ? parts : [parts]; | |
| const content: Content = { | |
| role: "user", | |
| parts, | |
| }; | |
| const clientContentRequest: ClientContentMessage = { | |
| clientContent: { | |
| turns: [content], | |
| turnComplete, | |
| }, | |
| }; | |
| this._sendDirect(clientContentRequest); | |
| this.log(`client.send`, clientContentRequest); | |
| } | |
| /** | |
| * used internally to send all messages | |
| * don't use directly unless trying to send an unsupported message type | |
| */ | |
| _sendDirect(request: object) { | |
| if (!this.ws) { | |
| throw new Error("WebSocket is not connected"); | |
| } | |
| const str = JSON.stringify(request); | |
| this.ws.send(str); | |
| } | |
| } | |