Spaces:
Sleeping
Sleeping
import { DEFAULT_MODELS, ServiceProvider } from "../constant"; | |
import { LLMModel } from "../client/api"; | |
const CustomSeq = { | |
val: -1000, //To ensure the custom model located at front, start from -1000, refer to constant.ts | |
cache: new Map<string, number>(), | |
next: (id: string) => { | |
if (CustomSeq.cache.has(id)) { | |
return CustomSeq.cache.get(id) as number; | |
} else { | |
let seq = CustomSeq.val++; | |
CustomSeq.cache.set(id, seq); | |
return seq; | |
} | |
}, | |
}; | |
const customProvider = (providerName: string) => ({ | |
id: providerName.toLowerCase(), | |
providerName: providerName, | |
providerType: "custom", | |
sorted: CustomSeq.next(providerName), | |
}); | |
/** | |
* Sorts an array of models based on specified rules. | |
* | |
* First, sorted by provider; if the same, sorted by model | |
*/ | |
const sortModelTable = (models: ReturnType<typeof collectModels>) => | |
models.sort((a, b) => { | |
if (a.provider && b.provider) { | |
let cmp = a.provider.sorted - b.provider.sorted; | |
return cmp === 0 ? a.sorted - b.sorted : cmp; | |
} else { | |
return a.sorted - b.sorted; | |
} | |
}); | |
/** | |
* get model name and provider from a formatted string, | |
* e.g. `gpt-4@OpenAi` or `claude-3-5-sonnet@20240620@Google` | |
* @param modelWithProvider model name with provider separated by last `@` char, | |
* @returns [model, provider] tuple, if no `@` char found, provider is undefined | |
*/ | |
export function getModelProvider(modelWithProvider: string): [string, string?] { | |
const [model, provider] = modelWithProvider.split(/@(?!.*@)/); | |
return [model, provider]; | |
} | |
export function collectModelTable( | |
models: readonly LLMModel[], | |
customModels: string, | |
) { | |
const modelTable: Record< | |
string, | |
{ | |
available: boolean; | |
name: string; | |
displayName: string; | |
sorted: number; | |
provider?: LLMModel["provider"]; // Marked as optional | |
isDefault?: boolean; | |
} | |
> = {}; | |
// default models | |
models.forEach((m) => { | |
// using <modelName>@<providerId> as fullName | |
modelTable[`${m.name}@${m?.provider?.id}`] = { | |
...m, | |
displayName: m.name, // 'provider' is copied over if it exists | |
}; | |
}); | |
// server custom models | |
customModels | |
.split(",") | |
.filter((v) => !!v && v.length > 0) | |
.forEach((m) => { | |
const available = !m.startsWith("-"); | |
const nameConfig = | |
m.startsWith("+") || m.startsWith("-") ? m.slice(1) : m; | |
let [name, displayName] = nameConfig.split("="); | |
// enable or disable all models | |
if (name === "all") { | |
Object.values(modelTable).forEach( | |
(model) => (model.available = available), | |
); | |
} else { | |
// 1. find model by name, and set available value | |
const [customModelName, customProviderName] = getModelProvider(name); | |
let count = 0; | |
for (const fullName in modelTable) { | |
const [modelName, providerName] = getModelProvider(fullName); | |
if ( | |
customModelName == modelName && | |
(customProviderName === undefined || | |
customProviderName === providerName) | |
) { | |
count += 1; | |
modelTable[fullName]["available"] = available; | |
// swap name and displayName for bytedance | |
if (providerName === "bytedance") { | |
[name, displayName] = [displayName, modelName]; | |
modelTable[fullName]["name"] = name; | |
} | |
if (displayName) { | |
modelTable[fullName]["displayName"] = displayName; | |
} | |
} | |
} | |
// 2. if model not exists, create new model with available value | |
if (count === 0) { | |
let [customModelName, customProviderName] = getModelProvider(name); | |
const provider = customProvider( | |
customProviderName || customModelName, | |
); | |
// swap name and displayName for bytedance | |
if (displayName && provider.providerName == "ByteDance") { | |
[customModelName, displayName] = [displayName, customModelName]; | |
} | |
modelTable[`${customModelName}@${provider?.id}`] = { | |
name: customModelName, | |
displayName: displayName || customModelName, | |
available, | |
provider, // Use optional chaining | |
sorted: CustomSeq.next(`${customModelName}@${provider?.id}`), | |
}; | |
} | |
} | |
}); | |
return modelTable; | |
} | |
export function collectModelTableWithDefaultModel( | |
models: readonly LLMModel[], | |
customModels: string, | |
defaultModel: string, | |
) { | |
let modelTable = collectModelTable(models, customModels); | |
if (defaultModel && defaultModel !== "") { | |
if (defaultModel.includes("@")) { | |
if (defaultModel in modelTable) { | |
modelTable[defaultModel].isDefault = true; | |
} | |
} else { | |
for (const key of Object.keys(modelTable)) { | |
if ( | |
modelTable[key].available && | |
getModelProvider(key)[0] == defaultModel | |
) { | |
modelTable[key].isDefault = true; | |
break; | |
} | |
} | |
} | |
} | |
return modelTable; | |
} | |
/** | |
* Generate full model table. | |
*/ | |
export function collectModels( | |
models: readonly LLMModel[], | |
customModels: string, | |
) { | |
const modelTable = collectModelTable(models, customModels); | |
let allModels = Object.values(modelTable); | |
allModels = sortModelTable(allModels); | |
return allModels; | |
} | |
export function collectModelsWithDefaultModel( | |
models: readonly LLMModel[], | |
customModels: string, | |
defaultModel: string, | |
) { | |
const modelTable = collectModelTableWithDefaultModel( | |
models, | |
customModels, | |
defaultModel, | |
); | |
let allModels = Object.values(modelTable); | |
allModels = sortModelTable(allModels); | |
return allModels; | |
} | |
export function isModelAvailableInServer( | |
customModels: string, | |
modelName: string, | |
providerName: string, | |
) { | |
const fullName = `${modelName}@${providerName}`; | |
const modelTable = collectModelTable(DEFAULT_MODELS, customModels); | |
return modelTable[fullName]?.available === false; | |
} | |
/** | |
* Check if the model name is a GPT-4 related model | |
* | |
* @param modelName The name of the model to check | |
* @returns True if the model is a GPT-4 related model (excluding gpt-4o-mini) | |
*/ | |
export function isGPT4Model(modelName: string): boolean { | |
return ( | |
(modelName.startsWith("gpt-4") || | |
modelName.startsWith("chatgpt-4o") || | |
modelName.startsWith("o1")) && | |
!modelName.startsWith("gpt-4o-mini") | |
); | |
} | |
/** | |
* Checks if a model is not available on any of the specified providers in the server. | |
* | |
* @param {string} customModels - A string of custom models, comma-separated. | |
* @param {string} modelName - The name of the model to check. | |
* @param {string|string[]} providerNames - A string or array of provider names to check against. | |
* | |
* @returns {boolean} True if the model is not available on any of the specified providers, false otherwise. | |
*/ | |
export function isModelNotavailableInServer( | |
customModels: string, | |
modelName: string, | |
providerNames: string | string[], | |
): boolean { | |
// Check DISABLE_GPT4 environment variable | |
if ( | |
process.env.DISABLE_GPT4 === "1" && | |
isGPT4Model(modelName.toLowerCase()) | |
) { | |
return true; | |
} | |
const modelTable = collectModelTable(DEFAULT_MODELS, customModels); | |
const providerNamesArray = Array.isArray(providerNames) | |
? providerNames | |
: [providerNames]; | |
for (const providerName of providerNamesArray) { | |
// if model provider is bytedance, use model config name to check if not avaliable | |
if (providerName === ServiceProvider.ByteDance) { | |
return !Object.values(modelTable).filter((v) => v.name === modelName)?.[0] | |
?.available; | |
} | |
const fullName = `${modelName}@${providerName.toLowerCase()}`; | |
if (modelTable?.[fullName]?.available === true) return false; | |
} | |
return true; | |
} | |