Spaces:
Runtime error
Runtime error
File size: 3,199 Bytes
bfc0ec6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
"""Router for the signal registry."""
import math
from typing import Annotated, Any, Optional
from fastapi import APIRouter, Depends
from pydantic import BaseModel, validator
from .auth import UserInfo, get_session_user
from .router_utils import RouteErrorHandler, server_compute_concept
from .schema import Field, SignalInputType
from .signal import SIGNAL_REGISTRY, Signal, TextEmbeddingSignal, resolve_signal
from .signals.concept_scorer import ConceptSignal
router = APIRouter(route_class=RouteErrorHandler)
EMBEDDING_SORT_PRIORITIES = ['gte-small', 'gte-base', 'openai', 'sbert']
class SignalInfo(BaseModel):
"""Information about a signal."""
name: str
input_type: SignalInputType
json_schema: dict[str, Any]
@router.get('/', response_model_exclude_none=True)
def get_signals() -> list[SignalInfo]:
"""List the signals."""
return [
SignalInfo(name=s.name, input_type=s.input_type, json_schema=s.schema())
for s in SIGNAL_REGISTRY.values()
if not issubclass(s, TextEmbeddingSignal)
]
@router.get('/embeddings', response_model_exclude_none=True)
def get_embeddings() -> list[SignalInfo]:
"""List the embeddings."""
embedding_infos = [
SignalInfo(name=s.name, input_type=s.input_type, json_schema=s.schema())
for s in SIGNAL_REGISTRY.values()
if issubclass(s, TextEmbeddingSignal)
]
# Sort the embedding infos by priority.
embedding_infos = sorted(
embedding_infos,
key=lambda s: EMBEDDING_SORT_PRIORITIES.index(s.name)
if s.name in EMBEDDING_SORT_PRIORITIES else math.inf)
return embedding_infos
class SignalComputeOptions(BaseModel):
"""The request for the standalone compute signal endpoint."""
signal: Signal
# The inputs to compute.
inputs: list[str]
@validator('signal', pre=True)
def parse_signal(cls, signal: dict) -> Signal:
"""Parse a signal to its specific subclass instance."""
return resolve_signal(signal)
class SignalComputeResponse(BaseModel):
"""The response for the standalone compute signal endpoint."""
items: list[Optional[Any]]
@router.post('/compute', response_model_exclude_none=True)
def compute(
options: SignalComputeOptions,
user: Annotated[Optional[UserInfo], Depends(get_session_user)]) -> SignalComputeResponse:
"""Compute a signal over a set of inputs."""
signal = options.signal
if isinstance(signal, ConceptSignal):
result = server_compute_concept(signal, options.inputs, user)
else:
signal.setup()
result = list(signal.compute(options.inputs))
return SignalComputeResponse(items=result)
class SignalSchemaOptions(BaseModel):
"""The request for the signal schema endpoint."""
signal: Signal
@validator('signal', pre=True)
def parse_signal(cls, signal: dict) -> Signal:
"""Parse a signal to its specific subclass instance."""
return resolve_signal(signal)
class SignalSchemaResponse(BaseModel):
"""The response for the signal schema endpoint."""
fields: Field
@router.post('/schema', response_model_exclude_none=True)
def schema(options: SignalSchemaOptions) -> SignalSchemaResponse:
"""Get the schema for a signal."""
signal = options.signal
return SignalSchemaResponse(fields=signal.fields())
|