import asyncio import base64 from typing import Union import numpy as np import tiktoken from fastapi import APIRouter, Depends from openai import AsyncOpenAI from openai.types.create_embedding_response import Usage from sentence_transformers import SentenceTransformer from api.config import SETTINGS from api.models import EMBEDDED_MODEL from api.utils.protocol import EmbeddingCreateParams, Embedding, CreateEmbeddingResponse from api.utils.request import check_api_key embedding_router = APIRouter() def get_embedding_engine(): yield EMBEDDED_MODEL @embedding_router.post("/embeddings", dependencies=[Depends(check_api_key)]) @embedding_router.post("/engines/{model_name}/embeddings", dependencies=[Depends(check_api_key)]) async def create_embeddings( request: EmbeddingCreateParams, model_name: str = None, client: Union[SentenceTransformer, AsyncOpenAI] = Depends(get_embedding_engine), ): """Creates embeddings for the text""" if request.model is None: request.model = model_name request.input = request.input if isinstance(request.input, str): request.input = [request.input] elif isinstance(request.input, list): if isinstance(request.input[0], int): decoding = tiktoken.model.encoding_for_model(request.model) request.input = [decoding.decode(request.input)] elif isinstance(request.input[0], list): decoding = tiktoken.model.encoding_for_model(request.model) request.input = [decoding.decode(text) for text in request.input] data, total_tokens = [], 0 # support for tei: https://github.com/huggingface/text-embeddings-inference if isinstance(client, AsyncOpenAI): global_batch_size = SETTINGS.max_concurrent_requests * SETTINGS.max_client_batch_size for i in range(0, len(request.input), global_batch_size): tasks = [] texts = request.input[i: i + global_batch_size] for j in range(0, len(texts), SETTINGS.max_client_batch_size): tasks.append( client.embeddings.create( input=[text[:510] for text in texts[j: j + SETTINGS.max_client_batch_size]], model=request.model, ) ) res = await asyncio.gather(*tasks) vecs = np.asarray([e.embedding for r in res for e in r.data]) bs, dim = vecs.shape if SETTINGS.embedding_size > dim: zeros = np.zeros((bs, SETTINGS.embedding_size - dim)) vecs = np.c_[vecs, zeros] if request.encoding_format == "base64": vecs = [base64.b64encode(v.tobytes()).decode("utf-8") for v in vecs] else: vecs = vecs.tolist() data.extend( Embedding( index=i * global_batch_size + j, object="embedding", embedding=embed ) for j, embed in enumerate(vecs) ) total_tokens += sum(r.usage.total_tokens for r in res) else: batches = [request.input[i: i + 1024] for i in range(0, len(request.input), 1024)] for num_batch, batch in enumerate(batches): token_num = sum(len(i) for i in batch) vecs = client.encode(batch, normalize_embeddings=True) bs, dim = vecs.shape if SETTINGS.embedding_size > dim: zeros = np.zeros((bs, SETTINGS.embedding_size - dim)) vecs = np.c_[vecs, zeros] if request.encoding_format == "base64": vecs = [base64.b64encode(v.tobytes()).decode("utf-8") for v in vecs] else: vecs = vecs.tolist() data.extend( Embedding( index=num_batch * 1024 + i, object="embedding", embedding=embedding, ) for i, embedding in enumerate(vecs) ) total_tokens += token_num return CreateEmbeddingResponse( data=data, model=request.model, object="list", usage=Usage(prompt_tokens=total_tokens, total_tokens=total_tokens), )