Spaces:
Sleeping
Sleeping
from typing import List, Dict, Union, Optional | |
import tiktoken | |
from ..embedding_provider import EmbeddingProvider | |
import numpy as np | |
class OpenAIEmbedding(EmbeddingProvider): | |
def __init__( | |
self, | |
api_key: Optional[str] = None, | |
model: str = "text-embedding-3-small", | |
max_tokens: int = 8191 | |
) -> None: | |
"""Initialize OpenAI embedding provider | |
Args: | |
model_name (str, optional): Name of the embedding model. Default to "text-embedding-3-small" | |
more info: https://platform.openai.com/docs/models#embeddings | |
api_key: api_key for OpenAI | |
""" | |
from openai import OpenAI | |
self.client = OpenAI(api_key=api_key) | |
self.model = model | |
self.max_tokens = max_tokens | |
self.tokenizer = tiktoken.encoding_for_model(model) | |
def _trancated_text(self, text: str) -> str: | |
"""Truncate text into maximum token length | |
Args: | |
text (str): Input text | |
Returns: | |
str: Truncated text | |
""" | |
tokens = self.tokenizer.encode(text) | |
truncated_tokens = tokens[:self.max_tokens] | |
return self.tokenizer.decode(truncated_tokens) | |
def embed_documents( | |
self, | |
documents: List[str], | |
batch_size: int = 100 | |
) -> np.array: | |
"""Embed a list of documents | |
Args: | |
documents (List[str]): List of documents to embed | |
Returns: | |
np.array: embeddings of documents | |
""" | |
truncated_docs = [self._trancated_text(doc) for doc in documents] | |
embeddings = [] | |
for i in range(0, len(truncated_docs), batch_size): | |
batch = truncated_docs[i: i+batch_size] | |
response = self.client.embeddings.create( | |
input=batch, | |
model=self.model | |
) | |
batch_embeddings = [ | |
embed.embedding for embed in response.data | |
] | |
embeddings.extend(batch_embeddings) | |
return np.array(embeddings) | |
def embed_query(self, query): | |
truncated_query = self._trancated_text(query) | |
response = self.client.embeddings.create( | |
input=[truncated_query], | |
model=self.model | |
) | |
return np.array(response.data[0].embedding) | |
def get_embedding_info(self) -> Dict[str, Union[str, int]]: | |
""" | |
Get information about the current embedding configuration | |
Returns: | |
Dict: Embedding configuration details | |
""" | |
return { | |
"model": self.model, | |
"max_tokens": self.max_tokens, | |
"batch_size": 100, # Default batch size | |
} | |
def list_available_models(self) -> List[str]: | |
""" | |
List available OpenAI embedding models | |
Returns: | |
List[str]: Available embedding model names | |
""" | |
return [ | |
"text-embedding-ada-002", # Most common | |
"text-embedding-3-small", # Newer, more efficient | |
"text-embedding-3-large" # Highest quality | |
] | |
def estimate_cost(self, num_documents: int) -> float: | |
""" | |
Estimate embedding cost | |
Args: | |
num_documents (int): Number of documents to embed | |
Returns: | |
float: Estimated cost in USD | |
""" | |
# Pricing as of 2024 (subject to change) | |
pricing = { | |
"text-embedding-ada-002": 0.0001 / 1000, # $0.0001 per 1000 tokens | |
"text-embedding-3-small": 0.00006 / 1000, | |
"text-embedding-3-large": 0.00013 / 1000 | |
} | |
# Estimate tokens (assuming ~100 tokens per document) | |
total_tokens = num_documents * 100 | |
return total_tokens * pricing.get(self.model, pricing["text-embedding-ada-002"]) |