|
from __future__ import annotations |
|
|
|
import requests |
|
|
|
from ..helper import filter_none |
|
from ...typing import AsyncResult, Messages |
|
from ...requests import StreamSession, raise_for_status, sse_stream |
|
from ...providers.response import FinishReason, Usage |
|
from ...errors import MissingAuthError |
|
from ...tools.run_tools import AuthManager |
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin |
|
from ... import debug |
|
|
|
class Cohere(AsyncGeneratorProvider, ProviderModelMixin): |
|
label = "Cohere API" |
|
url = "https://cohere.com" |
|
login_url = "https://dashboard.cohere.com/api-keys" |
|
api_endpoint = "https://api.cohere.ai/v2/chat" |
|
working = True |
|
active_by_default = True |
|
needs_auth = True |
|
models_needs_auth = True |
|
supports_stream = True |
|
supports_system_message = True |
|
supports_message_history = True |
|
|
|
default_model = "command-r-plus" |
|
|
|
@classmethod |
|
def get_models(cls, api_key: str = None, **kwargs): |
|
if not cls.models: |
|
if not api_key: |
|
api_key = AuthManager.load_api_key(cls) |
|
url = "https://api.cohere.com/v1/models?page_size=500&endpoint=chat" |
|
models = requests.get(url, headers={"Authorization": f"Bearer {api_key}" }).json().get("models", []) |
|
if models: |
|
cls.live += 1 |
|
cls.models = [model.get("name") for model in models if "chat" in model.get("endpoints")] |
|
cls.vision_models = {model.get("name") for model in models if model.get("supports_vision")} |
|
return cls.models |
|
|
|
@classmethod |
|
async def create_async_generator( |
|
cls, |
|
model: str, |
|
messages: Messages, |
|
proxy: str = None, |
|
timeout: int = 120, |
|
api_key: str = None, |
|
temperature: float = None, |
|
max_tokens: int = None, |
|
top_k: int = None, |
|
top_p: float = None, |
|
stop: list[str] = None, |
|
stream: bool = True, |
|
headers: dict = None, |
|
impersonate: str = None, |
|
**kwargs |
|
) -> AsyncResult: |
|
if api_key is None: |
|
raise MissingAuthError('Add a "api_key"') |
|
|
|
async with StreamSession( |
|
proxy=proxy, |
|
headers=cls.get_headers(stream, api_key, headers), |
|
timeout=timeout, |
|
impersonate=impersonate, |
|
) as session: |
|
data = filter_none( |
|
messages=messages, |
|
model=cls.get_model(model, api_key=api_key), |
|
temperature=temperature, |
|
max_tokens=max_tokens, |
|
k=top_k, |
|
p=top_p, |
|
stop_sequences=stop, |
|
stream=stream, |
|
) |
|
async with session.post(cls.api_endpoint, json=data) as response: |
|
await raise_for_status(response) |
|
|
|
if not stream: |
|
data = await response.json() |
|
cls.raise_error(data) |
|
if "text" in data: |
|
yield data["text"] |
|
if "finish_reason" in data: |
|
if data["finish_reason"] == "COMPLETE": |
|
yield FinishReason("stop") |
|
elif data["finish_reason"] == "MAX_TOKENS": |
|
yield FinishReason("length") |
|
if "usage" in data: |
|
tokens = data.get("usage", {}).get("tokens", {}) |
|
yield Usage( |
|
prompt_tokens=tokens.get("input_tokens"), |
|
completion_tokens=tokens.get("output_tokens"), |
|
total_tokens=tokens.get("input_tokens", 0) + tokens.get("output_tokens", 0), |
|
billed_units=data.get("usage", {}).get("billed_units") |
|
) |
|
else: |
|
async for data in sse_stream(response): |
|
cls.raise_error(data) |
|
if "type" in data: |
|
if data["type"] == "content-delta": |
|
yield data.get("delta", {}).get("message", {}).get("content", {}).get("text") |
|
elif data["type"] == "message-end": |
|
delta = data.get("delta", {}) |
|
if "finish_reason" in delta: |
|
if delta["finish_reason"] == "COMPLETE": |
|
yield FinishReason("stop") |
|
elif delta["finish_reason"] == "MAX_TOKENS": |
|
yield FinishReason("length") |
|
if "usage" in delta: |
|
tokens = delta.get("usage", {}).get("tokens", {}) |
|
yield Usage( |
|
prompt_tokens=tokens.get("input_tokens"), |
|
completion_tokens=tokens.get("output_tokens"), |
|
total_tokens=tokens.get("input_tokens", 0) + tokens.get("output_tokens", 0), |
|
billed_units=delta.get("usage", {}).get("billed_units") |
|
) |
|
|
|
@classmethod |
|
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict: |
|
return { |
|
"Accept": "text/event-stream" if stream else "application/json", |
|
"Content-Type": "application/json", |
|
**( |
|
{"Authorization": f"Bearer {api_key}"} |
|
if api_key is not None else {} |
|
), |
|
**({} if headers is None else headers) |
|
} |
|
|
|
@classmethod |
|
def raise_error(cls, data: dict): |
|
if "error" in data: |
|
raise RuntimeError(f"Cohere API Error: {data['error']}") |