|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; |
|
import { createSelector } from '@reduxjs/toolkit'; |
|
import { useAppSelector } from 'app/store/storeHooks'; |
|
import type { GroupBase } from 'chakra-react-select'; |
|
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; |
|
import type { ModelIdentifierField } from 'features/nodes/types/common'; |
|
import { selectSystemShouldEnableModelDescriptions } from 'features/system/store/systemSlice'; |
|
import { groupBy, reduce } from 'lodash-es'; |
|
import { useCallback, useMemo } from 'react'; |
|
import { useTranslation } from 'react-i18next'; |
|
import type { AnyModelConfig } from 'services/api/types'; |
|
|
|
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = { |
|
modelConfigs: T[]; |
|
selectedModel?: ModelIdentifierField | null; |
|
onChange: (value: T | null) => void; |
|
getIsDisabled?: (model: T) => boolean; |
|
isLoading?: boolean; |
|
groupByType?: boolean; |
|
}; |
|
|
|
type UseGroupedModelComboboxReturn = { |
|
value: ComboboxOption | undefined | null; |
|
options: GroupBase<ComboboxOption>[]; |
|
onChange: ComboboxOnChange; |
|
placeholder: string; |
|
noOptionsMessage: () => string; |
|
}; |
|
|
|
const groupByBaseFunc = <T extends AnyModelConfig>(model: T) => model.base.toUpperCase(); |
|
const groupByBaseAndTypeFunc = <T extends AnyModelConfig>(model: T) => |
|
`${model.base.toUpperCase()} / ${model.type.replaceAll('_', ' ').toUpperCase()}`; |
|
|
|
const selectBaseWithSDXLFallback = createSelector(selectParamsSlice, (params) => params.model?.base ?? 'sdxl'); |
|
|
|
export const useGroupedModelCombobox = <T extends AnyModelConfig>( |
|
arg: UseGroupedModelComboboxArg<T> |
|
): UseGroupedModelComboboxReturn => { |
|
const { t } = useTranslation(); |
|
const base = useAppSelector(selectBaseWithSDXLFallback); |
|
const shouldShowModelDescriptions = useAppSelector(selectSystemShouldEnableModelDescriptions); |
|
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, groupByType = false } = arg; |
|
const options = useMemo<GroupBase<ComboboxOption>[]>(() => { |
|
if (!modelConfigs) { |
|
return []; |
|
} |
|
const groupedModels = groupBy(modelConfigs, groupByType ? groupByBaseAndTypeFunc : groupByBaseFunc); |
|
const _options = reduce( |
|
groupedModels, |
|
(acc, val, label) => { |
|
acc.push({ |
|
label, |
|
options: val.map((model) => ({ |
|
label: model.name, |
|
value: model.key, |
|
description: (shouldShowModelDescriptions && model.description) || undefined, |
|
isDisabled: getIsDisabled ? getIsDisabled(model) : false, |
|
})), |
|
}); |
|
return acc; |
|
}, |
|
[] as GroupBase<ComboboxOption>[] |
|
); |
|
_options.sort((a) => (a.label?.split('/')[0]?.toLowerCase().includes(base) ? -1 : 1)); |
|
return _options; |
|
}, [modelConfigs, groupByType, getIsDisabled, base, shouldShowModelDescriptions]); |
|
|
|
const value = useMemo( |
|
() => |
|
options.flatMap((o) => o.options).find((m) => (selectedModel ? m.value === selectedModel.key : false)) ?? null, |
|
[options, selectedModel] |
|
); |
|
|
|
const _onChange = useCallback<ComboboxOnChange>( |
|
(v) => { |
|
if (!v) { |
|
onChange(null); |
|
return; |
|
} |
|
const model = modelConfigs.find((m) => m.key === v.value); |
|
if (!model) { |
|
onChange(null); |
|
return; |
|
} |
|
onChange(model); |
|
}, |
|
[modelConfigs, onChange] |
|
); |
|
|
|
const placeholder = useMemo(() => { |
|
if (isLoading) { |
|
return t('common.loading'); |
|
} |
|
|
|
if (options.length === 0) { |
|
return t('models.noModelsAvailable'); |
|
} |
|
|
|
return t('models.selectModel'); |
|
}, [isLoading, options, t]); |
|
|
|
const noOptionsMessage = useCallback(() => t('models.noMatchingModels'), [t]); |
|
|
|
return { options, value, onChange: _onChange, placeholder, noOptionsMessage }; |
|
}; |
|
|