| "use client"; |
|
|
| import { useState, useMemo, useEffect } from "react"; |
| import PropTypes from "prop-types"; |
| import Modal from "./Modal"; |
| import ProviderIcon from "./ProviderIcon"; |
| import { getModelsByProviderId } from "@/shared/constants/models"; |
| import { OAUTH_PROVIDERS, APIKEY_PROVIDERS, FREE_PROVIDERS, FREE_TIER_PROVIDERS, AI_PROVIDERS, isOpenAICompatibleProvider, isAnthropicCompatibleProvider, getProviderAlias } from "@/shared/constants/providers"; |
|
|
| |
| const PROVIDER_ORDER = [ |
| ...Object.keys(OAUTH_PROVIDERS), |
| ...Object.keys(FREE_PROVIDERS), |
| ...Object.keys(FREE_TIER_PROVIDERS), |
| ...Object.keys(APIKEY_PROVIDERS), |
| ]; |
|
|
| |
| const NO_AUTH_PROVIDER_IDS = Object.keys(FREE_PROVIDERS).filter(id => FREE_PROVIDERS[id].noAuth); |
|
|
| export default function ModelSelectModal({ |
| isOpen, |
| onClose, |
| onSelect, |
| onDeselect, |
| selectedModel, |
| activeProviders = [], |
| title = "Select Model", |
| modelAliases = {}, |
| kindFilter = null, |
| addedModelValues = [], |
| closeOnSelect = true, |
| }) { |
| |
| const filteredActiveProviders = useMemo(() => { |
| if (!kindFilter) return activeProviders; |
| return activeProviders.filter((p) => { |
| const info = AI_PROVIDERS[p.provider]; |
| const kinds = info?.serviceKinds || ["llm"]; |
| return kinds.includes(kindFilter); |
| }); |
| }, [activeProviders, kindFilter]); |
| const [searchQuery, setSearchQuery] = useState(""); |
| const [combos, setCombos] = useState([]); |
| const [providerNodes, setProviderNodes] = useState([]); |
| const [customModels, setCustomModels] = useState([]); |
| const [disabledModels, setDisabledModels] = useState({}); |
|
|
| const fetchCombos = async () => { |
| try { |
| const res = await fetch("/api/combos"); |
| if (!res.ok) throw new Error(`Failed to fetch combos: ${res.status}`); |
| const data = await res.json(); |
| setCombos(data.combos || []); |
| } catch (error) { |
| console.error("Error fetching combos:", error); |
| setCombos([]); |
| } |
| }; |
|
|
| useEffect(() => { |
| if (isOpen) fetchCombos(); |
| }, [isOpen]); |
|
|
| const fetchProviderNodes = async () => { |
| try { |
| const res = await fetch("/api/provider-nodes"); |
| if (!res.ok) throw new Error(`Failed to fetch provider nodes: ${res.status}`); |
| const data = await res.json(); |
| setProviderNodes(data.nodes || []); |
| } catch (error) { |
| console.error("Error fetching provider nodes:", error); |
| setProviderNodes([]); |
| } |
| }; |
|
|
| useEffect(() => { |
| if (isOpen) fetchProviderNodes(); |
| }, [isOpen]); |
|
|
| const fetchCustomModels = async () => { |
| try { |
| const res = await fetch("/api/models/custom"); |
| if (!res.ok) throw new Error(`Failed to fetch custom models: ${res.status}`); |
| const data = await res.json(); |
| setCustomModels(data.models || []); |
| } catch (error) { |
| console.error("Error fetching custom models:", error); |
| setCustomModels([]); |
| } |
| }; |
|
|
| useEffect(() => { |
| if (isOpen) fetchCustomModels(); |
| }, [isOpen]); |
|
|
| const fetchDisabledModels = async () => { |
| try { |
| const res = await fetch("/api/models/disabled"); |
| if (!res.ok) throw new Error(`Failed to fetch disabled models: ${res.status}`); |
| const data = await res.json(); |
| setDisabledModels(data.disabled || {}); |
| } catch (error) { |
| console.error("Error fetching disabled models:", error); |
| setDisabledModels({}); |
| } |
| }; |
|
|
| useEffect(() => { |
| if (isOpen) fetchDisabledModels(); |
| }, [isOpen]); |
|
|
| const allProviders = useMemo(() => ({ ...OAUTH_PROVIDERS, ...FREE_PROVIDERS, ...FREE_TIER_PROVIDERS, ...APIKEY_PROVIDERS }), []); |
|
|
| |
| const groupedModels = useMemo(() => { |
| const groups = {}; |
|
|
| |
| const PROVIDER_AS_MODEL_KINDS = new Set(["webSearch", "webFetch"]); |
| |
| const TYPED_KINDS = new Set(["image", "tts", "stt", "embedding", "imageToText"]); |
| |
| const ALLOW_PROVIDER_FALLBACK_KINDS = new Set(["tts", "image", "webFetch"]); |
|
|
| |
| const filterByKind = (models) => { |
| |
| if (!kindFilter) return models.filter((m) => m.isPlaceholder || !m.type || m.type === "llm"); |
| if (!TYPED_KINDS.has(kindFilter)) return models; |
| return models.filter((m) => m.isPlaceholder || m.type === kindFilter); |
| }; |
|
|
| |
| const activeConnectionIds = filteredActiveProviders.map(p => p.provider); |
|
|
| |
| const noAuthIds = kindFilter |
| ? NO_AUTH_PROVIDER_IDS.filter((id) => (AI_PROVIDERS[id]?.serviceKinds || ["llm"]).includes(kindFilter)) |
| : NO_AUTH_PROVIDER_IDS; |
|
|
| |
| const providerIdsToShow = new Set([ |
| ...activeConnectionIds, |
| ...noAuthIds, |
| ]); |
|
|
| |
| const sortedProviderIds = [...providerIdsToShow].sort((a, b) => { |
| const indexA = PROVIDER_ORDER.indexOf(a); |
| const indexB = PROVIDER_ORDER.indexOf(b); |
| return (indexA === -1 ? 999 : indexA) - (indexB === -1 ? 999 : indexB); |
| }); |
|
|
| sortedProviderIds.forEach((providerId) => { |
| const alias = getProviderAlias(providerId); |
| const providerInfo = allProviders[providerId] || { name: providerId, color: "#666" }; |
| const isCustomProvider = isOpenAICompatibleProvider(providerId) || isAnthropicCompatibleProvider(providerId); |
|
|
| |
| if (kindFilter && PROVIDER_AS_MODEL_KINDS.has(kindFilter)) { |
| groups[providerId] = { |
| name: providerInfo.name, |
| alias, |
| color: providerInfo.color, |
| models: [{ id: providerId, name: providerInfo.name, value: providerId }], |
| }; |
| return; |
| } |
|
|
| if (providerInfo.passthroughModels) { |
| const aliasModels = Object.entries(modelAliases) |
| .filter(([, fullModel]) => fullModel.startsWith(`${alias}/`)) |
| .map(([aliasName, fullModel]) => ({ |
| id: fullModel.replace(`${alias}/`, ""), |
| name: aliasName, |
| value: fullModel, |
| })); |
|
|
| |
| let combined = aliasModels; |
| if (kindFilter && TYPED_KINDS.has(kindFilter)) { |
| combined = getModelsByProviderId(providerId) |
| .filter((m) => m.type === kindFilter) |
| .map((m) => ({ id: m.id, name: m.name, value: `${alias}/${m.id}`, type: m.type })); |
| |
| if (combined.length === 0 && ALLOW_PROVIDER_FALLBACK_KINDS.has(kindFilter)) { |
| const supports = (providerInfo.serviceKinds || ["llm"]).includes(kindFilter); |
| if (supports) combined = [{ id: providerId, name: providerInfo.name, value: alias }]; |
| } |
| } |
|
|
| if (combined.length > 0) { |
| |
| const matchedNode = providerNodes.find(node => node.id === providerId); |
| const displayName = matchedNode?.name || providerInfo.name; |
|
|
| groups[providerId] = { |
| name: displayName, |
| alias: alias, |
| color: providerInfo.color, |
| models: combined, |
| }; |
| } |
| } else if (isCustomProvider) { |
| |
| if (kindFilter && TYPED_KINDS.has(kindFilter)) return; |
| |
| const connection = activeProviders.find(p => p.provider === providerId); |
| const matchedNode = providerNodes.find(node => node.id === providerId); |
| const displayName = matchedNode?.name || connection?.name || providerInfo.name; |
| const nodePrefix = connection?.providerSpecificData?.prefix || matchedNode?.prefix || providerId; |
|
|
| |
| |
| const nodeModels = Object.entries(modelAliases) |
| .filter(([, fullModel]) => fullModel.startsWith(`${providerId}/`)) |
| .map(([aliasName, fullModel]) => ({ |
| id: fullModel.replace(`${providerId}/`, ""), |
| name: aliasName, |
| value: `${nodePrefix}/${fullModel.replace(`${providerId}/`, "")}`, |
| })); |
|
|
| |
| |
| const modelsToShow = nodeModels.length > 0 ? nodeModels : [{ |
| id: `__placeholder__${providerId}`, |
| name: `${nodePrefix}/model-id`, |
| value: `${nodePrefix}/model-id`, |
| isPlaceholder: true, |
| }]; |
|
|
| groups[providerId] = { |
| name: displayName, |
| alias: nodePrefix, |
| color: providerInfo.color, |
| models: modelsToShow, |
| isCustom: true, |
| hasModels: nodeModels.length > 0, |
| }; |
| } else { |
| const hardcodedModels = getModelsByProviderId(providerId); |
| const hardcodedIds = new Set(hardcodedModels.map((m) => m.id)); |
|
|
| |
| |
| const hasHardcoded = hardcodedModels.length > 0; |
| const customAliasModels = Object.entries(modelAliases) |
| .filter(([aliasName, fullModel]) => |
| fullModel.startsWith(`${alias}/`) && |
| (hasHardcoded ? aliasName === fullModel.replace(`${alias}/`, "") : true) && |
| !hardcodedIds.has(fullModel.replace(`${alias}/`, "")) |
| ) |
| .map(([aliasName, fullModel]) => { |
| const modelId = fullModel.replace(`${alias}/`, ""); |
| return { id: modelId, name: aliasName, value: fullModel, isCustom: true }; |
| }); |
|
|
| |
| const customAliasIds = new Set(customAliasModels.map((m) => m.id)); |
| const customRegisteredModels = customModels |
| .filter((m) => m.providerAlias === alias && !hardcodedIds.has(m.id) && !customAliasIds.has(m.id)) |
| .map((m) => ({ id: m.id, name: m.name || m.id, value: `${alias}/${m.id}`, isCustom: true })); |
|
|
| const merged = [ |
| ...hardcodedModels.map((m) => ({ id: m.id, name: m.name, value: `${alias}/${m.id}`, type: m.type })), |
| ...customAliasModels, |
| ...customRegisteredModels, |
| ]; |
| |
| const seen = new Set(); |
| let allModels = filterByKind(merged.filter((m) => { |
| if (seen.has(m.value)) return false; |
| seen.add(m.value); |
| return true; |
| })); |
|
|
| |
| |
| if (allModels.length === 0 && kindFilter && ALLOW_PROVIDER_FALLBACK_KINDS.has(kindFilter)) { |
| const supports = (providerInfo.serviceKinds || ["llm"]).includes(kindFilter); |
| if (supports) { |
| allModels = [{ id: providerId, name: providerInfo.name, value: alias }]; |
| } |
| } |
|
|
| if (allModels.length > 0) { |
| groups[providerId] = { |
| name: providerInfo.name, |
| alias: alias, |
| color: providerInfo.color, |
| models: allModels, |
| }; |
| } |
| } |
| }); |
|
|
| |
| Object.entries(groups).forEach(([providerId, group]) => { |
| const aliasKey = getProviderAlias(providerId); |
| const disabledIds = new Set([ |
| ...(disabledModels[aliasKey] || []), |
| ...(disabledModels[providerId] || []), |
| ]); |
| if (disabledIds.size === 0) return; |
| group.models = group.models.filter((m) => !disabledIds.has(m.id)); |
| if (group.models.length === 0) delete groups[providerId]; |
| }); |
|
|
| return groups; |
| }, [filteredActiveProviders, modelAliases, allProviders, providerNodes, customModels, disabledModels, kindFilter, activeProviders]); |
|
|
| |
| const filteredCombos = useMemo(() => { |
| if (kindFilter) return []; |
| if (!searchQuery.trim()) return combos; |
| const query = searchQuery.toLowerCase(); |
| return combos.filter(c => c.name.toLowerCase().includes(query)); |
| }, [combos, searchQuery, kindFilter]); |
|
|
| |
| const sortModels = (models) => { |
| const added = models.filter(m => addedModelValues.includes(m.value)).sort((a, b) => a.name.localeCompare(b.name)); |
| const rest = models.filter(m => !addedModelValues.includes(m.value)).sort((a, b) => a.name.localeCompare(b.name)); |
| return [...added, ...rest]; |
| }; |
|
|
| |
| const filteredGroups = useMemo(() => { |
| const query = searchQuery.trim().toLowerCase(); |
|
|
| const filtered = {}; |
| Object.entries(groupedModels).forEach(([providerId, group]) => { |
| let models = group.models; |
| if (query) { |
| const providerNameMatches = group.name.toLowerCase().includes(query); |
| models = models.filter( |
| (m) => |
| m.name.toLowerCase().includes(query) || |
| m.id.toLowerCase().includes(query) |
| ); |
| if (models.length === 0 && !providerNameMatches) return; |
| } |
| filtered[providerId] = { |
| ...group, |
| models: sortModels(models), |
| }; |
| }); |
|
|
| return filtered; |
| }, [groupedModels, searchQuery, addedModelValues]); |
|
|
| const handleSelect = (model) => { |
| const value = model?.value || model?.name || model; |
| const isAdded = addedModelValues.includes(value); |
|
|
| if (isAdded && onDeselect) { |
| onDeselect(model); |
| } else { |
| onSelect(model); |
| } |
|
|
| if (closeOnSelect) { |
| onClose(); |
| setSearchQuery(""); |
| } |
| }; |
|
|
| return ( |
| <Modal |
| isOpen={isOpen} |
| onClose={() => { |
| onClose(); |
| setSearchQuery(""); |
| }} |
| title={title} |
| size="md" |
| className="p-4!" |
| footer={null} |
| > |
| {/* Info bar */} |
| <div className="flex items-center gap-2 mb-3 px-2.5 py-2 bg-primary/8 border border-primary/20 rounded-lg text-xs text-text-muted"> |
| <span className="material-symbols-outlined text-primary shrink-0" style={{ fontSize: "14px" }}>info</span> |
| <span>Click to add, click again to remove. Changes are saved automatically.</span> |
| </div> |
| |
| {/* Search - compact */} |
| <div className="mb-3"> |
| <div className="relative"> |
| <span className="material-symbols-outlined absolute left-2.5 top-1/2 -translate-y-1/2 text-text-muted text-[16px]"> |
| search |
| </span> |
| <input |
| type="text" |
| placeholder="Search..." |
| value={searchQuery} |
| onChange={(e) => setSearchQuery(e.target.value)} |
| className="w-full pl-8 pr-3 py-1.5 bg-surface border border-border rounded text-xs focus:outline-none focus:ring-1 focus:ring-primary/50" |
| /> |
| </div> |
| </div> |
| |
| {/* Models grouped by provider - compact */} |
| <div className="max-h-[400px] overflow-y-auto space-y-3"> |
| {/* Combos section - always first */} |
| {filteredCombos.length > 0 && ( |
| <div> |
| <div className="flex items-center gap-1.5 mb-1.5 sticky top-0 bg-surface py-0.5"> |
| <span className="material-symbols-outlined text-primary text-[14px]">layers</span> |
| <span className="text-xs font-medium text-primary">Combos</span> |
| <span className="text-[10px] text-text-muted">({filteredCombos.length})</span> |
| </div> |
| <div className="flex flex-wrap gap-1.5"> |
| {filteredCombos.map((combo) => { |
| const isSelected = selectedModel === combo.name; |
| return ( |
| <button |
| key={combo.id} |
| onClick={() => handleSelect({ id: combo.name, name: combo.name, value: combo.name })} |
| className={` |
| px-2 py-1 rounded-xl text-xs font-medium transition-all border hover:cursor-pointer flex items-center gap-1 |
| ${isSelected |
| ? "bg-primary text-white border-primary" |
| : addedModelValues.includes(combo.name) |
| ? "bg-primary border-primary text-white hover:bg-primary-hover" |
| : "bg-surface border-border text-text-main hover:border-primary/50 hover:bg-primary/5" |
| } |
| `} |
| > |
| {addedModelValues.includes(combo.name) && ( |
| <span className="material-symbols-outlined leading-none" style={{ fontSize: "10px" }}>check</span> |
| )} |
| {combo.name} |
| </button> |
| ); |
| })} |
| </div> |
| </div> |
| )} |
| |
| {/* Provider models */} |
| {Object.entries(filteredGroups).map(([providerId, group]) => ( |
| <div key={providerId}> |
| {/* Provider header */} |
| <div className="flex items-center gap-1.5 mb-1.5 sticky top-0 bg-surface py-0.5"> |
| <ProviderIcon |
| src={`/providers/${providerId}.png`} |
| alt={group.name} |
| size={14} |
| fallbackText={(group.name || providerId).slice(0, 2).toUpperCase()} |
| fallbackColor={group.color} |
| /> |
| <span className="text-xs font-medium text-primary"> |
| {group.name} |
| </span> |
| <span className="text-[10px] text-text-muted"> |
| ({group.models.length}) |
| </span> |
| </div> |
| |
| <div className="flex flex-wrap gap-1.5"> |
| {group.models.map((model) => { |
| const isSelected = selectedModel === model.value; |
| const isPlaceholder = model.isPlaceholder; |
| return ( |
| <button |
| key={model.value} |
| onClick={() => handleSelect(model)} |
| title={isPlaceholder ? "Select to pre-fill, then edit model ID in the input" : undefined} |
| className={` |
| px-2 py-1 rounded-xl text-xs font-medium transition-all border hover:cursor-pointer |
| ${isPlaceholder |
| ? "border-dashed border-border text-text-muted hover:border-primary/50 hover:text-primary bg-surface italic" |
| : isSelected |
| ? "bg-primary text-white border-primary" |
| : addedModelValues.includes(model.value) |
| ? "bg-primary border-primary text-white hover:bg-primary-hover" |
| : "bg-surface border-border text-text-main hover:border-primary/50 hover:bg-primary/5" |
| } |
| `} |
| > |
| <span className="flex items-center gap-1"> |
| {addedModelValues.includes(model.value) && !isPlaceholder && ( |
| <span className="material-symbols-outlined leading-none" style={{ fontSize: "10px" }}>check</span> |
| )} |
| {isPlaceholder ? ( |
| <> |
| <span className="material-symbols-outlined text-[11px]">edit</span> |
| {model.name} |
| </> |
| ) : model.isCustom ? ( |
| <> |
| {model.name} |
| <span className="text-[9px] opacity-60 font-normal">custom</span> |
| </> |
| ) : ( |
| model.name |
| )} |
| </span> |
| </button> |
| ); |
| })} |
| </div> |
| </div> |
| ))} |
|
|
| {Object.keys(filteredGroups).length === 0 && filteredCombos.length === 0 && ( |
| <div className="text-center py-4 text-text-muted"> |
| <span className="material-symbols-outlined text-2xl mb-1 block"> |
| search_off |
| </span> |
| <p className="text-xs">No models found</p> |
| </div> |
| )} |
| </div> |
| </Modal> |
| ); |
| } |
|
|
| ModelSelectModal.propTypes = { |
| isOpen: PropTypes.bool.isRequired, |
| onClose: PropTypes.func.isRequired, |
| onSelect: PropTypes.func.isRequired, |
| onDeselect: PropTypes.func, |
| selectedModel: PropTypes.string, |
| activeProviders: PropTypes.arrayOf( |
| PropTypes.shape({ |
| provider: PropTypes.string.isRequired, |
| }) |
| ), |
| title: PropTypes.string, |
| modelAliases: PropTypes.object, |
| kindFilter: PropTypes.string, |
| addedModelValues: PropTypes.arrayOf(PropTypes.string), |
| closeOnSelect: PropTypes.bool, |
| }; |
|
|
|
|