Spaces:
Runtime error
Runtime error
"""Wrapper around Together AI's Embeddings API.""" | |
import logging | |
import os | |
import warnings | |
from typing import ( | |
Any, | |
Dict, | |
List, | |
Literal, | |
Mapping, | |
Optional, | |
Sequence, | |
Set, | |
Tuple, | |
Union, | |
) | |
import openai | |
from langchain_core.embeddings import Embeddings | |
from langchain_core.pydantic_v1 import ( | |
BaseModel, | |
Extra, | |
Field, | |
SecretStr, | |
root_validator, | |
) | |
from langchain_core.utils import ( | |
convert_to_secret_str, | |
get_from_dict_or_env, | |
get_pydantic_field_names, | |
) | |
logger = logging.getLogger(__name__) | |
class TogetherEmbeddings(BaseModel, Embeddings): | |
"""TogetherEmbeddings embedding model. | |
To use, set the environment variable `TOGETHER_API_KEY` with your API key or | |
pass it as a named parameter to the constructor. | |
Example: | |
.. code-block:: python | |
from langchain_together import TogetherEmbeddings | |
model = TogetherEmbeddings() | |
""" | |
client: Any = Field(default=None, exclude=True) #: :meta private: | |
async_client: Any = Field(default=None, exclude=True) #: :meta private: | |
model: str = "togethercomputer/m2-bert-80M-8k-retrieval" | |
"""Embeddings model name to use. | |
Instead, use 'togethercomputer/m2-bert-80M-8k-retrieval' for example. | |
""" | |
dimensions: Optional[int] = None | |
"""The number of dimensions the resulting output embeddings should have. | |
Not yet supported. | |
""" | |
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") | |
"""API Key for Solar API.""" | |
together_api_base: str = Field( | |
default="https://api.together.ai/v1/", alias="base_url" | |
) | |
"""Endpoint URL to use.""" | |
embedding_ctx_length: int = 4096 | |
"""The maximum number of tokens to embed at once. | |
Not yet supported. | |
""" | |
allowed_special: Union[Literal["all"], Set[str]] = set() | |
"""Not yet supported.""" | |
disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all" | |
"""Not yet supported.""" | |
chunk_size: int = 1000 | |
"""Maximum number of texts to embed in each batch. | |
Not yet supported. | |
""" | |
max_retries: int = 2 | |
"""Maximum number of retries to make when generating.""" | |
request_timeout: Optional[Union[float, Tuple[float, float], Any]] = Field( | |
default=None, alias="timeout" | |
) | |
"""Timeout for requests to Together embedding API. Can be float, httpx.Timeout or | |
None.""" | |
show_progress_bar: bool = False | |
"""Whether to show a progress bar when embedding. | |
Not yet supported. | |
""" | |
model_kwargs: Dict[str, Any] = Field(default_factory=dict) | |
"""Holds any model parameters valid for `create` call not explicitly specified.""" | |
skip_empty: bool = False | |
"""Whether to skip empty strings when embedding or raise an error. | |
Defaults to not skipping. | |
Not yet supported.""" | |
default_headers: Union[Mapping[str, str], None] = None | |
default_query: Union[Mapping[str, object], None] = None | |
# Configure a custom httpx client. See the | |
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details. | |
http_client: Union[Any, None] = None | |
"""Optional httpx.Client. Only used for sync invocations. Must specify | |
http_async_client as well if you'd like a custom client for async invocations. | |
""" | |
http_async_client: Union[Any, None] = None | |
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify | |
http_client as well if you'd like a custom client for sync invocations.""" | |
class Config: | |
extra = Extra.forbid | |
allow_population_by_field_name = True | |
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: | |
"""Build extra kwargs from additional params that were passed in.""" | |
all_required_field_names = get_pydantic_field_names(cls) | |
extra = values.get("model_kwargs", {}) | |
for field_name in list(values): | |
if field_name in extra: | |
raise ValueError(f"Found {field_name} supplied twice.") | |
if field_name not in all_required_field_names: | |
warnings.warn( | |
f"""WARNING! {field_name} is not default parameter. | |
{field_name} was transferred to model_kwargs. | |
Please confirm that {field_name} is what you intended.""" | |
) | |
extra[field_name] = values.pop(field_name) | |
invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) | |
if invalid_model_kwargs: | |
raise ValueError( | |
f"Parameters {invalid_model_kwargs} should be specified explicitly. " | |
f"Instead they were passed in as part of `model_kwargs` parameter." | |
) | |
values["model_kwargs"] = extra | |
return values | |
def validate_environment(cls, values: Dict) -> Dict: | |
"""Validate that api key and python package exists in environment.""" | |
together_api_key = get_from_dict_or_env( | |
values, "together_api_key", "TOGETHER_API_KEY" | |
) | |
values["together_api_key"] = ( | |
convert_to_secret_str(together_api_key) if together_api_key else None | |
) | |
values["together_api_base"] = values["together_api_base"] or os.getenv( | |
"TOGETHER_API_BASE" | |
) | |
client_params = { | |
"api_key": ( | |
values["together_api_key"].get_secret_value() | |
if values["together_api_key"] | |
else None | |
), | |
"base_url": values["together_api_base"], | |
"timeout": values["request_timeout"], | |
"max_retries": values["max_retries"], | |
"default_headers": values["default_headers"], | |
"default_query": values["default_query"], | |
} | |
if not values.get("client"): | |
sync_specific = ( | |
{"http_client": values["http_client"]} if values["http_client"] else {} | |
) | |
values["client"] = openai.OpenAI( | |
**client_params, **sync_specific | |
).embeddings | |
if not values.get("async_client"): | |
async_specific = ( | |
{"http_client": values["http_async_client"]} | |
if values["http_async_client"] | |
else {} | |
) | |
values["async_client"] = openai.AsyncOpenAI( | |
**client_params, **async_specific | |
).embeddings | |
return values | |
def _invocation_params(self) -> Dict[str, Any]: | |
params: Dict = {"model": self.model, **self.model_kwargs} | |
if self.dimensions is not None: | |
params["dimensions"] = self.dimensions | |
return params | |
def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
"""Embed a list of document texts using passage model. | |
Args: | |
texts: The list of texts to embed. | |
Returns: | |
List of embeddings, one for each text. | |
""" | |
embeddings = [] | |
params = self._invocation_params | |
params["model"] = params["model"] | |
for text in texts: | |
response = self.client.create(input=text, **params) | |
if not isinstance(response, dict): | |
response = response.model_dump() | |
embeddings.extend([i["embedding"] for i in response["data"]]) | |
return embeddings | |
def embed_query(self, text: str) -> List[float]: | |
"""Embed query text using query model. | |
Args: | |
text: The text to embed. | |
Returns: | |
Embedding for the text. | |
""" | |
params = self._invocation_params | |
params["model"] = params["model"] | |
response = self.client.create(input=text, **params) | |
if not isinstance(response, dict): | |
response = response.model_dump() | |
return response["data"][0]["embedding"] | |
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: | |
"""Embed a list of document texts using passage model asynchronously. | |
Args: | |
texts: The list of texts to embed. | |
Returns: | |
List of embeddings, one for each text. | |
""" | |
embeddings = [] | |
params = self._invocation_params | |
params["model"] = params["model"] | |
for text in texts: | |
response = await self.async_client.create(input=text, **params) | |
if not isinstance(response, dict): | |
response = response.model_dump() | |
embeddings.extend([i["embedding"] for i in response["data"]]) | |
return embeddings | |
async def aembed_query(self, text: str) -> List[float]: | |
"""Asynchronous Embed query text using query model. | |
Args: | |
text: The text to embed. | |
Returns: | |
Embedding for the text. | |
""" | |
params = self._invocation_params | |
params["model"] = params["model"] | |
response = await self.async_client.create(input=text, **params) | |
if not isinstance(response, dict): | |
response = response.model_dump() | |
return response["data"][0]["embedding"] | |