File size: 2,445 Bytes
1efb9e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
// Create a custom request handler for the /classify route.
// For more information, see https://nextjs.org/docs/app/building-your-application/routing/router-handlers

import { NextResponse } from 'next/server'
import ApplicationSingleton from '../app.js'

const parseInputs = (searchParams) => {
    const text = searchParams.get('text');
    if (!text) {
        return {
            error: 'Missing text parameter',
        };
    }
    const threshold = searchParams.get('threshold');
    const match_threshold = Number(threshold ?? 0.1);
    if (isNaN(match_threshold) || match_threshold < 0 || match_threshold > 1) {
        return {
            error: `Invalid threshold parameter "${threshold}" (should be a number between 0 and 1)`,
        };
    }

    const limit = searchParams.get('limit');
    const match_count = Number(limit ?? 25);
    if (isNaN(match_count) || !Number.isInteger(match_count) || match_count < 0 || match_count > 1000) {
        return {
            error: `Invalid limit parameter "${limit}" (should be an integer between 0 and 1000)`,
        };
    }

    return { text, match_threshold, match_count }
}

// TODO: add caching

export async function GET(request) {
    const parsedInputs = parseInputs(request.nextUrl.searchParams);
    if (parsedInputs.error) {
        return NextResponse.json({
            error: parsedInputs.error,
        }, { status: 400 });
    }

    // Valid inputs, so we can proceed
    const { text, match_threshold, match_count } = parsedInputs;

    // Get the tokenizer, model, and database singletons. When called for the first time,
    // this will load the models and cache them for future use.
    const [tokenizer, text_model, database] = await ApplicationSingleton.getInstance();

    // Run tokenization
    let text_inputs = tokenizer(text, { padding: true, truncation: true });

    // Compute embeddings
    const { text_embeds } = await text_model(text_inputs);
    const query_embedding = text_embeds.tolist()[0];

    // TODO add pagination?
    let { data: images, error } = await database
        .rpc('match_images', {
            query_embedding,
            match_threshold,
            match_count,
        });
    if (error) {
        console.warn('Error fetching images', error);
        return NextResponse.json({
            error: 'An error occurred while fetching images',
        }, { status: 500 });
    }


    return NextResponse.json(images);
}