import { GradioApiInfo, SupportedFields } from '../types' import { identifyField } from './identifyField' import { getDefaultFields } from './getDefaultFields' import { findMainGradioEndpoint } from './findMainGradioEndpoint' import { base64DataUriToBlob } from '@/lib/utils/base64DataUriToBlob' /** * This function try to adapt arbitrary inputs to strict gradio inputs * * @param param0 * @returns */ export function adaptAnyInputsToGradioInputs({ inputs, gradioApiInfo, }: { inputs: Record gradioApiInfo: GradioApiInfo }): { endpoint: string inputArray: Array inputMap: Record } { const mainGradioEndpoint = findMainGradioEndpoint({ gradioApiInfo }) if (!mainGradioEndpoint) { throw new Error(`couldn't find a suitable endpoint`) } // input fields passed by the parent calling function let inputFields: Record> = {} let allInputFields = getDefaultFields() for (const [key, value] of Object.entries(inputs)) { const inputField = identifyField(key, value) inputFields[key] = inputField allInputFields = { ...allInputFields, ...inputField } } console.log( `input fields passed by the parent calling function:`, inputFields ) // the gradio input array // apparently the new JS client also supports dictionaries, yay! let inputArray: Array = [] let inputMap: Record< string, string | number | boolean | Blob | undefined | null > = {} for (const parameter of mainGradioEndpoint.endpoint.parameters) { let gradioInputValue: any = undefined const fields = mainGradioEndpoint.fields[parameter.parameter_name] // TODO: rewrite this in a better way maybe // until then, please don't blame me if you forget to update those! if (fields.hasPositiveTextPrompt) { gradioInputValue = allInputFields.inputPositiveTextPrompt } if (fields.hasNegativeTextPrompt) { gradioInputValue = allInputFields.inputNegativeTextPrompt } if (fields.hasInputImage) { gradioInputValue = allInputFields.inputImage } if (fields.hasInputAudio) { gradioInputValue = allInputFields.inputAudio } if (fields.hasInputWidth) { gradioInputValue = allInputFields.inputWidth } if (fields.hasInputHeight) { gradioInputValue = allInputFields.inputHeight } if (fields.hasInputSteps) { gradioInputValue = allInputFields.inputSteps } if (fields.hasInputGuidance) { gradioInputValue = allInputFields.inputGuidance } if (fields.hasInputSeed) { gradioInputValue = allInputFields.inputSeed } //console.log("parameter:", parameter) const valueSeemsToBeABase64Uri = typeof gradioInputValue === 'string' && gradioInputValue.startsWith('data:') const fieldSeemsToBeTextBased = parameter.type === 'string' || parameter.component === 'Textbox' // || parameter.parameter_name.includes("base64") // the magic doesn't end here: we need to convert base64 inputs to buffers, // unless gradio expects it to be a text if (valueSeemsToBeABase64Uri && !fieldSeemsToBeTextBased) { gradioInputValue = base64DataUriToBlob(gradioInputValue) } // old, low-level way inputArray.push(gradioInputValue) // new, high-level way inputMap[parameter.parameter_name] = gradioInputValue } console.log( `inputArray:`, inputArray.map((x) => (typeof x === 'string' ? x.slice(0, 255) : x)) ) console.log(`await client.predict("${mainGradioEndpoint.name}", `, inputMap) return { endpoint: mainGradioEndpoint.name, inputArray, inputMap, } }