ANIMA-Web-LLM / src /index.ts
Severian's picture
Update src/index.ts
6f06af9
import {
ChatInterface,
ChatModule,
ChatRestModule,
ChatWorkerClient,
} from "@mlc-ai/web-llm";
function getElementAndCheck(id: string): HTMLElement {
const element = document.getElementById(id);
if (element == null) {
throw Error("Cannot find element " + id);
}
return element;
}
const appConfig = {
model_list: [
{
model_url:
"https://huggingface.co/hrishioa/wasm-ANIMA-Phi-Neptune-Mistral-7B-q4f32_1/resolve/main/params/",
local_id: "ANIMA-Phi-Neptune-Mistral-7B-q4f32_1",
},
],
model_lib_map: {
"ANIMA-Phi-Neptune-Mistral-7B-q4f32_1":
"https://huggingface.co/hrishioa/wasm-ANIMA-Phi-Neptune-Mistral-7B-q4f32_1/resolve/main/ANIMA-Phi-Neptune-Mistral-7B-q4f32_1-webgpu.wasm",
},
use_web_worker: true,
};
class ChatUI {
private uiChat: HTMLElement;
private uiChatInput: HTMLInputElement;
private uiChatInfoLabel: HTMLLabelElement;
private chat: ChatInterface;
private localChat: ChatInterface;
private config = appConfig;
private selectedModel: string;
private chatLoaded = false;
private requestInProgress = false;
// We use a request chain to ensure that
// all requests send to chat are sequentialized
private chatRequestChain: Promise<void> = Promise.resolve();
constructor(chat: ChatInterface, localChat: ChatInterface) {
// use web worker to run chat generation in background
this.chat = chat;
this.localChat = localChat;
// get the elements
this.uiChat = getElementAndCheck("chatui-chat");
this.uiChatInput = getElementAndCheck("chatui-input") as HTMLInputElement;
this.uiChatInfoLabel = getElementAndCheck(
"chatui-info-label"
) as HTMLLabelElement;
// register event handlers
getElementAndCheck("chatui-reset-btn").onclick = () => {
this.onReset();
};
getElementAndCheck("chatui-send-btn").onclick = () => {
this.onGenerate();
};
// TODO: find other alternative triggers
getElementAndCheck("chatui-input").onkeypress = (event) => {
if (event.keyCode === 13) {
this.onGenerate();
}
};
const modelSelector = getElementAndCheck(
"chatui-select"
) as HTMLSelectElement;
for (let i = 0; i < this.config.model_list.length; ++i) {
const item = this.config.model_list[i];
const opt = document.createElement("option");
opt.value = item.local_id;
opt.innerHTML = item.local_id;
opt.selected = i == 0;
modelSelector.appendChild(opt);
}
// Append local server option to the model selector
const localServerOpt = document.createElement("option");
localServerOpt.value = "Local Server";
localServerOpt.innerHTML = "Local Server";
modelSelector.append(localServerOpt);
this.selectedModel = modelSelector.value;
modelSelector.onchange = () => {
this.onSelectChange(modelSelector);
};
}
/**
* Push a task to the execution queue.
*
* @param task The task to be executed;
*/
private pushTask(task: () => Promise<void>) {
const lastEvent = this.chatRequestChain;
this.chatRequestChain = lastEvent.then(task);
}
// Event handlers
// all event handler pushes the tasks to a queue
// that get executed sequentially
// the tasks previous tasks, which causes them to early stop
// can be interrupted by chat.interruptGenerate
private async onGenerate() {
if (this.requestInProgress) {
return;
}
this.pushTask(async () => {
await this.asyncGenerate();
});
}
private async onSelectChange(modelSelector: HTMLSelectElement) {
if (this.requestInProgress) {
// interrupt previous generation if any
this.chat.interruptGenerate();
}
// try reset after previous requests finishes
this.pushTask(async () => {
await this.chat.resetChat();
this.resetChatHistory();
await this.unloadChat();
this.selectedModel = modelSelector.value;
await this.asyncInitChat();
});
}
private async onReset() {
if (this.requestInProgress) {
// interrupt previous generation if any
this.chat.interruptGenerate();
}
// try reset after previous requests finishes
this.pushTask(async () => {
await this.chat.resetChat();
this.resetChatHistory();
});
}
// Internal helper functions
private appendMessage(kind, text) {
if (kind == "init") {
text = "[System Initalize] " + text;
}
if (this.uiChat === undefined) {
throw Error("cannot find ui chat");
}
const msg = `
<div class="msg ${kind}-msg">
<div class="msg-bubble">
<div class="msg-text">${text}</div>
</div>
</div>
`;
this.uiChat.insertAdjacentHTML("beforeend", msg);
this.uiChat.scrollTo(0, this.uiChat.scrollHeight);
}
private updateLastMessage(kind, text) {
if (kind == "init") {
text = "[System Initalize] " + text;
}
if (this.uiChat === undefined) {
throw Error("cannot find ui chat");
}
const matches = this.uiChat.getElementsByClassName(`msg ${kind}-msg`);
if (matches.length == 0) throw Error(`${kind} message do not exist`);
const msg = matches[matches.length - 1];
const msgText = msg.getElementsByClassName("msg-text");
if (msgText.length != 1) throw Error("Expect msg-text");
if (msgText[0].innerHTML == text) return;
const list = text.split("\n").map((t) => {
const item = document.createElement("div");
item.textContent = t;
return item;
});
msgText[0].innerHTML = "";
list.forEach((item) => msgText[0].append(item));
this.uiChat.scrollTo(0, this.uiChat.scrollHeight);
}
private resetChatHistory() {
const clearTags = ["left", "right", "init", "error"];
for (const tag of clearTags) {
// need to unpack to list so the iterator don't get affected by mutation
const matches = [...this.uiChat.getElementsByClassName(`msg ${tag}-msg`)];
for (const item of matches) {
this.uiChat.removeChild(item);
}
}
if (this.uiChatInfoLabel !== undefined) {
this.uiChatInfoLabel.innerHTML = "";
}
}
private async asyncInitChat() {
if (this.chatLoaded) return;
this.requestInProgress = true;
this.appendMessage("init", "");
const initProgressCallback = (report) => {
this.updateLastMessage("init", report.text);
};
this.chat.setInitProgressCallback(initProgressCallback);
try {
if (this.selectedModel != "Local Server") {
await this.chat.reload(this.selectedModel, undefined, this.config);
}
} catch (err) {
this.appendMessage("error", "Init error, " + err.toString());
console.log(err.stack);
this.unloadChat();
this.requestInProgress = false;
return;
}
this.requestInProgress = false;
this.chatLoaded = true;
}
private async unloadChat() {
await this.chat.unload();
this.chatLoaded = false;
}
/**
* Run generate
*/
private async asyncGenerate() {
await this.asyncInitChat();
this.requestInProgress = true;
const prompt = this.uiChatInput.value;
if (prompt == "") {
this.requestInProgress = false;
return;
}
this.appendMessage("right", prompt);
this.uiChatInput.value = "";
this.uiChatInput.setAttribute("placeholder", "Generating...");
this.appendMessage("left", "");
const callbackUpdateResponse = (step, msg) => {
if (msg.length === 0) return this.chat.interruptGenerate();
this.updateLastMessage("left", msg);
};
try {
if (this.selectedModel == "Local Server") {
await this.localChat.generate(prompt, callbackUpdateResponse);
this.uiChatInfoLabel.innerHTML =
await this.localChat.runtimeStatsText();
} else {
await this.chat.generate(prompt, callbackUpdateResponse);
this.uiChatInfoLabel.innerHTML = await this.chat.runtimeStatsText();
}
} catch (err) {
this.appendMessage("error", "Generate error, " + err.toString());
console.log(err.stack);
await this.unloadChat();
}
this.uiChatInput.setAttribute("placeholder", "Enter your message...");
this.requestInProgress = false;
}
}
const useWebWorker = appConfig.use_web_worker;
let chat: ChatInterface;
let localChat: ChatInterface;
if (useWebWorker) {
chat = new ChatWorkerClient(
new Worker(new URL("./worker.ts", import.meta.url), { type: "module" })
);
localChat = new ChatRestModule();
} else {
chat = new ChatModule();
localChat = new ChatRestModule();
}
new ChatUI(chat, localChat);