File size: 7,385 Bytes
f1e6b80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
import functools
from importlib import util
from typing import Any, List, Optional, Tuple, Union
from langchain_core._api import beta
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import Runnable
_SUPPORTED_PROVIDERS = {
"azure_openai": "langchain_openai",
"bedrock": "langchain_aws",
"cohere": "langchain_cohere",
"google_vertexai": "langchain_google_vertexai",
"huggingface": "langchain_huggingface",
"mistralai": "langchain_mistralai",
"openai": "langchain_openai",
}
def _get_provider_list() -> str:
"""Get formatted list of providers and their packages."""
return "\n".join(
f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items()
)
def _parse_model_string(model_name: str) -> Tuple[str, str]:
"""Parse a model string into provider and model name components.
The model string should be in the format 'provider:model-name', where provider
is one of the supported providers.
Args:
model_name: A model string in the format 'provider:model-name'
Returns:
A tuple of (provider, model_name)
.. code-block:: python
_parse_model_string("openai:text-embedding-3-small")
# Returns: ("openai", "text-embedding-3-small")
_parse_model_string("bedrock:amazon.titan-embed-text-v1")
# Returns: ("bedrock", "amazon.titan-embed-text-v1")
Raises:
ValueError: If the model string is not in the correct format or
the provider is unsupported
"""
if ":" not in model_name:
providers = _SUPPORTED_PROVIDERS
raise ValueError(
f"Invalid model format '{model_name}'.\n"
f"Model name must be in format 'provider:model-name'\n"
f"Example valid model strings:\n"
f" - openai:text-embedding-3-small\n"
f" - bedrock:amazon.titan-embed-text-v1\n"
f" - cohere:embed-english-v3.0\n"
f"Supported providers: {providers}"
)
provider, model = model_name.split(":", 1)
provider = provider.lower().strip()
model = model.strip()
if provider not in _SUPPORTED_PROVIDERS:
raise ValueError(
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"
f"{_get_provider_list()}"
)
if not model:
raise ValueError("Model name cannot be empty")
return provider, model
def _infer_model_and_provider(
model: str, *, provider: Optional[str] = None
) -> Tuple[str, str]:
if not model.strip():
raise ValueError("Model name cannot be empty")
if provider is None and ":" in model:
provider, model_name = _parse_model_string(model)
else:
provider = provider
model_name = model
if not provider:
providers = _SUPPORTED_PROVIDERS
raise ValueError(
"Must specify either:\n"
"1. A model string in format 'provider:model-name'\n"
" Example: 'openai:text-embedding-3-small'\n"
"2. Or explicitly set provider from: "
f"{providers}"
)
if provider not in _SUPPORTED_PROVIDERS:
raise ValueError(
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"
f"{_get_provider_list()}"
)
return provider, model_name
@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
def _check_pkg(pkg: str) -> None:
"""Check if a package is installed."""
if not util.find_spec(pkg):
raise ImportError(
f"Could not import {pkg} python package. "
f"Please install it with `pip install {pkg}`"
)
@beta()
def init_embeddings(
model: str,
*,
provider: Optional[str] = None,
**kwargs: Any,
) -> Union[Embeddings, Runnable[Any, List[float]]]:
"""Initialize an embeddings model from a model name and optional provider.
**Note:** Must have the integration package corresponding to the model provider
installed.
Args:
model: Name of the model to use. Can be either:
- A model string like "openai:text-embedding-3-small"
- Just the model name if provider is specified
provider: Optional explicit provider name. If not specified,
will attempt to parse from the model string. Supported providers
and their required packages:
{_get_provider_list()}
**kwargs: Additional model-specific parameters passed to the embedding model.
These vary by provider, see the provider-specific documentation for details.
Returns:
An Embeddings instance that can generate embeddings for text.
Raises:
ValueError: If the model provider is not supported or cannot be determined
ImportError: If the required provider package is not installed
.. dropdown:: Example Usage
:open:
.. code-block:: python
# Using a model string
model = init_embeddings("openai:text-embedding-3-small")
model.embed_query("Hello, world!")
# Using explicit provider
model = init_embeddings(
model="text-embedding-3-small",
provider="openai"
)
model.embed_documents(["Hello, world!", "Goodbye, world!"])
# With additional parameters
model = init_embeddings(
"openai:text-embedding-3-small",
api_key="sk-..."
)
.. versionadded:: 0.3.9
"""
if not model:
providers = _SUPPORTED_PROVIDERS.keys()
raise ValueError(
"Must specify model name. "
f"Supported providers are: {', '.join(providers)}"
)
provider, model_name = _infer_model_and_provider(model, provider=provider)
pkg = _SUPPORTED_PROVIDERS[provider]
_check_pkg(pkg)
if provider == "openai":
from langchain_openai import OpenAIEmbeddings
return OpenAIEmbeddings(model=model_name, **kwargs)
elif provider == "azure_openai":
from langchain_openai import AzureOpenAIEmbeddings
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
elif provider == "google_vertexai":
from langchain_google_vertexai import VertexAIEmbeddings
return VertexAIEmbeddings(model=model_name, **kwargs)
elif provider == "bedrock":
from langchain_aws import BedrockEmbeddings
return BedrockEmbeddings(model_id=model_name, **kwargs)
elif provider == "cohere":
from langchain_cohere import CohereEmbeddings
return CohereEmbeddings(model=model_name, **kwargs)
elif provider == "mistralai":
from langchain_mistralai import MistralAIEmbeddings
return MistralAIEmbeddings(model=model_name, **kwargs)
elif provider == "huggingface":
from langchain_huggingface import HuggingFaceEmbeddings
return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
else:
raise ValueError(
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"
f"{_get_provider_list()}"
)
__all__ = [
"init_embeddings",
"Embeddings", # This one is for backwards compatibility
]
|