File size: 5,851 Bytes
a4b70d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from __future__ import annotations

import requests

from ..helper import filter_none
from ...typing import AsyncResult, Messages
from ...requests import StreamSession, raise_for_status, sse_stream
from ...providers.response import FinishReason, Usage
from ...errors import MissingAuthError
from ...tools.run_tools import AuthManager
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ... import debug

class Cohere(AsyncGeneratorProvider, ProviderModelMixin):
    label = "Cohere API"
    url = "https://cohere.com"
    login_url = "https://dashboard.cohere.com/api-keys"
    api_endpoint = "https://api.cohere.ai/v2/chat"
    working = True
    active_by_default = True
    needs_auth = True
    models_needs_auth = True
    supports_stream = True
    supports_system_message = True
    supports_message_history = True
    
    default_model = "command-r-plus"

    @classmethod
    def get_models(cls, api_key: str = None, **kwargs):
        if not cls.models:
            if not api_key:
                api_key = AuthManager.load_api_key(cls)
            url = "https://api.cohere.com/v1/models?page_size=500&endpoint=chat"
            models = requests.get(url, headers={"Authorization": f"Bearer {api_key}" }).json().get("models", [])
            if models:
                cls.live += 1
            cls.models = [model.get("name") for model in models if "chat" in model.get("endpoints")]
            cls.vision_models = {model.get("name") for model in models if model.get("supports_vision")}
        return cls.models

    @classmethod
    async def create_async_generator(
        cls,
        model: str,
        messages: Messages,
        proxy: str = None,
        timeout: int = 120,
        api_key: str = None,
        temperature: float = None,
        max_tokens: int = None,
        top_k: int = None,
        top_p: float = None,
        stop: list[str] = None,
        stream: bool = True,
        headers: dict = None,
        impersonate: str = None,
        **kwargs
    ) -> AsyncResult:
        if api_key is None:
            raise MissingAuthError('Add a "api_key"')

        async with StreamSession(
            proxy=proxy,
            headers=cls.get_headers(stream, api_key, headers),
            timeout=timeout,
            impersonate=impersonate,
        ) as session:
            data = filter_none(
                messages=messages,
                model=cls.get_model(model, api_key=api_key),
                temperature=temperature,
                max_tokens=max_tokens,
                k=top_k,
                p=top_p,
                stop_sequences=stop,
                stream=stream,
            )
            async with session.post(cls.api_endpoint, json=data) as response:
                await raise_for_status(response)
                
                if not stream:
                    data = await response.json()
                    cls.raise_error(data)
                    if "text" in data:
                        yield data["text"]
                    if "finish_reason" in data:
                        if data["finish_reason"] == "COMPLETE":
                            yield FinishReason("stop")
                        elif data["finish_reason"] == "MAX_TOKENS":
                            yield FinishReason("length")
                    if "usage" in data:
                        tokens = data.get("usage", {}).get("tokens", {})
                        yield Usage(
                            prompt_tokens=tokens.get("input_tokens"),
                            completion_tokens=tokens.get("output_tokens"),
                            total_tokens=tokens.get("input_tokens", 0) + tokens.get("output_tokens", 0),
                            billed_units=data.get("usage", {}).get("billed_units")
                        )
                else:
                    async for data in sse_stream(response):
                        cls.raise_error(data)
                        if "type" in data:
                            if data["type"] == "content-delta":
                                yield data.get("delta", {}).get("message", {}).get("content", {}).get("text")
                            elif data["type"] == "message-end":
                                delta = data.get("delta", {})
                                if "finish_reason" in delta:
                                    if delta["finish_reason"] == "COMPLETE":
                                        yield FinishReason("stop")
                                    elif delta["finish_reason"] == "MAX_TOKENS":
                                        yield FinishReason("length")
                                if "usage" in delta:
                                    tokens = delta.get("usage", {}).get("tokens", {})
                                    yield Usage(
                                        prompt_tokens=tokens.get("input_tokens"),
                                        completion_tokens=tokens.get("output_tokens"),
                                        total_tokens=tokens.get("input_tokens", 0) + tokens.get("output_tokens", 0),
                                        billed_units=delta.get("usage", {}).get("billed_units")
                                    )

    @classmethod
    def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
        return {
            "Accept": "text/event-stream" if stream else "application/json",
            "Content-Type": "application/json",
            **(
                {"Authorization": f"Bearer {api_key}"}
                if api_key is not None else {}
            ),
            **({} if headers is None else headers)
        }

    @classmethod
    def raise_error(cls, data: dict):
        if "error" in data:
            raise RuntimeError(f"Cohere API Error: {data['error']}")