from __future__ import annotations import base64 import json import requests from typing import Optional from aiohttp import ClientSession, BaseConnector from ...typing import AsyncResult, Messages, MediaListType from ...image import to_bytes, is_data_an_media from ...errors import MissingAuthError, ModelNotFoundError from ...requests import raise_for_status, iter_lines from ...providers.response import Usage, FinishReason from ...image.copy_images import save_response_media from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..helper import get_connector, to_string, format_media_prompt, get_system_prompt from ... import debug class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): label = "Google Gemini API" url = "https://ai.google.dev" login_url = "https://aistudio.google.com/u/0/apikey" api_base = "https://generativelanguage.googleapis.com/v1beta" active_by_default = True working = True supports_message_history = True supports_system_message = True needs_auth = True default_model = "gemini-2.5-flash" default_vision_model = default_model fallback_models = [ "gemini-2.0-flash", "gemini-2.0-flash-lite", "gemini-2.0-flash-thinking-exp", "gemini-2.5-flash", "gemma-3-1b-it", "gemma-3-12b-it", "gemma-3-27b-it", "gemma-3-4b-it", "gemma-3n-e2b-it", "gemma-3n-e4b-it", ] @classmethod def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]: if not api_key: return cls.fallback_models if not cls.models: try: url = f"{cls.api_base if not api_base else api_base}/models" response = requests.get(url, params={"key": api_key}) raise_for_status(response) data = response.json() cls.models = [ model.get("name").split("/").pop() for model in data.get("models") if "generateContent" in model.get("supportedGenerationMethods") ] cls.models.sort() cls.live += 1 except Exception as e: debug.error(e) if api_key is not None: raise MissingAuthError("Invalid API key") return cls.fallback_models return cls.models @classmethod async def create_async_generator( cls, model: str, messages: Messages, stream: bool = False, proxy: str = None, api_key: str = None, api_base: str = api_base, use_auth_header: bool = False, media: MediaListType = None, tools: Optional[list] = None, connector: BaseConnector = None, **kwargs ) -> AsyncResult: if not api_key: raise MissingAuthError('Add a "api_key"') try: model = cls.get_model(model, api_key=api_key, api_base=api_base) except ModelNotFoundError: pass headers = params = None if use_auth_header: headers = {"Authorization": f"Bearer {api_key}"} else: params = {"key": api_key} method = "streamGenerateContent" if stream else "generateContent" url = f"{api_base.rstrip('/')}/models/{model}:{method}" async with ClientSession(headers=headers, connector=get_connector(connector, proxy)) as session: contents = [ { "role": "model" if message["role"] == "assistant" else "user", "parts": [{"text": to_string(message["content"])}] } for message in messages if message["role"] not in ["system", "developer"] ] if media is not None: if not contents: contents.append({"role": "user", "parts": []}) for media_data, filename in media: media_data = to_bytes(media_data) contents[-1]["parts"].append({ "inline_data": { "mime_type": is_data_an_media(media_data, filename), "data": base64.b64encode(media_data).decode() } }) responseModalities = {"responseModalities": ["AUDIO"]} if "tts" in model else {} data = { "contents": contents, "generationConfig": { "stopSequences": kwargs.get("stop"), "temperature": kwargs.get("temperature"), "maxOutputTokens": kwargs.get("max_tokens"), "topP": kwargs.get("top_p"), "topK": kwargs.get("top_k"), **responseModalities, }, "tools": [{ "function_declarations": [{ "name": tool["function"]["name"], "description": tool["function"]["description"], "parameters": { "type": "object", "properties": {key: { "type": value["type"], "description": value["title"] } for key, value in tool["function"]["parameters"]["properties"].items()} }, } for tool in tools] }] if tools else None } system_prompt = get_system_prompt(messages) if system_prompt: data["system_instruction"] = {"parts": {"text": system_prompt}} async with session.post(url, params=params, json=data) as response: if not response.ok: data = await response.json() data = data[0] if isinstance(data, list) else data raise RuntimeError(f"Response {response.status}: {data['error']['message']}") if stream: lines = [] buffer = b"" async for chunk in iter_lines(response.content.iter_any()): buffer += chunk if chunk == b"[{": lines = [b"{"] elif chunk == b"," or chunk == b"]": try: data = json.loads(b"".join(lines)) content = data["candidates"][0]["content"] if "parts" in content and content["parts"]: if "text" in content["parts"][0]: yield content["parts"][0]["text"] elif "inlineData" in content["parts"][0]: async for media in save_response_media( content["parts"][0]["inlineData"], format_media_prompt(messages) ): yield media if "finishReason" in data["candidates"][0]: yield FinishReason(data["candidates"][0]["finishReason"].lower()) usage = data.get("usageMetadata") if usage: yield Usage( prompt_tokens=usage.get("promptTokenCount"), completion_tokens=usage.get("candidatesTokenCount"), total_tokens=usage.get("totalTokenCount") ) except Exception as e: raise RuntimeError(f"Read chunk failed") from e lines = [] else: lines.append(chunk) else: data = await response.json() candidate = data["candidates"][0] if "content" in candidate: content = candidate["content"] if "parts" in content and content["parts"]: for part in content["parts"]: if "text" in part: yield part["text"] elif "inlineData" in part: async for media in save_response_media( part["inlineData"], format_media_prompt(messages) ): yield media if "finishReason" in candidate: yield FinishReason(candidate["finishReason"].lower())