|
|
|
import type { |
|
Status, |
|
Payload, |
|
EventType, |
|
ListenerMap, |
|
SubmitReturn, |
|
EventListener, |
|
Event, |
|
JsApiData, |
|
EndpointInfo, |
|
ApiInfo, |
|
Config, |
|
Dependency |
|
} from "../types"; |
|
|
|
import { skip_queue, post_message } from "../helpers/data"; |
|
import { resolve_root } from "../helpers/init_helpers"; |
|
import { |
|
handle_message, |
|
map_data_to_params, |
|
process_endpoint |
|
} from "../helpers/api_info"; |
|
import semiver from "semiver"; |
|
import { BROKEN_CONNECTION_MSG, QUEUE_FULL_MSG } from "../constants"; |
|
import { apply_diff_stream, close_stream } from "./stream"; |
|
import { Client } from "../client"; |
|
|
|
export function submit( |
|
this: Client, |
|
endpoint: string | number, |
|
data: unknown[] | Record<string, unknown>, |
|
event_data?: unknown, |
|
trigger_id?: number | null |
|
): SubmitReturn { |
|
try { |
|
const { hf_token } = this.options; |
|
const { |
|
fetch, |
|
app_reference, |
|
config, |
|
session_hash, |
|
api_info, |
|
api_map, |
|
stream_status, |
|
pending_stream_messages, |
|
pending_diff_streams, |
|
event_callbacks, |
|
unclosed_events, |
|
post_data |
|
} = this; |
|
|
|
if (!api_info) throw new Error("No API found"); |
|
if (!config) throw new Error("Could not resolve app config"); |
|
|
|
let { fn_index, endpoint_info, dependency } = get_endpoint_info( |
|
api_info, |
|
endpoint, |
|
api_map, |
|
config |
|
); |
|
|
|
let resolved_data = map_data_to_params(data, api_info); |
|
|
|
let websocket: WebSocket; |
|
let stream: EventSource | null; |
|
let protocol = config.protocol ?? "ws"; |
|
|
|
const _endpoint = typeof endpoint === "number" ? "/predict" : endpoint; |
|
let payload: Payload; |
|
let event_id: string | null = null; |
|
let complete: Status | undefined | false = false; |
|
const listener_map: ListenerMap<EventType> = {}; |
|
let last_status: Record<string, Status["stage"]> = {}; |
|
let url_params = |
|
typeof window !== "undefined" |
|
? new URLSearchParams(window.location.search).toString() |
|
: ""; |
|
|
|
|
|
function fire_event<K extends EventType>(event: Event<K>): void { |
|
const narrowed_listener_map: ListenerMap<K> = listener_map; |
|
const listeners = narrowed_listener_map[event.type] || []; |
|
listeners?.forEach((l) => l(event)); |
|
} |
|
|
|
function on<K extends EventType>( |
|
eventType: K, |
|
listener: EventListener<K> |
|
): SubmitReturn { |
|
const narrowed_listener_map: ListenerMap<K> = listener_map; |
|
const listeners = narrowed_listener_map[eventType] || []; |
|
narrowed_listener_map[eventType] = listeners; |
|
listeners?.push(listener); |
|
|
|
return { on, off, cancel, destroy }; |
|
} |
|
|
|
function off<K extends EventType>( |
|
eventType: K, |
|
listener: EventListener<K> |
|
): SubmitReturn { |
|
const narrowed_listener_map: ListenerMap<K> = listener_map; |
|
let listeners = narrowed_listener_map[eventType] || []; |
|
listeners = listeners?.filter((l) => l !== listener); |
|
narrowed_listener_map[eventType] = listeners; |
|
return { on, off, cancel, destroy }; |
|
} |
|
|
|
async function cancel(): Promise<void> { |
|
const _status: Status = { |
|
stage: "complete", |
|
queue: false, |
|
time: new Date() |
|
}; |
|
complete = _status; |
|
fire_event({ |
|
..._status, |
|
type: "status", |
|
endpoint: _endpoint, |
|
fn_index: fn_index |
|
}); |
|
|
|
let cancel_request = {}; |
|
if (protocol === "ws") { |
|
if (websocket && websocket.readyState === 0) { |
|
websocket.addEventListener("open", () => { |
|
websocket.close(); |
|
}); |
|
} else { |
|
websocket.close(); |
|
} |
|
cancel_request = { fn_index, session_hash }; |
|
} else { |
|
stream?.close(); |
|
cancel_request = { event_id }; |
|
} |
|
|
|
try { |
|
if (!config) { |
|
throw new Error("Could not resolve app config"); |
|
} |
|
|
|
await fetch(`${config.root}/reset`, { |
|
headers: { "Content-Type": "application/json" }, |
|
method: "POST", |
|
body: JSON.stringify(cancel_request) |
|
}); |
|
} catch (e) { |
|
console.warn( |
|
"The `/reset` endpoint could not be called. Subsequent endpoint results may be unreliable." |
|
); |
|
} |
|
} |
|
|
|
function destroy(): void { |
|
for (const event_type in listener_map) { |
|
listener_map && |
|
listener_map[event_type as "data" | "status"]?.forEach((fn) => { |
|
off(event_type as "data" | "status", fn); |
|
}); |
|
} |
|
} |
|
|
|
this.handle_blob(config.root, resolved_data, endpoint_info).then( |
|
async (_payload) => { |
|
payload = { |
|
data: _payload || [], |
|
event_data, |
|
fn_index, |
|
trigger_id |
|
}; |
|
if (skip_queue(fn_index, config)) { |
|
fire_event({ |
|
type: "status", |
|
endpoint: _endpoint, |
|
stage: "pending", |
|
queue: false, |
|
fn_index, |
|
time: new Date() |
|
}); |
|
|
|
post_data( |
|
`${config.root}/run${ |
|
_endpoint.startsWith("/") ? _endpoint : `/${_endpoint}` |
|
}${url_params ? "?" + url_params : ""}`, |
|
{ |
|
...payload, |
|
session_hash |
|
} |
|
) |
|
.then(([output, status_code]: any) => { |
|
const data = output.data; |
|
if (status_code == 200) { |
|
fire_event({ |
|
type: "data", |
|
endpoint: _endpoint, |
|
fn_index, |
|
data: data, |
|
time: new Date(), |
|
event_data, |
|
trigger_id |
|
}); |
|
|
|
fire_event({ |
|
type: "status", |
|
endpoint: _endpoint, |
|
fn_index, |
|
stage: "complete", |
|
eta: output.average_duration, |
|
queue: false, |
|
time: new Date() |
|
}); |
|
} else { |
|
fire_event({ |
|
type: "status", |
|
stage: "error", |
|
endpoint: _endpoint, |
|
fn_index, |
|
message: output.error, |
|
queue: false, |
|
time: new Date() |
|
}); |
|
} |
|
}) |
|
.catch((e) => { |
|
fire_event({ |
|
type: "status", |
|
stage: "error", |
|
message: e.message, |
|
endpoint: _endpoint, |
|
fn_index, |
|
queue: false, |
|
time: new Date() |
|
}); |
|
}); |
|
} else if (protocol == "ws") { |
|
const { ws_protocol, host } = await process_endpoint( |
|
app_reference, |
|
hf_token |
|
); |
|
|
|
fire_event({ |
|
type: "status", |
|
stage: "pending", |
|
queue: true, |
|
endpoint: _endpoint, |
|
fn_index, |
|
time: new Date() |
|
}); |
|
|
|
let url = new URL( |
|
`${ws_protocol}://${resolve_root( |
|
host, |
|
config.path as string, |
|
true |
|
)}/queue/join${url_params ? "?" + url_params : ""}` |
|
); |
|
|
|
if (this.jwt) { |
|
url.searchParams.set("__sign", this.jwt); |
|
} |
|
|
|
websocket = new WebSocket(url); |
|
|
|
websocket.onclose = (evt) => { |
|
if (!evt.wasClean) { |
|
fire_event({ |
|
type: "status", |
|
stage: "error", |
|
broken: true, |
|
message: BROKEN_CONNECTION_MSG, |
|
queue: true, |
|
endpoint: _endpoint, |
|
fn_index, |
|
time: new Date() |
|
}); |
|
} |
|
}; |
|
|
|
websocket.onmessage = function (event) { |
|
const _data = JSON.parse(event.data); |
|
const { type, status, data } = handle_message( |
|
_data, |
|
last_status[fn_index] |
|
); |
|
|
|
if (type === "update" && status && !complete) { |
|
|
|
fire_event({ |
|
type: "status", |
|
endpoint: _endpoint, |
|
fn_index, |
|
time: new Date(), |
|
...status |
|
}); |
|
if (status.stage === "error") { |
|
websocket.close(); |
|
} |
|
} else if (type === "hash") { |
|
websocket.send(JSON.stringify({ fn_index, session_hash })); |
|
return; |
|
} else if (type === "data") { |
|
websocket.send(JSON.stringify({ ...payload, session_hash })); |
|
} else if (type === "complete") { |
|
complete = status; |
|
} else if (type === "log") { |
|
fire_event({ |
|
type: "log", |
|
log: data.log, |
|
level: data.level, |
|
endpoint: _endpoint, |
|
fn_index |
|
}); |
|
} else if (type === "generating") { |
|
fire_event({ |
|
type: "status", |
|
time: new Date(), |
|
...status, |
|
stage: status?.stage!, |
|
queue: true, |
|
endpoint: _endpoint, |
|
fn_index |
|
}); |
|
} |
|
if (data) { |
|
fire_event({ |
|
type: "data", |
|
time: new Date(), |
|
data: data.data, |
|
endpoint: _endpoint, |
|
fn_index, |
|
event_data, |
|
trigger_id |
|
}); |
|
|
|
if (complete) { |
|
fire_event({ |
|
type: "status", |
|
time: new Date(), |
|
...complete, |
|
stage: status?.stage!, |
|
queue: true, |
|
endpoint: _endpoint, |
|
fn_index |
|
}); |
|
websocket.close(); |
|
} |
|
} |
|
}; |
|
|
|
|
|
|
|
if (semiver(config.version || "2.0.0", "3.6") < 0) { |
|
addEventListener("open", () => |
|
websocket.send(JSON.stringify({ hash: session_hash })) |
|
); |
|
} |
|
} else if (protocol == "sse") { |
|
fire_event({ |
|
type: "status", |
|
stage: "pending", |
|
queue: true, |
|
endpoint: _endpoint, |
|
fn_index, |
|
time: new Date() |
|
}); |
|
var params = new URLSearchParams({ |
|
fn_index: fn_index.toString(), |
|
session_hash: session_hash |
|
}).toString(); |
|
let url = new URL( |
|
`${config.root}/queue/join?${ |
|
url_params ? url_params + "&" : "" |
|
}${params}` |
|
); |
|
|
|
if (this.jwt) { |
|
url.searchParams.set("__sign", this.jwt); |
|
} |
|
|
|
stream = await this.stream(url); |
|
|
|
if (!stream) { |
|
return Promise.reject( |
|
new Error("Cannot connect to SSE endpoint: " + url.toString()) |
|
); |
|
} |
|
|
|
stream.onmessage = async function (event: MessageEvent) { |
|
const _data = JSON.parse(event.data); |
|
const { type, status, data } = handle_message( |
|
_data, |
|
last_status[fn_index] |
|
); |
|
|
|
if (type === "update" && status && !complete) { |
|
|
|
fire_event({ |
|
type: "status", |
|
endpoint: _endpoint, |
|
fn_index, |
|
time: new Date(), |
|
...status |
|
}); |
|
if (status.stage === "error") { |
|
stream?.close(); |
|
} |
|
} else if (type === "data") { |
|
event_id = _data.event_id as string; |
|
let [_, status] = await post_data(`${config.root}/queue/data`, { |
|
...payload, |
|
session_hash, |
|
event_id |
|
}); |
|
if (status !== 200) { |
|
fire_event({ |
|
type: "status", |
|
stage: "error", |
|
message: BROKEN_CONNECTION_MSG, |
|
queue: true, |
|
endpoint: _endpoint, |
|
fn_index, |
|
time: new Date() |
|
}); |
|
stream?.close(); |
|
} |
|
} else if (type === "complete") { |
|
complete = status; |
|
} else if (type === "log") { |
|
fire_event({ |
|
type: "log", |
|
log: data.log, |
|
level: data.level, |
|
endpoint: _endpoint, |
|
fn_index |
|
}); |
|
} else if (type === "generating") { |
|
fire_event({ |
|
type: "status", |
|
time: new Date(), |
|
...status, |
|
stage: status?.stage!, |
|
queue: true, |
|
endpoint: _endpoint, |
|
fn_index |
|
}); |
|
} |
|
if (data) { |
|
fire_event({ |
|
type: "data", |
|
time: new Date(), |
|
data: data.data, |
|
endpoint: _endpoint, |
|
fn_index, |
|
event_data, |
|
trigger_id |
|
}); |
|
|
|
if (complete) { |
|
fire_event({ |
|
type: "status", |
|
time: new Date(), |
|
...complete, |
|
stage: status?.stage!, |
|
queue: true, |
|
endpoint: _endpoint, |
|
fn_index |
|
}); |
|
stream?.close(); |
|
} |
|
} |
|
}; |
|
} else if ( |
|
protocol == "sse_v1" || |
|
protocol == "sse_v2" || |
|
protocol == "sse_v2.1" || |
|
protocol == "sse_v3" |
|
) { |
|
|
|
|
|
fire_event({ |
|
type: "status", |
|
stage: "pending", |
|
queue: true, |
|
endpoint: _endpoint, |
|
fn_index, |
|
time: new Date() |
|
}); |
|
let hostname = ""; |
|
if (typeof window !== "undefined") { |
|
hostname = window?.location?.hostname; |
|
} |
|
|
|
let hfhubdev = "dev.spaces.huggingface.tech"; |
|
const origin = hostname.includes(".dev.") |
|
? `https://moon-${hostname.split(".")[1]}.${hfhubdev}` |
|
: `https://huggingface.co`; |
|
const zerogpu_auth_promise = |
|
dependency.zerogpu && window.parent != window && config.space_id |
|
? post_message<Headers>("zerogpu-headers", origin) |
|
: Promise.resolve(null); |
|
const post_data_promise = zerogpu_auth_promise.then((headers) => { |
|
return post_data( |
|
`${config.root}/queue/join?${url_params}`, |
|
{ |
|
...payload, |
|
session_hash |
|
}, |
|
headers |
|
); |
|
}); |
|
post_data_promise.then(async ([response, status]: any) => { |
|
if (status === 503) { |
|
fire_event({ |
|
type: "status", |
|
stage: "error", |
|
message: QUEUE_FULL_MSG, |
|
queue: true, |
|
endpoint: _endpoint, |
|
fn_index, |
|
time: new Date() |
|
}); |
|
} else if (status !== 200) { |
|
fire_event({ |
|
type: "status", |
|
stage: "error", |
|
message: BROKEN_CONNECTION_MSG, |
|
queue: true, |
|
endpoint: _endpoint, |
|
fn_index, |
|
time: new Date() |
|
}); |
|
} else { |
|
event_id = response.event_id as string; |
|
let callback = async function (_data: object): Promise<void> { |
|
try { |
|
const { type, status, data } = handle_message( |
|
_data, |
|
last_status[fn_index] |
|
); |
|
|
|
if (type == "heartbeat") { |
|
return; |
|
} |
|
|
|
if (type === "update" && status && !complete) { |
|
|
|
fire_event({ |
|
type: "status", |
|
endpoint: _endpoint, |
|
fn_index, |
|
time: new Date(), |
|
...status |
|
}); |
|
} else if (type === "complete") { |
|
complete = status; |
|
} else if (type == "unexpected_error") { |
|
console.error("Unexpected error", status?.message); |
|
fire_event({ |
|
type: "status", |
|
stage: "error", |
|
message: |
|
status?.message || "An Unexpected Error Occurred!", |
|
queue: true, |
|
endpoint: _endpoint, |
|
fn_index, |
|
time: new Date() |
|
}); |
|
} else if (type === "log") { |
|
fire_event({ |
|
type: "log", |
|
log: data.log, |
|
level: data.level, |
|
endpoint: _endpoint, |
|
fn_index |
|
}); |
|
return; |
|
} else if (type === "generating") { |
|
fire_event({ |
|
type: "status", |
|
time: new Date(), |
|
...status, |
|
stage: status?.stage!, |
|
queue: true, |
|
endpoint: _endpoint, |
|
fn_index |
|
}); |
|
if ( |
|
data && |
|
["sse_v2", "sse_v2.1", "sse_v3"].includes(protocol) |
|
) { |
|
apply_diff_stream(pending_diff_streams, event_id!, data); |
|
} |
|
} |
|
if (data) { |
|
fire_event({ |
|
type: "data", |
|
time: new Date(), |
|
data: data.data, |
|
endpoint: _endpoint, |
|
fn_index |
|
}); |
|
if (data.render_config) { |
|
fire_event({ |
|
type: "render", |
|
data: data.render_config, |
|
endpoint: _endpoint, |
|
fn_index |
|
}); |
|
} |
|
|
|
if (complete) { |
|
fire_event({ |
|
type: "status", |
|
time: new Date(), |
|
...complete, |
|
stage: status?.stage!, |
|
queue: true, |
|
endpoint: _endpoint, |
|
fn_index |
|
}); |
|
} |
|
} |
|
|
|
if ( |
|
status?.stage === "complete" || |
|
status?.stage === "error" |
|
) { |
|
if (event_callbacks[event_id!]) { |
|
delete event_callbacks[event_id!]; |
|
} |
|
if (event_id! in pending_diff_streams) { |
|
delete pending_diff_streams[event_id!]; |
|
} |
|
} |
|
} catch (e) { |
|
console.error("Unexpected client exception", e); |
|
fire_event({ |
|
type: "status", |
|
stage: "error", |
|
message: "An Unexpected Error Occurred!", |
|
queue: true, |
|
endpoint: _endpoint, |
|
fn_index, |
|
time: new Date() |
|
}); |
|
if (["sse_v2", "sse_v2.1"].includes(protocol)) { |
|
close_stream(stream_status, stream); |
|
stream_status.open = false; |
|
} |
|
} |
|
}; |
|
|
|
if (event_id in pending_stream_messages) { |
|
pending_stream_messages[event_id].forEach((msg) => |
|
callback(msg) |
|
); |
|
delete pending_stream_messages[event_id]; |
|
} |
|
|
|
event_callbacks[event_id] = callback; |
|
unclosed_events.add(event_id); |
|
if (!stream_status.open) { |
|
await this.open_stream(); |
|
} |
|
} |
|
}); |
|
} |
|
} |
|
); |
|
|
|
return { on, off, cancel, destroy }; |
|
} catch (error) { |
|
console.error("Submit function encountered an error:", error); |
|
throw error; |
|
} |
|
} |
|
|
|
function get_endpoint_info( |
|
api_info: ApiInfo<JsApiData>, |
|
endpoint: string | number, |
|
api_map: Record<string, number>, |
|
config: Config |
|
): { |
|
fn_index: number; |
|
endpoint_info: EndpointInfo<JsApiData>; |
|
dependency: Dependency; |
|
} { |
|
let fn_index: number; |
|
let endpoint_info: EndpointInfo<JsApiData>; |
|
let dependency: Dependency; |
|
|
|
if (typeof endpoint === "number") { |
|
fn_index = endpoint; |
|
endpoint_info = api_info.unnamed_endpoints[fn_index]; |
|
dependency = config.dependencies[endpoint]; |
|
} else { |
|
const trimmed_endpoint = endpoint.replace(/^\//, ""); |
|
|
|
fn_index = api_map[trimmed_endpoint]; |
|
endpoint_info = api_info.named_endpoints[endpoint.trim()]; |
|
dependency = config.dependencies[api_map[trimmed_endpoint]]; |
|
} |
|
|
|
if (typeof fn_index !== "number") { |
|
throw new Error( |
|
"There is no endpoint matching that name of fn_index matching that number." |
|
); |
|
} |
|
return { fn_index, endpoint_info, dependency }; |
|
} |
|
|