gordonchan's picture
Upload 41 files
ca56e6a verified
raw
history blame
4.28 kB
import asyncio
import base64
from typing import Union
import numpy as np
import tiktoken
from fastapi import APIRouter, Depends
from openai import AsyncOpenAI
from openai.types.create_embedding_response import Usage
from sentence_transformers import SentenceTransformer
from api.config import SETTINGS
from api.models import EMBEDDED_MODEL
from api.utils.protocol import EmbeddingCreateParams, Embedding, CreateEmbeddingResponse
from api.utils.request import check_api_key
embedding_router = APIRouter()
def get_embedding_engine():
yield EMBEDDED_MODEL
@embedding_router.post("/embeddings", dependencies=[Depends(check_api_key)])
@embedding_router.post("/engines/{model_name}/embeddings", dependencies=[Depends(check_api_key)])
async def create_embeddings(
request: EmbeddingCreateParams,
model_name: str = None,
client: Union[SentenceTransformer, AsyncOpenAI] = Depends(get_embedding_engine),
):
"""Creates embeddings for the text"""
if request.model is None:
request.model = model_name
request.input = request.input
if isinstance(request.input, str):
request.input = [request.input]
elif isinstance(request.input, list):
if isinstance(request.input[0], int):
decoding = tiktoken.model.encoding_for_model(request.model)
request.input = [decoding.decode(request.input)]
elif isinstance(request.input[0], list):
decoding = tiktoken.model.encoding_for_model(request.model)
request.input = [decoding.decode(text) for text in request.input]
data, total_tokens = [], 0
# support for tei: https://github.com/huggingface/text-embeddings-inference
if isinstance(client, AsyncOpenAI):
global_batch_size = SETTINGS.max_concurrent_requests * SETTINGS.max_client_batch_size
for i in range(0, len(request.input), global_batch_size):
tasks = []
texts = request.input[i: i + global_batch_size]
for j in range(0, len(texts), SETTINGS.max_client_batch_size):
tasks.append(
client.embeddings.create(
input=[text[:510] for text in texts[j: j + SETTINGS.max_client_batch_size]],
model=request.model,
)
)
res = await asyncio.gather(*tasks)
vecs = np.asarray([e.embedding for r in res for e in r.data])
bs, dim = vecs.shape
if SETTINGS.embedding_size > dim:
zeros = np.zeros((bs, SETTINGS.embedding_size - dim))
vecs = np.c_[vecs, zeros]
if request.encoding_format == "base64":
vecs = [base64.b64encode(v.tobytes()).decode("utf-8") for v in vecs]
else:
vecs = vecs.tolist()
data.extend(
Embedding(
index=i * global_batch_size + j,
object="embedding",
embedding=embed
)
for j, embed in enumerate(vecs)
)
total_tokens += sum(r.usage.total_tokens for r in res)
else:
batches = [request.input[i: i + 1024] for i in range(0, len(request.input), 1024)]
for num_batch, batch in enumerate(batches):
token_num = sum(len(i) for i in batch)
vecs = client.encode(batch, normalize_embeddings=True)
bs, dim = vecs.shape
if SETTINGS.embedding_size > dim:
zeros = np.zeros((bs, SETTINGS.embedding_size - dim))
vecs = np.c_[vecs, zeros]
if request.encoding_format == "base64":
vecs = [base64.b64encode(v.tobytes()).decode("utf-8") for v in vecs]
else:
vecs = vecs.tolist()
data.extend(
Embedding(
index=num_batch * 1024 + i,
object="embedding",
embedding=embedding,
)
for i, embedding in enumerate(vecs)
)
total_tokens += token_num
return CreateEmbeddingResponse(
data=data,
model=request.model,
object="list",
usage=Usage(prompt_tokens=total_tokens, total_tokens=total_tokens),
)