from typing import List, Dict, Union, Optional import tiktoken from ..embedding_provider import EmbeddingProvider import numpy as np class OpenAIEmbedding(EmbeddingProvider): def __init__( self, api_key: Optional[str] = None, model: str = "text-embedding-3-small", max_tokens: int = 8191 ) -> None: """Initialize OpenAI embedding provider Args: model_name (str, optional): Name of the embedding model. Default to "text-embedding-3-small" more info: https://platform.openai.com/docs/models#embeddings api_key: api_key for OpenAI """ from openai import OpenAI self.client = OpenAI(api_key=api_key) self.model = model self.max_tokens = max_tokens self.tokenizer = tiktoken.encoding_for_model(model) def _trancated_text(self, text: str) -> str: """Truncate text into maximum token length Args: text (str): Input text Returns: str: Truncated text """ tokens = self.tokenizer.encode(text) truncated_tokens = tokens[:self.max_tokens] return self.tokenizer.decode(truncated_tokens) def embed_documents( self, documents: List[str], batch_size: int = 100 ) -> np.array: """Embed a list of documents Args: documents (List[str]): List of documents to embed Returns: np.array: embeddings of documents """ truncated_docs = [self._trancated_text(doc) for doc in documents] embeddings = [] for i in range(0, len(truncated_docs), batch_size): batch = truncated_docs[i: i+batch_size] response = self.client.embeddings.create( input=batch, model=self.model ) batch_embeddings = [ embed.embedding for embed in response.data ] embeddings.extend(batch_embeddings) return np.array(embeddings) def embed_query(self, query): truncated_query = self._trancated_text(query) response = self.client.embeddings.create( input=[truncated_query], model=self.model ) return np.array(response.data[0].embedding) def get_embedding_info(self) -> Dict[str, Union[str, int]]: """ Get information about the current embedding configuration Returns: Dict: Embedding configuration details """ return { "model": self.model, "max_tokens": self.max_tokens, "batch_size": 100, # Default batch size } def list_available_models(self) -> List[str]: """ List available OpenAI embedding models Returns: List[str]: Available embedding model names """ return [ "text-embedding-ada-002", # Most common "text-embedding-3-small", # Newer, more efficient "text-embedding-3-large" # Highest quality ] def estimate_cost(self, num_documents: int) -> float: """ Estimate embedding cost Args: num_documents (int): Number of documents to embed Returns: float: Estimated cost in USD """ # Pricing as of 2024 (subject to change) pricing = { "text-embedding-ada-002": 0.0001 / 1000, # $0.0001 per 1000 tokens "text-embedding-3-small": 0.00006 / 1000, "text-embedding-3-large": 0.00013 / 1000 } # Estimate tokens (assuming ~100 tokens per document) total_tokens = num_documents * 100 return total_tokens * pricing.get(self.model, pricing["text-embedding-ada-002"])