Spaces:
Running
Running
// Create a custom request handler for the /classify route. | |
// For more information, see https://nextjs.org/docs/app/building-your-application/routing/router-handlers | |
import { NextResponse } from 'next/server' | |
import ApplicationSingleton from '../app.js' | |
const parseInputs = (searchParams) => { | |
const text = searchParams.get('text'); | |
if (!text) { | |
return { | |
error: 'Missing text parameter', | |
}; | |
} | |
const threshold = searchParams.get('threshold'); | |
const match_threshold = Number(threshold ?? 0.1); | |
if (isNaN(match_threshold) || match_threshold < 0 || match_threshold > 1) { | |
return { | |
error: `Invalid threshold parameter "${threshold}" (should be a number between 0 and 1)`, | |
}; | |
} | |
const limit = searchParams.get('limit'); | |
const match_count = Number(limit ?? 25); | |
if (isNaN(match_count) || !Number.isInteger(match_count) || match_count < 0 || match_count > 1000) { | |
return { | |
error: `Invalid limit parameter "${limit}" (should be an integer between 0 and 1000)`, | |
}; | |
} | |
return { text, match_threshold, match_count } | |
} | |
// TODO: add caching | |
export async function GET(request) { | |
const parsedInputs = parseInputs(request.nextUrl.searchParams); | |
if (parsedInputs.error) { | |
return NextResponse.json({ | |
error: parsedInputs.error, | |
}, { status: 400 }); | |
} | |
// Valid inputs, so we can proceed | |
const { text, match_threshold, match_count } = parsedInputs; | |
// Get the tokenizer, model, and database singletons. When called for the first time, | |
// this will load the models and cache them for future use. | |
const [tokenizer, text_model, database] = await ApplicationSingleton.getInstance(); | |
// Run tokenization | |
let text_inputs = tokenizer(text, { padding: true, truncation: true }); | |
// Compute embeddings | |
const { text_embeds } = await text_model(text_inputs); | |
const query_embedding = text_embeds.tolist()[0]; | |
// TODO add pagination? | |
let { data: images, error } = await database | |
.rpc('match_images', { | |
query_embedding, | |
match_threshold, | |
match_count, | |
}); | |
if (error) { | |
console.warn('Error fetching images', error); | |
return NextResponse.json({ | |
error: 'An error occurred while fetching images', | |
}, { status: 500 }); | |
} | |
return NextResponse.json(images); | |
} | |