Spaces:
Runtime error
Runtime error
import asyncio | |
import os | |
import typing | |
from concurrent.futures import ThreadPoolExecutor | |
from tokenizers import Tokenizer # type: ignore | |
import httpx | |
from cohere.types.detokenize_response import DetokenizeResponse | |
from cohere.types.tokenize_response import TokenizeResponse | |
from . import EmbedResponse, EmbedInputType, EmbeddingType, EmbedRequestTruncate | |
from .base_client import BaseCohere, AsyncBaseCohere, OMIT | |
from .config import embed_batch_size | |
from .core import RequestOptions | |
from .environment import ClientEnvironment | |
from .manually_maintained.cache import CacheMixin | |
from .manually_maintained import tokenizers as local_tokenizers | |
from .overrides import run_overrides | |
from .utils import wait, async_wait, merge_embed_responses, SyncSdkUtils, AsyncSdkUtils | |
run_overrides() | |
# Use NoReturn as Never type for compatibility | |
Never = typing.NoReturn | |
def validate_args(obj: typing.Any, method_name: str, check_fn: typing.Callable[[typing.Any], typing.Any]) -> None: | |
method = getattr(obj, method_name) | |
def wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: | |
check_fn(*args, **kwargs) | |
return method(*args, **kwargs) | |
setattr(obj, method_name, wrapped) | |
def throw_if_stream_is_true(*args, **kwargs) -> None: | |
if kwargs.get("stream") is True: | |
raise ValueError( | |
"Since python sdk cohere==5.0.0, you must now use chat_stream(...) instead of chat(stream=True, ...)" | |
) | |
def moved_function(fn_name: str, new_fn_name: str) -> typing.Any: | |
""" | |
This method is moved. Please update usage. | |
""" | |
def fn(*args, **kwargs): | |
raise ValueError( | |
f"Since python sdk cohere==5.0.0, the function {fn_name}(...) has been moved to {new_fn_name}(...). " | |
f"Please update your code. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues." | |
) | |
return fn | |
def deprecated_function(fn_name: str) -> typing.Any: | |
""" | |
This method is deprecated. Please update usage. | |
""" | |
def fn(*args, **kwargs): | |
raise ValueError( | |
f"Since python sdk cohere==5.0.0, the function {fn_name}(...) has been deprecated. " | |
f"Please update your code. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues." | |
) | |
return fn | |
class Client(BaseCohere, CacheMixin): | |
def __init__( | |
self, | |
api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None, | |
*, | |
base_url: typing.Optional[str] = os.getenv("CO_API_URL"), | |
environment: ClientEnvironment = ClientEnvironment.PRODUCTION, | |
client_name: typing.Optional[str] = None, | |
timeout: typing.Optional[float] = None, | |
httpx_client: typing.Optional[httpx.Client] = None, | |
): | |
if api_key is None: | |
api_key = _get_api_key_from_environment() | |
BaseCohere.__init__( | |
self, | |
base_url=base_url, | |
environment=environment, | |
client_name=client_name, | |
token=api_key, | |
timeout=timeout, | |
httpx_client=httpx_client, | |
) | |
validate_args(self, "chat", throw_if_stream_is_true) | |
utils = SyncSdkUtils() | |
# support context manager until Fern upstreams | |
# https://linear.app/buildwithfern/issue/FER-1242/expose-a-context-manager-interface-or-the-http-client-easily | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
self._client_wrapper.httpx_client.httpx_client.close() | |
wait = wait | |
_executor = ThreadPoolExecutor(64) | |
def embed( | |
self, | |
*, | |
texts: typing.Sequence[str], | |
model: typing.Optional[str] = OMIT, | |
input_type: typing.Optional[EmbedInputType] = OMIT, | |
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, | |
truncate: typing.Optional[EmbedRequestTruncate] = OMIT, | |
request_options: typing.Optional[RequestOptions] = None, | |
batching: typing.Optional[bool] = True, | |
) -> EmbedResponse: | |
if batching is False: | |
return BaseCohere.embed( | |
self, | |
texts=texts, | |
model=model, | |
input_type=input_type, | |
embedding_types=embedding_types, | |
truncate=truncate, | |
request_options=request_options, | |
) | |
texts_batches = [texts[i : i + embed_batch_size] for i in range(0, len(texts), embed_batch_size)] | |
responses = [ | |
response | |
for response in self._executor.map( | |
lambda text_batch: BaseCohere.embed( | |
self, | |
texts=text_batch, | |
model=model, | |
input_type=input_type, | |
embedding_types=embedding_types, | |
truncate=truncate, | |
request_options=request_options, | |
), | |
texts_batches, | |
) | |
] | |
return merge_embed_responses(responses) | |
""" | |
The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage. | |
Issues may be filed in https://github.com/cohere-ai/cohere-python/issues. | |
""" | |
check_api_key: Never = deprecated_function("check_api_key") | |
loglikelihood: Never = deprecated_function("loglikelihood") | |
batch_generate: Never = deprecated_function("batch_generate") | |
codebook: Never = deprecated_function("codebook") | |
batch_tokenize: Never = deprecated_function("batch_tokenize") | |
batch_detokenize: Never = deprecated_function("batch_detokenize") | |
detect_language: Never = deprecated_function("detect_language") | |
generate_feedback: Never = deprecated_function("generate_feedback") | |
generate_preference_feedback: Never = deprecated_function("generate_preference_feedback") | |
create_dataset: Never = moved_function("create_dataset", ".datasets.create") | |
get_dataset: Never = moved_function("get_dataset", ".datasets.get") | |
list_datasets: Never = moved_function("list_datasets", ".datasets.list") | |
delete_dataset: Never = moved_function("delete_dataset", ".datasets.delete") | |
get_dataset_usage: Never = moved_function("get_dataset_usage", ".datasets.get_usage") | |
wait_for_dataset: Never = moved_function("wait_for_dataset", ".wait") | |
_check_response: Never = deprecated_function("_check_response") | |
_request: Never = deprecated_function("_request") | |
create_cluster_job: Never = deprecated_function("create_cluster_job") | |
get_cluster_job: Never = deprecated_function("get_cluster_job") | |
list_cluster_jobs: Never = deprecated_function("list_cluster_jobs") | |
wait_for_cluster_job: Never = deprecated_function("wait_for_cluster_job") | |
create_embed_job: Never = moved_function("create_embed_job", ".embed_jobs.create") | |
list_embed_jobs: Never = moved_function("list_embed_jobs", ".embed_jobs.list") | |
get_embed_job: Never = moved_function("get_embed_job", ".embed_jobs.get") | |
cancel_embed_job: Never = moved_function("cancel_embed_job", ".embed_jobs.cancel") | |
wait_for_embed_job: Never = moved_function("wait_for_embed_job", ".wait") | |
create_custom_model: Never = deprecated_function("create_custom_model") | |
wait_for_custom_model: Never = deprecated_function("wait_for_custom_model") | |
_upload_dataset: Never = deprecated_function("_upload_dataset") | |
_create_signed_url: Never = deprecated_function("_create_signed_url") | |
get_custom_model: Never = deprecated_function("get_custom_model") | |
get_custom_model_by_name: Never = deprecated_function("get_custom_model_by_name") | |
get_custom_model_metrics: Never = deprecated_function("get_custom_model_metrics") | |
list_custom_models: Never = deprecated_function("list_custom_models") | |
create_connector: Never = moved_function("create_connector", ".connectors.create") | |
update_connector: Never = moved_function("update_connector", ".connectors.update") | |
get_connector: Never = moved_function("get_connector", ".connectors.get") | |
list_connectors: Never = moved_function("list_connectors", ".connectors.list") | |
delete_connector: Never = moved_function("delete_connector", ".connectors.delete") | |
oauth_authorize_connector: Never = moved_function("oauth_authorize_connector", ".connectors.o_auth_authorize") | |
def tokenize( | |
self, | |
*, | |
text: str, | |
model: str, | |
request_options: typing.Optional[RequestOptions] = None, | |
offline: bool = True, | |
) -> TokenizeResponse: | |
# `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached), | |
# and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True. | |
opts: RequestOptions = request_options or {} # type: ignore | |
if offline: | |
try: | |
tokens = asyncio.run(local_tokenizers.local_tokenize(self, text=text, model=model)) | |
return TokenizeResponse(tokens=tokens, token_strings=[]) | |
except Exception: | |
opts["additional_headers"] = opts.get("additional_headers", {}) | |
opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed" | |
return super().tokenize(text=text, model=model, request_options=opts) | |
def detokenize( | |
self, | |
*, | |
tokens: typing.Sequence[int], | |
model: str, | |
request_options: typing.Optional[RequestOptions] = None, | |
offline: typing.Optional[bool] = True, | |
) -> DetokenizeResponse: | |
# `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached), | |
# and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True. | |
opts: RequestOptions = request_options or {} # type: ignore | |
if offline: | |
try: | |
text = asyncio.run(local_tokenizers.local_detokenize(self, model=model, tokens=tokens)) | |
return DetokenizeResponse(text=text) | |
except Exception: | |
opts["additional_headers"] = opts.get("additional_headers", {}) | |
opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed" | |
return super().detokenize(tokens=tokens, model=model, request_options=opts) | |
def fetch_tokenizer(self, *, model: str) -> Tokenizer: | |
""" | |
Returns a Hugging Face tokenizer from a given model name. | |
""" | |
return local_tokenizers.get_hf_tokenizer(self, model) | |
class AsyncClient(AsyncBaseCohere, CacheMixin): | |
def __init__( | |
self, | |
api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None, | |
*, | |
base_url: typing.Optional[str] = os.getenv("CO_API_URL"), | |
environment: ClientEnvironment = ClientEnvironment.PRODUCTION, | |
client_name: typing.Optional[str] = None, | |
timeout: typing.Optional[float] = None, | |
httpx_client: typing.Optional[httpx.AsyncClient] = None, | |
): | |
if api_key is None: | |
api_key = _get_api_key_from_environment() | |
AsyncBaseCohere.__init__( | |
self, | |
base_url=base_url, | |
environment=environment, | |
client_name=client_name, | |
token=api_key, | |
timeout=timeout, | |
httpx_client=httpx_client, | |
) | |
validate_args(self, "chat", throw_if_stream_is_true) | |
utils = AsyncSdkUtils() | |
# support context manager until Fern upstreams | |
# https://linear.app/buildwithfern/issue/FER-1242/expose-a-context-manager-interface-or-the-http-client-easily | |
async def __aenter__(self): | |
return self | |
async def __aexit__(self, exc_type, exc_value, traceback): | |
await self._client_wrapper.httpx_client.httpx_client.aclose() | |
wait = async_wait | |
_executor = ThreadPoolExecutor(64) | |
async def embed( | |
self, | |
*, | |
texts: typing.Sequence[str], | |
model: typing.Optional[str] = OMIT, | |
input_type: typing.Optional[EmbedInputType] = OMIT, | |
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, | |
truncate: typing.Optional[EmbedRequestTruncate] = OMIT, | |
request_options: typing.Optional[RequestOptions] = None, | |
batching: typing.Optional[bool] = True, | |
) -> EmbedResponse: | |
if batching is False: | |
return await AsyncBaseCohere.embed( | |
self, | |
texts=texts, | |
model=model, | |
input_type=input_type, | |
embedding_types=embedding_types, | |
truncate=truncate, | |
request_options=request_options, | |
) | |
texts_batches = [texts[i : i + embed_batch_size] for i in range(0, len(texts), embed_batch_size)] | |
responses = typing.cast( | |
typing.List[EmbedResponse], | |
await asyncio.gather( | |
*[ | |
AsyncBaseCohere.embed( | |
self, | |
texts=text_batch, | |
model=model, | |
input_type=input_type, | |
embedding_types=embedding_types, | |
truncate=truncate, | |
request_options=request_options, | |
) | |
for text_batch in texts_batches | |
] | |
), | |
) | |
return merge_embed_responses(responses) | |
""" | |
The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage. | |
Issues may be filed in https://github.com/cohere-ai/cohere-python/issues. | |
""" | |
check_api_key: Never = deprecated_function("check_api_key") | |
loglikelihood: Never = deprecated_function("loglikelihood") | |
batch_generate: Never = deprecated_function("batch_generate") | |
codebook: Never = deprecated_function("codebook") | |
batch_tokenize: Never = deprecated_function("batch_tokenize") | |
batch_detokenize: Never = deprecated_function("batch_detokenize") | |
detect_language: Never = deprecated_function("detect_language") | |
generate_feedback: Never = deprecated_function("generate_feedback") | |
generate_preference_feedback: Never = deprecated_function("generate_preference_feedback") | |
create_dataset: Never = moved_function("create_dataset", ".datasets.create") | |
get_dataset: Never = moved_function("get_dataset", ".datasets.get") | |
list_datasets: Never = moved_function("list_datasets", ".datasets.list") | |
delete_dataset: Never = moved_function("delete_dataset", ".datasets.delete") | |
get_dataset_usage: Never = moved_function("get_dataset_usage", ".datasets.get_usage") | |
wait_for_dataset: Never = moved_function("wait_for_dataset", ".wait") | |
_check_response: Never = deprecated_function("_check_response") | |
_request: Never = deprecated_function("_request") | |
create_cluster_job: Never = deprecated_function("create_cluster_job") | |
get_cluster_job: Never = deprecated_function("get_cluster_job") | |
list_cluster_jobs: Never = deprecated_function("list_cluster_jobs") | |
wait_for_cluster_job: Never = deprecated_function("wait_for_cluster_job") | |
create_embed_job: Never = moved_function("create_embed_job", ".embed_jobs.create") | |
list_embed_jobs: Never = moved_function("list_embed_jobs", ".embed_jobs.list") | |
get_embed_job: Never = moved_function("get_embed_job", ".embed_jobs.get") | |
cancel_embed_job: Never = moved_function("cancel_embed_job", ".embed_jobs.cancel") | |
wait_for_embed_job: Never = moved_function("wait_for_embed_job", ".wait") | |
create_custom_model: Never = deprecated_function("create_custom_model") | |
wait_for_custom_model: Never = deprecated_function("wait_for_custom_model") | |
_upload_dataset: Never = deprecated_function("_upload_dataset") | |
_create_signed_url: Never = deprecated_function("_create_signed_url") | |
get_custom_model: Never = deprecated_function("get_custom_model") | |
get_custom_model_by_name: Never = deprecated_function("get_custom_model_by_name") | |
get_custom_model_metrics: Never = deprecated_function("get_custom_model_metrics") | |
list_custom_models: Never = deprecated_function("list_custom_models") | |
create_connector: Never = moved_function("create_connector", ".connectors.create") | |
update_connector: Never = moved_function("update_connector", ".connectors.update") | |
get_connector: Never = moved_function("get_connector", ".connectors.get") | |
list_connectors: Never = moved_function("list_connectors", ".connectors.list") | |
delete_connector: Never = moved_function("delete_connector", ".connectors.delete") | |
oauth_authorize_connector: Never = moved_function("oauth_authorize_connector", ".connectors.o_auth_authorize") | |
async def tokenize( | |
self, | |
*, | |
text: str, | |
model: str, | |
request_options: typing.Optional[RequestOptions] = None, | |
offline: typing.Optional[bool] = True, | |
) -> TokenizeResponse: | |
# `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached), | |
# and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True. | |
opts: RequestOptions = request_options or {} # type: ignore | |
if offline: | |
try: | |
tokens = await local_tokenizers.local_tokenize(self, model=model, text=text) | |
return TokenizeResponse(tokens=tokens, token_strings=[]) | |
except Exception: | |
opts["additional_headers"] = opts.get("additional_headers", {}) | |
opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed" | |
return await super().tokenize(text=text, model=model, request_options=opts) | |
async def detokenize( | |
self, | |
*, | |
tokens: typing.Sequence[int], | |
model: str, | |
request_options: typing.Optional[RequestOptions] = None, | |
offline: typing.Optional[bool] = True, | |
) -> DetokenizeResponse: | |
# `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached), | |
# and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True. | |
opts: RequestOptions = request_options or {} # type: ignore | |
if offline: | |
try: | |
text = await local_tokenizers.local_detokenize(self, model=model, tokens=tokens) | |
return DetokenizeResponse(text=text) | |
except Exception: | |
opts["additional_headers"] = opts.get("additional_headers", {}) | |
opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed" | |
return await super().detokenize(tokens=tokens, model=model, request_options=opts) | |
async def fetch_tokenizer(self, *, model: str) -> Tokenizer: | |
""" | |
Returns a Hugging Face tokenizer from a given model name. | |
""" | |
return await local_tokenizers.get_hf_tokenizer(self, model) | |
def _get_api_key_from_environment() -> typing.Optional[str]: | |
""" | |
Retrieves the Cohere API key from specific environment variables. | |
CO_API_KEY is preferred (and documented) COHERE_API_KEY is accepted (but not documented). | |
""" | |
return os.getenv("CO_API_KEY", os.getenv("COHERE_API_KEY")) | |