File size: 3,199 Bytes
bfc0ec6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""Router for the signal registry."""

import math
from typing import Annotated, Any, Optional

from fastapi import APIRouter, Depends
from pydantic import BaseModel, validator

from .auth import UserInfo, get_session_user
from .router_utils import RouteErrorHandler, server_compute_concept
from .schema import Field, SignalInputType
from .signal import SIGNAL_REGISTRY, Signal, TextEmbeddingSignal, resolve_signal
from .signals.concept_scorer import ConceptSignal

router = APIRouter(route_class=RouteErrorHandler)

EMBEDDING_SORT_PRIORITIES = ['gte-small', 'gte-base', 'openai', 'sbert']


class SignalInfo(BaseModel):
  """Information about a signal."""
  name: str
  input_type: SignalInputType
  json_schema: dict[str, Any]


@router.get('/', response_model_exclude_none=True)
def get_signals() -> list[SignalInfo]:
  """List the signals."""
  return [
    SignalInfo(name=s.name, input_type=s.input_type, json_schema=s.schema())
    for s in SIGNAL_REGISTRY.values()
    if not issubclass(s, TextEmbeddingSignal)
  ]


@router.get('/embeddings', response_model_exclude_none=True)
def get_embeddings() -> list[SignalInfo]:
  """List the embeddings."""
  embedding_infos = [
    SignalInfo(name=s.name, input_type=s.input_type, json_schema=s.schema())
    for s in SIGNAL_REGISTRY.values()
    if issubclass(s, TextEmbeddingSignal)
  ]

  # Sort the embedding infos by priority.
  embedding_infos = sorted(
    embedding_infos,
    key=lambda s: EMBEDDING_SORT_PRIORITIES.index(s.name)
    if s.name in EMBEDDING_SORT_PRIORITIES else math.inf)

  return embedding_infos


class SignalComputeOptions(BaseModel):
  """The request for the standalone compute signal endpoint."""
  signal: Signal
  # The inputs to compute.
  inputs: list[str]

  @validator('signal', pre=True)
  def parse_signal(cls, signal: dict) -> Signal:
    """Parse a signal to its specific subclass instance."""
    return resolve_signal(signal)


class SignalComputeResponse(BaseModel):
  """The response for the standalone compute signal endpoint."""
  items: list[Optional[Any]]


@router.post('/compute', response_model_exclude_none=True)
def compute(
    options: SignalComputeOptions,
    user: Annotated[Optional[UserInfo], Depends(get_session_user)]) -> SignalComputeResponse:
  """Compute a signal over a set of inputs."""
  signal = options.signal
  if isinstance(signal, ConceptSignal):
    result = server_compute_concept(signal, options.inputs, user)
  else:
    signal.setup()
    result = list(signal.compute(options.inputs))
  return SignalComputeResponse(items=result)


class SignalSchemaOptions(BaseModel):
  """The request for the signal schema endpoint."""
  signal: Signal

  @validator('signal', pre=True)
  def parse_signal(cls, signal: dict) -> Signal:
    """Parse a signal to its specific subclass instance."""
    return resolve_signal(signal)


class SignalSchemaResponse(BaseModel):
  """The response for the signal schema endpoint."""
  fields: Field


@router.post('/schema', response_model_exclude_none=True)
def schema(options: SignalSchemaOptions) -> SignalSchemaResponse:
  """Get the schema for a signal."""
  signal = options.signal
  return SignalSchemaResponse(fields=signal.fields())