File size: 2,185 Bytes
c1f12bf
 
 
 
ed9e9d0
 
 
 
 
 
 
 
 
 
c1f12bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed9e9d0
 
 
 
c1f12bf
 
 
ed9e9d0
c1f12bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed9e9d0
c1f12bf
 
 
 
 
 
 
ed9e9d0
c1f12bf
ed9e9d0
c1f12bf
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import { GradioApiInfo, GradioEndpoint, SupportedFields } from '../types'
import { identifyField } from './identifyField'
import { getDefaultFields } from './getDefaultFields'
import { getAdaptationScore } from './getAdaptationScore'

/**
 * Find the main entrypoint (main entry endpoint) of a Gradio API
 */
export function findMainGradioEndpoint({
  gradioApiInfo,
}: {
  gradioApiInfo: GradioApiInfo
}): GradioEndpoint | undefined {
  const endpoints: GradioEndpoint[] = [
    ...Object.entries(gradioApiInfo.named_endpoints).map(
      ([name, endpoint]) => ({
        isNamed: true,
        name,
        endpoint,
        fields: {},
        score: 0,
      })
    ),
    ...Object.entries(gradioApiInfo.unnamed_endpoints).map(
      ([name, endpoint]) => ({
        isNamed: true,
        name,
        endpoint,
        fields: {},
        score: 0,
      })
    ),
  ]

  // generally the main entry point will be called "/run", "/call", "/predict" etc
  // and contain stuff we usually expect: a text prompt, or image etc
  const sortableEndpoints = endpoints.map(
    ({ isNamed, name, endpoint, score }) => {
      console.log(`found endpoint: ${name}`)

      // const isContinuous = !!endpoint.type?.continuous
      // const isGenerator = !!endpoint.type?.generator
      // const canCancel = !!endpoint.type?.cancel

      let gradioFields: Record<string, Partial<SupportedFields>> = {}
      let allGradioFields = getDefaultFields()
      for (const gradioParameter of endpoint.parameters) {
        const gradioParameterField = identifyField(
          gradioParameter.parameter_name,
          gradioParameter.parameter_default
        )
        gradioFields[gradioParameter.parameter_name] = gradioParameterField
        allGradioFields = { ...allGradioFields, ...gradioParameterField }
      }

      score = getAdaptationScore(allGradioFields)
      console.log(`allGradioFields:`, allGradioFields)
      console.log(`score:`, score)

      return {
        isNamed,
        name,
        endpoint,
        fields: gradioFields,
        score,
      }
    }
  )

  return sortableEndpoints
    .sort((a, b) => {
      return b.score - a.score
    })
    .at(0)
}