File size: 5,851 Bytes
a4b70d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from __future__ import annotations
import requests
from ..helper import filter_none
from ...typing import AsyncResult, Messages
from ...requests import StreamSession, raise_for_status, sse_stream
from ...providers.response import FinishReason, Usage
from ...errors import MissingAuthError
from ...tools.run_tools import AuthManager
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ... import debug
class Cohere(AsyncGeneratorProvider, ProviderModelMixin):
label = "Cohere API"
url = "https://cohere.com"
login_url = "https://dashboard.cohere.com/api-keys"
api_endpoint = "https://api.cohere.ai/v2/chat"
working = True
active_by_default = True
needs_auth = True
models_needs_auth = True
supports_stream = True
supports_system_message = True
supports_message_history = True
default_model = "command-r-plus"
@classmethod
def get_models(cls, api_key: str = None, **kwargs):
if not cls.models:
if not api_key:
api_key = AuthManager.load_api_key(cls)
url = "https://api.cohere.com/v1/models?page_size=500&endpoint=chat"
models = requests.get(url, headers={"Authorization": f"Bearer {api_key}" }).json().get("models", [])
if models:
cls.live += 1
cls.models = [model.get("name") for model in models if "chat" in model.get("endpoints")]
cls.vision_models = {model.get("name") for model in models if model.get("supports_vision")}
return cls.models
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
proxy: str = None,
timeout: int = 120,
api_key: str = None,
temperature: float = None,
max_tokens: int = None,
top_k: int = None,
top_p: float = None,
stop: list[str] = None,
stream: bool = True,
headers: dict = None,
impersonate: str = None,
**kwargs
) -> AsyncResult:
if api_key is None:
raise MissingAuthError('Add a "api_key"')
async with StreamSession(
proxy=proxy,
headers=cls.get_headers(stream, api_key, headers),
timeout=timeout,
impersonate=impersonate,
) as session:
data = filter_none(
messages=messages,
model=cls.get_model(model, api_key=api_key),
temperature=temperature,
max_tokens=max_tokens,
k=top_k,
p=top_p,
stop_sequences=stop,
stream=stream,
)
async with session.post(cls.api_endpoint, json=data) as response:
await raise_for_status(response)
if not stream:
data = await response.json()
cls.raise_error(data)
if "text" in data:
yield data["text"]
if "finish_reason" in data:
if data["finish_reason"] == "COMPLETE":
yield FinishReason("stop")
elif data["finish_reason"] == "MAX_TOKENS":
yield FinishReason("length")
if "usage" in data:
tokens = data.get("usage", {}).get("tokens", {})
yield Usage(
prompt_tokens=tokens.get("input_tokens"),
completion_tokens=tokens.get("output_tokens"),
total_tokens=tokens.get("input_tokens", 0) + tokens.get("output_tokens", 0),
billed_units=data.get("usage", {}).get("billed_units")
)
else:
async for data in sse_stream(response):
cls.raise_error(data)
if "type" in data:
if data["type"] == "content-delta":
yield data.get("delta", {}).get("message", {}).get("content", {}).get("text")
elif data["type"] == "message-end":
delta = data.get("delta", {})
if "finish_reason" in delta:
if delta["finish_reason"] == "COMPLETE":
yield FinishReason("stop")
elif delta["finish_reason"] == "MAX_TOKENS":
yield FinishReason("length")
if "usage" in delta:
tokens = delta.get("usage", {}).get("tokens", {})
yield Usage(
prompt_tokens=tokens.get("input_tokens"),
completion_tokens=tokens.get("output_tokens"),
total_tokens=tokens.get("input_tokens", 0) + tokens.get("output_tokens", 0),
billed_units=delta.get("usage", {}).get("billed_units")
)
@classmethod
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
return {
"Accept": "text/event-stream" if stream else "application/json",
"Content-Type": "application/json",
**(
{"Authorization": f"Bearer {api_key}"}
if api_key is not None else {}
),
**({} if headers is None else headers)
}
@classmethod
def raise_error(cls, data: dict):
if "error" in data:
raise RuntimeError(f"Cohere API Error: {data['error']}") |