Spaces:
Runtime error
Runtime error
File size: 5,925 Bytes
e4f9cbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
"""Router for the concept database."""
from typing import Optional
import openai
from fastapi import APIRouter, HTTPException
from openai_function_call import OpenAISchema
from pydantic import BaseModel, Field
from .concepts.concept import DRAFT_MAIN, Concept, ConceptModel, DraftId, draft_examples
from .concepts.db_concept import DISK_CONCEPT_DB, DISK_CONCEPT_MODEL_DB, ConceptInfo, ConceptUpdate
from .config import CONFIG
from .router_utils import RouteErrorHandler
from .schema import SignalInputType
router = APIRouter(route_class=RouteErrorHandler)
@router.get('/', response_model_exclude_none=True)
def get_concepts() -> list[ConceptInfo]:
"""List the concepts."""
return DISK_CONCEPT_DB.list()
@router.get('/{namespace}/{concept_name}', response_model_exclude_none=True)
def get_concept(namespace: str,
concept_name: str,
draft: Optional[DraftId] = DRAFT_MAIN) -> Concept:
"""Get a concept from a database."""
concept = DISK_CONCEPT_DB.get(namespace, concept_name)
if not concept:
raise HTTPException(
status_code=404, detail=f'Concept "{namespace}/{concept_name}" was not found')
# Only return the examples from the draft.
concept.data = draft_examples(concept, draft or DRAFT_MAIN)
return concept
class CreateConceptOptions(BaseModel):
"""Options for creating a concept."""
# Namespace of the concept.
namespace: str
# Name of the concept.
name: str
# Input type (modality) of the concept.
type: SignalInputType
@router.post('/create', response_model_exclude_none=True)
def create_concept(options: CreateConceptOptions) -> Concept:
"""Edit a concept in the database."""
return DISK_CONCEPT_DB.create(options.namespace, options.name, options.type)
@router.post('/{namespace}/{concept_name}', response_model_exclude_none=True)
def edit_concept(namespace: str, concept_name: str, change: ConceptUpdate) -> Concept:
"""Edit a concept in the database."""
return DISK_CONCEPT_DB.edit(namespace, concept_name, change)
@router.delete('/{namespace}/{concept_name}')
def delete_concept(namespace: str, concept_name: str) -> None:
"""Deletes the concept from the database."""
DISK_CONCEPT_DB.remove(namespace, concept_name)
# Delete concept models from all datasets that are using this concept.
DISK_CONCEPT_MODEL_DB.remove_all(namespace, concept_name)
class MergeConceptDraftOptions(BaseModel):
"""Merge a draft into main."""
draft: DraftId
@router.post('/{namespace}/{concept_name}/merge_draft', response_model_exclude_none=True)
def merge_concept_draft(namespace: str, concept_name: str,
options: MergeConceptDraftOptions) -> Concept:
"""Merge a draft in the concept into main."""
return DISK_CONCEPT_DB.merge_draft(namespace, concept_name, options.draft)
class ScoreExample(BaseModel):
"""Example to score along a specific concept."""
text: Optional[str]
img: Optional[bytes]
class ScoreBody(BaseModel):
"""Request body for the score endpoint."""
examples: list[ScoreExample]
draft: str = DRAFT_MAIN
class ScoreResponse(BaseModel):
"""Response body for the score endpoint."""
scores: list[float]
model_synced: bool
class ConceptModelResponse(BaseModel):
"""Response body for the get_concept_model endpoint."""
model: ConceptModel
model_synced: bool
@router.get('/{namespace}/{concept_name}/{embedding_name}')
def get_concept_model(namespace: str,
concept_name: str,
embedding_name: str,
sync_model: bool = False) -> ConceptModelResponse:
"""Get a concept model from a database."""
concept = DISK_CONCEPT_DB.get(namespace, concept_name)
if not concept:
raise HTTPException(
status_code=404, detail=f'Concept "{namespace}/{concept_name}" was not found')
model = DISK_CONCEPT_MODEL_DB.get(namespace, concept_name, embedding_name)
if not model:
model = DISK_CONCEPT_MODEL_DB.create(namespace, concept_name, embedding_name)
if sync_model:
model_synced = DISK_CONCEPT_MODEL_DB.sync(model, column_info=None)
else:
model_synced = DISK_CONCEPT_MODEL_DB.in_sync(model)
return ConceptModelResponse(model=model, model_synced=model_synced)
@router.post('/{namespace}/{concept_name}/{embedding_name}/score', response_model_exclude_none=True)
def score(namespace: str, concept_name: str, embedding_name: str, body: ScoreBody) -> ScoreResponse:
"""Score examples along the specified concept."""
concept = DISK_CONCEPT_DB.get(namespace, concept_name)
if not concept:
raise HTTPException(
status_code=404, detail=f'Concept "{namespace}/{concept_name}" was not found')
model = DISK_CONCEPT_MODEL_DB.get(namespace, concept_name, embedding_name)
if model is None:
model = DISK_CONCEPT_MODEL_DB.create(namespace, concept_name, embedding_name)
models_updated = DISK_CONCEPT_MODEL_DB.sync(model, column_info=None)
# TODO(smilkov): Support images.
texts = [example.text or '' for example in body.examples]
return ScoreResponse(scores=model.score(body.draft, texts), model_synced=models_updated)
class Examples(OpenAISchema):
"""Generated text examples."""
examples: list[str] = Field(..., description='List of generated examples')
@router.get('/generate_examples')
def generate_examples(description: str) -> list[str]:
"""Generate positive examples for a given concept using an LLM model."""
openai.api_key = CONFIG['OPENAI_API_KEY']
completion = openai.ChatCompletion.create(
model='gpt-3.5-turbo-0613',
functions=[Examples.openai_schema],
messages=[
{
'role': 'system',
'content': 'You must call the `Examples` function with the generated sentences.',
},
{
'role': 'user',
'content': f'Give me 5 diverse examples of sentences that demonstrate "{description}"',
},
],
)
result = Examples.from_response(completion)
return result.examples
|