teddyllm's picture
Upload 20 files
bd3532f verified
from typing import List, Dict, Union, Optional
import tiktoken
from ..embedding_provider import EmbeddingProvider
import numpy as np
class OpenAIEmbedding(EmbeddingProvider):
def __init__(
self,
api_key: Optional[str] = None,
model: str = "text-embedding-3-small",
max_tokens: int = 8191
) -> None:
"""Initialize OpenAI embedding provider
Args:
model_name (str, optional): Name of the embedding model. Default to "text-embedding-3-small"
more info: https://platform.openai.com/docs/models#embeddings
api_key: api_key for OpenAI
"""
from openai import OpenAI
self.client = OpenAI(api_key=api_key)
self.model = model
self.max_tokens = max_tokens
self.tokenizer = tiktoken.encoding_for_model(model)
def _trancated_text(self, text: str) -> str:
"""Truncate text into maximum token length
Args:
text (str): Input text
Returns:
str: Truncated text
"""
tokens = self.tokenizer.encode(text)
truncated_tokens = tokens[:self.max_tokens]
return self.tokenizer.decode(truncated_tokens)
def embed_documents(
self,
documents: List[str],
batch_size: int = 100
) -> np.array:
"""Embed a list of documents
Args:
documents (List[str]): List of documents to embed
Returns:
np.array: embeddings of documents
"""
truncated_docs = [self._trancated_text(doc) for doc in documents]
embeddings = []
for i in range(0, len(truncated_docs), batch_size):
batch = truncated_docs[i: i+batch_size]
response = self.client.embeddings.create(
input=batch,
model=self.model
)
batch_embeddings = [
embed.embedding for embed in response.data
]
embeddings.extend(batch_embeddings)
return np.array(embeddings)
def embed_query(self, query):
truncated_query = self._trancated_text(query)
response = self.client.embeddings.create(
input=[truncated_query],
model=self.model
)
return np.array(response.data[0].embedding)
def get_embedding_info(self) -> Dict[str, Union[str, int]]:
"""
Get information about the current embedding configuration
Returns:
Dict: Embedding configuration details
"""
return {
"model": self.model,
"max_tokens": self.max_tokens,
"batch_size": 100, # Default batch size
}
def list_available_models(self) -> List[str]:
"""
List available OpenAI embedding models
Returns:
List[str]: Available embedding model names
"""
return [
"text-embedding-ada-002", # Most common
"text-embedding-3-small", # Newer, more efficient
"text-embedding-3-large" # Highest quality
]
def estimate_cost(self, num_documents: int) -> float:
"""
Estimate embedding cost
Args:
num_documents (int): Number of documents to embed
Returns:
float: Estimated cost in USD
"""
# Pricing as of 2024 (subject to change)
pricing = {
"text-embedding-ada-002": 0.0001 / 1000, # $0.0001 per 1000 tokens
"text-embedding-3-small": 0.00006 / 1000,
"text-embedding-3-large": 0.00013 / 1000
}
# Estimate tokens (assuming ~100 tokens per document)
total_tokens = num_documents * 100
return total_tokens * pricing.get(self.model, pricing["text-embedding-ada-002"])