LoRDxdd's picture
Add gpt4free API for Hugging Face
a4b70d9
from __future__ import annotations
import requests
from ..helper import filter_none
from ...typing import AsyncResult, Messages
from ...requests import StreamSession, raise_for_status, sse_stream
from ...providers.response import FinishReason, Usage
from ...errors import MissingAuthError
from ...tools.run_tools import AuthManager
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ... import debug
class Cohere(AsyncGeneratorProvider, ProviderModelMixin):
label = "Cohere API"
url = "https://cohere.com"
login_url = "https://dashboard.cohere.com/api-keys"
api_endpoint = "https://api.cohere.ai/v2/chat"
working = True
active_by_default = True
needs_auth = True
models_needs_auth = True
supports_stream = True
supports_system_message = True
supports_message_history = True
default_model = "command-r-plus"
@classmethod
def get_models(cls, api_key: str = None, **kwargs):
if not cls.models:
if not api_key:
api_key = AuthManager.load_api_key(cls)
url = "https://api.cohere.com/v1/models?page_size=500&endpoint=chat"
models = requests.get(url, headers={"Authorization": f"Bearer {api_key}" }).json().get("models", [])
if models:
cls.live += 1
cls.models = [model.get("name") for model in models if "chat" in model.get("endpoints")]
cls.vision_models = {model.get("name") for model in models if model.get("supports_vision")}
return cls.models
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
proxy: str = None,
timeout: int = 120,
api_key: str = None,
temperature: float = None,
max_tokens: int = None,
top_k: int = None,
top_p: float = None,
stop: list[str] = None,
stream: bool = True,
headers: dict = None,
impersonate: str = None,
**kwargs
) -> AsyncResult:
if api_key is None:
raise MissingAuthError('Add a "api_key"')
async with StreamSession(
proxy=proxy,
headers=cls.get_headers(stream, api_key, headers),
timeout=timeout,
impersonate=impersonate,
) as session:
data = filter_none(
messages=messages,
model=cls.get_model(model, api_key=api_key),
temperature=temperature,
max_tokens=max_tokens,
k=top_k,
p=top_p,
stop_sequences=stop,
stream=stream,
)
async with session.post(cls.api_endpoint, json=data) as response:
await raise_for_status(response)
if not stream:
data = await response.json()
cls.raise_error(data)
if "text" in data:
yield data["text"]
if "finish_reason" in data:
if data["finish_reason"] == "COMPLETE":
yield FinishReason("stop")
elif data["finish_reason"] == "MAX_TOKENS":
yield FinishReason("length")
if "usage" in data:
tokens = data.get("usage", {}).get("tokens", {})
yield Usage(
prompt_tokens=tokens.get("input_tokens"),
completion_tokens=tokens.get("output_tokens"),
total_tokens=tokens.get("input_tokens", 0) + tokens.get("output_tokens", 0),
billed_units=data.get("usage", {}).get("billed_units")
)
else:
async for data in sse_stream(response):
cls.raise_error(data)
if "type" in data:
if data["type"] == "content-delta":
yield data.get("delta", {}).get("message", {}).get("content", {}).get("text")
elif data["type"] == "message-end":
delta = data.get("delta", {})
if "finish_reason" in delta:
if delta["finish_reason"] == "COMPLETE":
yield FinishReason("stop")
elif delta["finish_reason"] == "MAX_TOKENS":
yield FinishReason("length")
if "usage" in delta:
tokens = delta.get("usage", {}).get("tokens", {})
yield Usage(
prompt_tokens=tokens.get("input_tokens"),
completion_tokens=tokens.get("output_tokens"),
total_tokens=tokens.get("input_tokens", 0) + tokens.get("output_tokens", 0),
billed_units=delta.get("usage", {}).get("billed_units")
)
@classmethod
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
return {
"Accept": "text/event-stream" if stream else "application/json",
"Content-Type": "application/json",
**(
{"Authorization": f"Bearer {api_key}"}
if api_key is not None else {}
),
**({} if headers is None else headers)
}
@classmethod
def raise_error(cls, data: dict):
if "error" in data:
raise RuntimeError(f"Cohere API Error: {data['error']}")