Spaces:
Running
Running
import asyncio | |
import base64 | |
from typing import Union | |
import numpy as np | |
import tiktoken | |
from fastapi import APIRouter, Depends | |
from openai import AsyncOpenAI | |
from openai.types.create_embedding_response import Usage | |
from sentence_transformers import SentenceTransformer | |
from api.config import SETTINGS | |
from api.models import EMBEDDED_MODEL | |
from api.utils.protocol import EmbeddingCreateParams, Embedding, CreateEmbeddingResponse | |
from api.utils.request import check_api_key | |
embedding_router = APIRouter() | |
def get_embedding_engine(): | |
yield EMBEDDED_MODEL | |
async def create_embeddings( | |
request: EmbeddingCreateParams, | |
model_name: str = None, | |
client: Union[SentenceTransformer, AsyncOpenAI] = Depends(get_embedding_engine), | |
): | |
"""Creates embeddings for the text""" | |
if request.model is None: | |
request.model = model_name | |
request.input = request.input | |
if isinstance(request.input, str): | |
request.input = [request.input] | |
elif isinstance(request.input, list): | |
if isinstance(request.input[0], int): | |
decoding = tiktoken.model.encoding_for_model(request.model) | |
request.input = [decoding.decode(request.input)] | |
elif isinstance(request.input[0], list): | |
decoding = tiktoken.model.encoding_for_model(request.model) | |
request.input = [decoding.decode(text) for text in request.input] | |
data, total_tokens = [], 0 | |
# support for tei: https://github.com/huggingface/text-embeddings-inference | |
if isinstance(client, AsyncOpenAI): | |
global_batch_size = SETTINGS.max_concurrent_requests * SETTINGS.max_client_batch_size | |
for i in range(0, len(request.input), global_batch_size): | |
tasks = [] | |
texts = request.input[i: i + global_batch_size] | |
for j in range(0, len(texts), SETTINGS.max_client_batch_size): | |
tasks.append( | |
client.embeddings.create( | |
input=[text[:510] for text in texts[j: j + SETTINGS.max_client_batch_size]], | |
model=request.model, | |
) | |
) | |
res = await asyncio.gather(*tasks) | |
vecs = np.asarray([e.embedding for r in res for e in r.data]) | |
bs, dim = vecs.shape | |
if SETTINGS.embedding_size > dim: | |
zeros = np.zeros((bs, SETTINGS.embedding_size - dim)) | |
vecs = np.c_[vecs, zeros] | |
if request.encoding_format == "base64": | |
vecs = [base64.b64encode(v.tobytes()).decode("utf-8") for v in vecs] | |
else: | |
vecs = vecs.tolist() | |
data.extend( | |
Embedding( | |
index=i * global_batch_size + j, | |
object="embedding", | |
embedding=embed | |
) | |
for j, embed in enumerate(vecs) | |
) | |
total_tokens += sum(r.usage.total_tokens for r in res) | |
else: | |
batches = [request.input[i: i + 1024] for i in range(0, len(request.input), 1024)] | |
for num_batch, batch in enumerate(batches): | |
token_num = sum(len(i) for i in batch) | |
vecs = client.encode(batch, normalize_embeddings=True) | |
bs, dim = vecs.shape | |
if SETTINGS.embedding_size > dim: | |
zeros = np.zeros((bs, SETTINGS.embedding_size - dim)) | |
vecs = np.c_[vecs, zeros] | |
if request.encoding_format == "base64": | |
vecs = [base64.b64encode(v.tobytes()).decode("utf-8") for v in vecs] | |
else: | |
vecs = vecs.tolist() | |
data.extend( | |
Embedding( | |
index=num_batch * 1024 + i, | |
object="embedding", | |
embedding=embedding, | |
) | |
for i, embedding in enumerate(vecs) | |
) | |
total_tokens += token_num | |
return CreateEmbeddingResponse( | |
data=data, | |
model=request.model, | |
object="list", | |
usage=Usage(prompt_tokens=total_tokens, total_tokens=total_tokens), | |
) | |