Spaces:
Runtime error
Runtime error
"""Router for the concept database.""" | |
from typing import Optional | |
import openai | |
from fastapi import APIRouter, HTTPException | |
from openai_function_call import OpenAISchema | |
from pydantic import BaseModel, Field | |
from .auth import get_user_access | |
from .concepts.concept import ( | |
DRAFT_MAIN, | |
Concept, | |
ConceptColumnInfo, | |
ConceptMetrics, | |
DraftId, | |
draft_examples, | |
) | |
from .concepts.db_concept import DISK_CONCEPT_DB, DISK_CONCEPT_MODEL_DB, ConceptInfo, ConceptUpdate | |
from .config import CONFIG | |
from .router_utils import RouteErrorHandler | |
from .schema import SignalInputType | |
router = APIRouter(route_class=RouteErrorHandler) | |
def get_concepts() -> list[ConceptInfo]: | |
"""List the concepts.""" | |
return DISK_CONCEPT_DB.list() | |
def get_concept(namespace: str, | |
concept_name: str, | |
draft: Optional[DraftId] = DRAFT_MAIN) -> Concept: | |
"""Get a concept from a database.""" | |
concept = DISK_CONCEPT_DB.get(namespace, concept_name) | |
if not concept: | |
raise HTTPException( | |
status_code=404, detail=f'Concept "{namespace}/{concept_name}" was not found') | |
# Only return the examples from the draft. | |
concept.data = draft_examples(concept, draft or DRAFT_MAIN) | |
return concept | |
def get_concept_column_infos(namespace: str, concept_name: str) -> list[ConceptColumnInfo]: | |
"""Return a list of dataset columns where this concept was applied to.""" | |
concept = DISK_CONCEPT_DB.get(namespace, concept_name) | |
if not concept: | |
raise HTTPException( | |
status_code=404, detail=f'Concept "{namespace}/{concept_name}" was not found') | |
return DISK_CONCEPT_MODEL_DB.get_column_infos(namespace, concept_name) | |
class CreateConceptOptions(BaseModel): | |
"""Options for creating a concept.""" | |
# Namespace of the concept. | |
namespace: str | |
# Name of the concept. | |
name: str | |
# Input type (modality) of the concept. | |
type: SignalInputType | |
description: Optional[str] = None | |
def create_concept(options: CreateConceptOptions) -> Concept: | |
"""Edit a concept in the database.""" | |
return DISK_CONCEPT_DB.create(options.namespace, options.name, options.type, options.description) | |
def edit_concept(namespace: str, concept_name: str, change: ConceptUpdate) -> Concept: | |
"""Edit a concept in the database.""" | |
return DISK_CONCEPT_DB.edit(namespace, concept_name, change) | |
def delete_concept(namespace: str, concept_name: str) -> None: | |
"""Deletes the concept from the database.""" | |
if not get_user_access().concept.delete_any_concept: | |
raise HTTPException(401, 'User does not have access to delete this concept.') | |
DISK_CONCEPT_DB.remove(namespace, concept_name) | |
# Delete concept models from all datasets that are using this concept. | |
DISK_CONCEPT_MODEL_DB.remove_all(namespace, concept_name) | |
class MergeConceptDraftOptions(BaseModel): | |
"""Merge a draft into main.""" | |
draft: DraftId | |
def merge_concept_draft(namespace: str, concept_name: str, | |
options: MergeConceptDraftOptions) -> Concept: | |
"""Merge a draft in the concept into main.""" | |
return DISK_CONCEPT_DB.merge_draft(namespace, concept_name, options.draft) | |
class ScoreExample(BaseModel): | |
"""Example to score along a specific concept.""" | |
text: Optional[str] | |
img: Optional[bytes] | |
class ScoreBody(BaseModel): | |
"""Request body for the score endpoint.""" | |
examples: list[ScoreExample] | |
draft: str = DRAFT_MAIN | |
class ScoreResponse(BaseModel): | |
"""Response body for the score endpoint.""" | |
scores: list[dict] | |
model_synced: bool | |
class ConceptModelInfo(BaseModel): | |
"""Information about a concept model.""" | |
namespace: str | |
concept_name: str | |
embedding_name: str | |
version: int | |
column_info: Optional[ConceptColumnInfo] = None | |
metrics: Optional[ConceptMetrics] = None | |
def get_concept_models(namespace: str, concept_name: str) -> list[ConceptModelInfo]: | |
"""Get a concept model from a database.""" | |
concept = DISK_CONCEPT_DB.get(namespace, concept_name) | |
if not concept: | |
raise HTTPException( | |
status_code=404, detail=f'Concept "{namespace}/{concept_name}" was not found') | |
models = DISK_CONCEPT_MODEL_DB.get_models(namespace, concept_name) | |
for m in models: | |
DISK_CONCEPT_MODEL_DB.sync(m) | |
return [ | |
ConceptModelInfo( | |
namespace=m.namespace, | |
concept_name=m.concept_name, | |
embedding_name=m.embedding_name, | |
version=m.version, | |
column_info=m.column_info, | |
metrics=m.get_metrics(concept)) for m in models | |
] | |
def get_concept_model(namespace: str, concept_name: str, embedding_name: str) -> ConceptModelInfo: | |
"""Get a concept model from a database.""" | |
concept = DISK_CONCEPT_DB.get(namespace, concept_name) | |
if not concept: | |
raise HTTPException( | |
status_code=404, detail=f'Concept "{namespace}/{concept_name}" was not found') | |
model = DISK_CONCEPT_MODEL_DB.get(namespace, concept_name, embedding_name) | |
if not model: | |
model = DISK_CONCEPT_MODEL_DB.create(namespace, concept_name, embedding_name) | |
model_synced = DISK_CONCEPT_MODEL_DB.sync(model) | |
model_info = ConceptModelInfo( | |
namespace=model.namespace, | |
concept_name=model.concept_name, | |
embedding_name=model.embedding_name, | |
version=model.version, | |
column_info=model.column_info, | |
metrics=model.get_metrics(concept)) | |
return model_info | |
class MetricsBody(BaseModel): | |
"""Request body for the compute_metrics endpoint.""" | |
column_info: Optional[ConceptColumnInfo] = None | |
def score(namespace: str, concept_name: str, embedding_name: str, body: ScoreBody) -> ScoreResponse: | |
"""Score examples along the specified concept.""" | |
concept = DISK_CONCEPT_DB.get(namespace, concept_name) | |
if not concept: | |
raise HTTPException( | |
status_code=404, detail=f'Concept "{namespace}/{concept_name}" was not found') | |
model = DISK_CONCEPT_MODEL_DB.get(namespace, concept_name, embedding_name) | |
if model is None: | |
model = DISK_CONCEPT_MODEL_DB.create(namespace, concept_name, embedding_name) | |
model_updated = DISK_CONCEPT_MODEL_DB.sync(model) | |
# TODO(smilkov): Support images. | |
texts = [example.text or '' for example in body.examples] | |
return ScoreResponse(scores=model.score(body.draft, texts), model_synced=model_updated) | |
class Examples(OpenAISchema): | |
"""Generated text examples.""" | |
examples: list[str] = Field(..., description='List of generated examples') | |
def generate_examples(description: str) -> list[str]: | |
"""Generate positive examples for a given concept using an LLM model.""" | |
openai.api_key = CONFIG['OPENAI_API_KEY'] | |
completion = openai.ChatCompletion.create( | |
model='gpt-3.5-turbo-0613', | |
functions=[Examples.openai_schema], | |
messages=[ | |
{ | |
'role': 'system', | |
'content': 'You must call the `Examples` function with the generated examples', | |
}, | |
{ | |
'role': 'user', | |
'content': f'Write 5 diverse, unnumbered, and concise examples of "{description}"', | |
}, | |
], | |
) | |
result = Examples.from_response(completion) | |
return result.examples | |