|
import { create } from "zustand"; |
|
import { persist } from "zustand/middleware"; |
|
|
|
import { type ChatCompletionResponseMessage } from "openai"; |
|
import { |
|
ControllerPool, |
|
requestChatStream, |
|
requestWithPrompt, |
|
} from "../requests"; |
|
import { isMobileScreen, trimTopic } from "../utils"; |
|
|
|
import Locale from "../locales"; |
|
import { showToast } from "../components/ui-lib"; |
|
import { DEFAULT_CONFIG, ModelConfig, ModelType, useAppConfig } from "./config"; |
|
import { createEmptyMask, Mask } from "./mask"; |
|
import { StoreKey } from "../constant"; |
|
|
|
export type Message = ChatCompletionResponseMessage & { |
|
date: string; |
|
streaming?: boolean; |
|
isError?: boolean; |
|
id?: number; |
|
model?: ModelType; |
|
}; |
|
|
|
export function createMessage(override: Partial<Message>): Message { |
|
return { |
|
id: Date.now(), |
|
date: new Date().toLocaleString(), |
|
role: "user", |
|
content: "", |
|
...override, |
|
}; |
|
} |
|
|
|
export const ROLES: Message["role"][] = ["system", "user", "assistant"]; |
|
|
|
export interface ChatStat { |
|
tokenCount: number; |
|
wordCount: number; |
|
charCount: number; |
|
} |
|
|
|
export interface ChatSession { |
|
id: number; |
|
|
|
topic: string; |
|
|
|
memoryPrompt: string; |
|
messages: Message[]; |
|
stat: ChatStat; |
|
lastUpdate: number; |
|
lastSummarizeIndex: number; |
|
|
|
mask: Mask; |
|
} |
|
|
|
export const DEFAULT_TOPIC = Locale.Store.DefaultTopic; |
|
export const BOT_HELLO: Message = createMessage({ |
|
role: "assistant", |
|
content: Locale.Store.BotHello, |
|
}); |
|
|
|
function createEmptySession(): ChatSession { |
|
return { |
|
id: Date.now() + Math.random(), |
|
topic: DEFAULT_TOPIC, |
|
memoryPrompt: "", |
|
messages: [], |
|
stat: { |
|
tokenCount: 0, |
|
wordCount: 0, |
|
charCount: 0, |
|
}, |
|
lastUpdate: Date.now(), |
|
lastSummarizeIndex: 0, |
|
mask: createEmptyMask(), |
|
}; |
|
} |
|
|
|
interface ChatStore { |
|
sessions: ChatSession[]; |
|
currentSessionIndex: number; |
|
globalId: number; |
|
clearSessions: () => void; |
|
moveSession: (from: number, to: number) => void; |
|
selectSession: (index: number) => void; |
|
newSession: (mask?: Mask) => void; |
|
deleteSession: (index: number) => void; |
|
currentSession: () => ChatSession; |
|
onNewMessage: (message: Message) => void; |
|
onUserInput: (content: string) => Promise<void>; |
|
summarizeSession: () => void; |
|
updateStat: (message: Message) => void; |
|
updateCurrentSession: (updater: (session: ChatSession) => void) => void; |
|
updateMessage: ( |
|
sessionIndex: number, |
|
messageIndex: number, |
|
updater: (message?: Message) => void, |
|
) => void; |
|
resetSession: () => void; |
|
getMessagesWithMemory: () => Message[]; |
|
getMemoryPrompt: () => Message; |
|
|
|
clearAllData: () => void; |
|
} |
|
|
|
function countMessages(msgs: Message[]) { |
|
return msgs.reduce((pre, cur) => pre + cur.content.length, 0); |
|
} |
|
|
|
export const useChatStore = create<ChatStore>()( |
|
persist( |
|
(set, get) => ({ |
|
sessions: [createEmptySession()], |
|
currentSessionIndex: 0, |
|
globalId: 0, |
|
|
|
clearSessions() { |
|
set(() => ({ |
|
sessions: [createEmptySession()], |
|
currentSessionIndex: 0, |
|
})); |
|
}, |
|
|
|
selectSession(index: number) { |
|
set({ |
|
currentSessionIndex: index, |
|
}); |
|
}, |
|
|
|
moveSession(from: number, to: number) { |
|
set((state) => { |
|
const { sessions, currentSessionIndex: oldIndex } = state; |
|
|
|
|
|
const newSessions = [...sessions]; |
|
const session = newSessions[from]; |
|
newSessions.splice(from, 1); |
|
newSessions.splice(to, 0, session); |
|
|
|
|
|
let newIndex = oldIndex === from ? to : oldIndex; |
|
if (oldIndex > from && oldIndex <= to) { |
|
newIndex -= 1; |
|
} else if (oldIndex < from && oldIndex >= to) { |
|
newIndex += 1; |
|
} |
|
|
|
return { |
|
currentSessionIndex: newIndex, |
|
sessions: newSessions, |
|
}; |
|
}); |
|
}, |
|
|
|
newSession(mask) { |
|
const session = createEmptySession(); |
|
|
|
set(() => ({ globalId: get().globalId + 1 })); |
|
session.id = get().globalId; |
|
|
|
if (mask) { |
|
session.mask = { ...mask }; |
|
session.topic = mask.name; |
|
} |
|
|
|
set((state) => ({ |
|
currentSessionIndex: 0, |
|
sessions: [session].concat(state.sessions), |
|
})); |
|
}, |
|
|
|
deleteSession(index) { |
|
const deletingLastSession = get().sessions.length === 1; |
|
const deletedSession = get().sessions.at(index); |
|
|
|
if (!deletedSession) return; |
|
|
|
const sessions = get().sessions.slice(); |
|
sessions.splice(index, 1); |
|
|
|
let nextIndex = Math.min( |
|
get().currentSessionIndex, |
|
sessions.length - 1, |
|
); |
|
|
|
if (deletingLastSession) { |
|
nextIndex = 0; |
|
sessions.push(createEmptySession()); |
|
} |
|
|
|
|
|
const restoreState = { |
|
currentSessionIndex: get().currentSessionIndex, |
|
sessions: get().sessions.slice(), |
|
}; |
|
|
|
set(() => ({ |
|
currentSessionIndex: nextIndex, |
|
sessions, |
|
})); |
|
|
|
showToast( |
|
Locale.Home.DeleteToast, |
|
{ |
|
text: Locale.Home.Revert, |
|
onClick() { |
|
set(() => restoreState); |
|
}, |
|
}, |
|
5000, |
|
); |
|
}, |
|
|
|
currentSession() { |
|
let index = get().currentSessionIndex; |
|
const sessions = get().sessions; |
|
|
|
if (index < 0 || index >= sessions.length) { |
|
index = Math.min(sessions.length - 1, Math.max(0, index)); |
|
set(() => ({ currentSessionIndex: index })); |
|
} |
|
|
|
const session = sessions[index]; |
|
|
|
return session; |
|
}, |
|
|
|
onNewMessage(message) { |
|
get().updateCurrentSession((session) => { |
|
session.lastUpdate = Date.now(); |
|
}); |
|
get().updateStat(message); |
|
get().summarizeSession(); |
|
}, |
|
|
|
async onUserInput(content) { |
|
const session = get().currentSession(); |
|
const modelConfig = session.mask.modelConfig; |
|
|
|
const userMessage: Message = createMessage({ |
|
role: "user", |
|
content, |
|
}); |
|
|
|
const botMessage: Message = createMessage({ |
|
role: "assistant", |
|
streaming: true, |
|
id: userMessage.id! + 1, |
|
model: modelConfig.model, |
|
}); |
|
|
|
|
|
const recentMessages = get().getMessagesWithMemory(); |
|
const sendMessages = recentMessages.concat(userMessage); |
|
const sessionIndex = get().currentSessionIndex; |
|
const messageIndex = get().currentSession().messages.length + 1; |
|
|
|
|
|
get().updateCurrentSession((session) => { |
|
session.messages.push(userMessage); |
|
session.messages.push(botMessage); |
|
}); |
|
|
|
|
|
console.log("[User Input] ", sendMessages); |
|
requestChatStream(sendMessages, { |
|
onMessage(content, done) { |
|
|
|
if (done) { |
|
botMessage.streaming = false; |
|
botMessage.content = content; |
|
get().onNewMessage(botMessage); |
|
ControllerPool.remove( |
|
sessionIndex, |
|
botMessage.id ?? messageIndex, |
|
); |
|
} else { |
|
botMessage.content = content; |
|
set(() => ({})); |
|
} |
|
}, |
|
onError(error, statusCode) { |
|
const isAborted = error.message.includes("aborted"); |
|
if (statusCode === 401) { |
|
botMessage.content = Locale.Error.Unauthorized; |
|
} else if (!isAborted) { |
|
botMessage.content += "\n\n" + Locale.Store.Error; |
|
} |
|
botMessage.streaming = false; |
|
userMessage.isError = !isAborted; |
|
botMessage.isError = !isAborted; |
|
|
|
set(() => ({})); |
|
ControllerPool.remove(sessionIndex, botMessage.id ?? messageIndex); |
|
}, |
|
onController(controller) { |
|
|
|
ControllerPool.addController( |
|
sessionIndex, |
|
botMessage.id ?? messageIndex, |
|
controller, |
|
); |
|
}, |
|
modelConfig: { ...modelConfig }, |
|
}); |
|
}, |
|
|
|
getMemoryPrompt() { |
|
const session = get().currentSession(); |
|
|
|
return { |
|
role: "system", |
|
content: |
|
session.memoryPrompt.length > 0 |
|
? Locale.Store.Prompt.History(session.memoryPrompt) |
|
: "", |
|
date: "", |
|
} as Message; |
|
}, |
|
|
|
getMessagesWithMemory() { |
|
const session = get().currentSession(); |
|
const modelConfig = session.mask.modelConfig; |
|
const messages = session.messages.filter((msg) => !msg.isError); |
|
const n = messages.length; |
|
|
|
const context = session.mask.context.slice(); |
|
|
|
|
|
if ( |
|
modelConfig.sendMemory && |
|
session.memoryPrompt && |
|
session.memoryPrompt.length > 0 |
|
) { |
|
const memoryPrompt = get().getMemoryPrompt(); |
|
context.push(memoryPrompt); |
|
} |
|
|
|
|
|
const shortTermMemoryMessageIndex = Math.max( |
|
0, |
|
n - modelConfig.historyMessageCount, |
|
); |
|
const longTermMemoryMessageIndex = session.lastSummarizeIndex; |
|
const oldestIndex = Math.max( |
|
shortTermMemoryMessageIndex, |
|
longTermMemoryMessageIndex, |
|
); |
|
const threshold = modelConfig.compressMessageLengthThreshold; |
|
|
|
|
|
const reversedRecentMessages = []; |
|
for ( |
|
let i = n - 1, count = 0; |
|
i >= oldestIndex && count < threshold; |
|
i -= 1 |
|
) { |
|
const msg = messages[i]; |
|
if (!msg || msg.isError) continue; |
|
count += msg.content.length; |
|
reversedRecentMessages.push(msg); |
|
} |
|
|
|
|
|
const recentMessages = context.concat(reversedRecentMessages.reverse()); |
|
|
|
return recentMessages; |
|
}, |
|
|
|
updateMessage( |
|
sessionIndex: number, |
|
messageIndex: number, |
|
updater: (message?: Message) => void, |
|
) { |
|
const sessions = get().sessions; |
|
const session = sessions.at(sessionIndex); |
|
const messages = session?.messages; |
|
updater(messages?.at(messageIndex)); |
|
set(() => ({ sessions })); |
|
}, |
|
|
|
resetSession() { |
|
get().updateCurrentSession((session) => { |
|
session.messages = []; |
|
session.memoryPrompt = ""; |
|
}); |
|
}, |
|
|
|
summarizeSession() { |
|
const session = get().currentSession(); |
|
|
|
|
|
const SUMMARIZE_MIN_LEN = 50; |
|
if ( |
|
session.topic === DEFAULT_TOPIC && |
|
countMessages(session.messages) >= SUMMARIZE_MIN_LEN |
|
) { |
|
requestWithPrompt(session.messages, Locale.Store.Prompt.Topic, { |
|
model: "gpt-3.5-turbo", |
|
}).then((res) => { |
|
get().updateCurrentSession( |
|
(session) => |
|
(session.topic = res ? trimTopic(res) : DEFAULT_TOPIC), |
|
); |
|
}); |
|
} |
|
|
|
const modelConfig = session.mask.modelConfig; |
|
let toBeSummarizedMsgs = session.messages.slice( |
|
session.lastSummarizeIndex, |
|
); |
|
|
|
const historyMsgLength = countMessages(toBeSummarizedMsgs); |
|
|
|
if (historyMsgLength > modelConfig?.max_tokens ?? 4000) { |
|
const n = toBeSummarizedMsgs.length; |
|
toBeSummarizedMsgs = toBeSummarizedMsgs.slice( |
|
Math.max(0, n - modelConfig.historyMessageCount), |
|
); |
|
} |
|
|
|
|
|
toBeSummarizedMsgs.unshift(get().getMemoryPrompt()); |
|
|
|
const lastSummarizeIndex = session.messages.length; |
|
|
|
console.log( |
|
"[Chat History] ", |
|
toBeSummarizedMsgs, |
|
historyMsgLength, |
|
modelConfig.compressMessageLengthThreshold, |
|
); |
|
|
|
if ( |
|
historyMsgLength > modelConfig.compressMessageLengthThreshold && |
|
session.mask.modelConfig.sendMemory |
|
) { |
|
requestChatStream( |
|
toBeSummarizedMsgs.concat({ |
|
role: "system", |
|
content: Locale.Store.Prompt.Summarize, |
|
date: "", |
|
}), |
|
{ |
|
overrideModel: "gpt-3.5-turbo", |
|
onMessage(message, done) { |
|
session.memoryPrompt = message; |
|
if (done) { |
|
console.log("[Memory] ", session.memoryPrompt); |
|
session.lastSummarizeIndex = lastSummarizeIndex; |
|
} |
|
}, |
|
onError(error) { |
|
console.error("[Summarize] ", error); |
|
}, |
|
}, |
|
); |
|
} |
|
}, |
|
|
|
updateStat(message) { |
|
get().updateCurrentSession((session) => { |
|
session.stat.charCount += message.content.length; |
|
|
|
}); |
|
}, |
|
|
|
updateCurrentSession(updater) { |
|
const sessions = get().sessions; |
|
const index = get().currentSessionIndex; |
|
updater(sessions[index]); |
|
set(() => ({ sessions })); |
|
}, |
|
|
|
clearAllData() { |
|
localStorage.clear(); |
|
location.reload(); |
|
}, |
|
}), |
|
{ |
|
name: StoreKey.Chat, |
|
version: 2, |
|
migrate(persistedState, version) { |
|
const state = persistedState as any; |
|
const newState = JSON.parse(JSON.stringify(state)) as ChatStore; |
|
|
|
if (version < 2) { |
|
newState.globalId = 0; |
|
newState.sessions = []; |
|
|
|
const oldSessions = state.sessions; |
|
for (const oldSession of oldSessions) { |
|
const newSession = createEmptySession(); |
|
newSession.topic = oldSession.topic; |
|
newSession.messages = [...oldSession.messages]; |
|
newSession.mask.modelConfig.sendMemory = true; |
|
newSession.mask.modelConfig.historyMessageCount = 4; |
|
newSession.mask.modelConfig.compressMessageLengthThreshold = 1000; |
|
newState.sessions.push(newSession); |
|
} |
|
} |
|
|
|
return newState; |
|
}, |
|
}, |
|
), |
|
); |
|
|