import os import time import httpx import warnings from typing import List, Dict, Optional from smolagents import ApiModel, ChatMessage class GeminiApiModel(ApiModel): """ ApiModel implementation using the Google Gemini API via direct HTTP requests. """ def __init__( self, model_id: str = "gemini-pro", api_key: Optional[str] = None, **kwargs, ): """ Initializes the GeminiApiModel. Args: model_id (str): The Gemini model ID to use (e.g., "gemini-pro"). api_key (str, optional): Google AI Studio API key. Defaults to GEMINI_API_KEY environment variable. **kwargs: Additional keyword arguments passed to the parent ApiModel. """ self.model_id = model_id # Prefer explicitly passed key, fallback to environment variable self.api_key = api_key if api_key else os.environ.get("GEMINI_API_KEY") if not self.api_key: warnings.warn( "GEMINI_API_KEY not provided via argument or environment variable. API calls will likely fail.", UserWarning, ) # Gemini API doesn't inherently support complex role structures or function calling like OpenAI. # We'll flatten messages for simplicity. super().__init__( model_id=model_id, flatten_messages_as_text=True, # Flatten messages to a single text prompt **kwargs, ) def create_client(self): """No dedicated client needed as we use httpx directly.""" return None # Or potentially return httpx client if reused def __call__( self, messages: List[Dict[str, str]], stop_sequences: Optional[ List[str] ] = None, # Note: Gemini API might not support stop sequences directly here grammar: Optional[ str ] = None, # Note: Gemini API doesn't support grammar directly tools_to_call_from: Optional[ List["Tool"] ] = None, # Note: Basic Gemini API doesn't support tools **kwargs, ) -> ChatMessage: """ Calls the Google Gemini API with the provided messages. Args: messages: A list of message dictionaries (e.g., [{'role': 'user', 'content': '...'}]). stop_sequences: Optional stop sequences (may not be supported). grammar: Optional grammar constraint (not supported). tools_to_call_from: Optional list of tools (not supported). **kwargs: Additional keyword arguments. Returns: A ChatMessage object containing the response. """ if not self.api_key: raise ValueError("GEMINI_API_KEY is not set.") # Prepare the prompt by concatenating message content # The Gemini Pro basic API expects a simple text prompt. prompt = self._messages_to_prompt(messages) prompt += ( "\n\n" + "If you have a result from a web search that looks helpful, please use httpx to get the HTML from the URL listed." + "You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string." ) # print(f"--- Gemini API prompt: ---\n{prompt}\n--- End of prompt ---") url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model_id}:generateContent?key={self.api_key}" headers = {"Content-Type": "application/json"} # Construct the payload according to Gemini API requirements data = {"contents": [{"parts": [{"text": prompt}]}]} # Add generation config if provided via kwargs (optional) generation_config = {} if "temperature" in kwargs: generation_config["temperature"] = kwargs["temperature"] if "max_output_tokens" in kwargs: generation_config["maxOutputTokens"] = kwargs["max_output_tokens"] # Add other relevant config parameters here if needed if generation_config: data["generationConfig"] = generation_config # Handle stop sequences if provided (basic support) # Note: This is a best-effort addition, check Gemini API docs for formal support if stop_sequences: if "generationConfig" not in data: data["generationConfig"] = {} # Assuming Gemini API might support 'stopSequences' in generationConfig data["generationConfig"]["stopSequences"] = stop_sequences raw_response = None try: response = httpx.post( url, headers=headers, json=data, timeout=120.0 ) # Increased timeout time.sleep(6) # Add delay to respect rate limits response.raise_for_status() response_json = response.json() raw_response = response_json # Store raw response # Parse the response - adjust based on actual Gemini API structure if "candidates" in response_json and response_json["candidates"]: part = response_json["candidates"][0]["content"]["parts"][0] if "text" in part: content = part["text"] # Check for "FINAL ANSWER: " and extract the rest of the string final_answer_marker = "FINAL ANSWER: " if final_answer_marker in content: content = content.split(final_answer_marker)[-1].strip() # Simulate token counts if available, otherwise default to 0 # The basic generateContent API might not return usage directly in the main response # It might be in safetyRatings or other metadata if enabled/available. # Setting to 0 for now as it's not reliably present in the simplest call. self.last_input_token_count = 0 self.last_output_token_count = 0 # If usage data becomes available in response_json, parse it here: # e.g., if response_json.get("usageMetadata"): # self.last_input_token_count = response_json["usageMetadata"].get("promptTokenCount", 0) # self.last_output_token_count = response_json["usageMetadata"].get("candidatesTokenCount", 0) return ChatMessage( role="assistant", content=content, raw=raw_response ) # Handle cases where the expected response structure isn't found error_content = f"Error or unexpected response format: {response_json}" return ChatMessage( role="assistant", content=error_content, raw=raw_response ) except httpx.RequestError as exc: error_content = ( f"An error occurred while requesting {exc.request.url!r}: {exc}" ) return ChatMessage( role="assistant", content=error_content, raw={"error": str(exc)} ) except httpx.HTTPStatusError as exc: error_content = f"Error response {exc.response.status_code} while requesting {exc.request.url!r}: {exc.response.text}" return ChatMessage( role="assistant", content=error_content, raw={ "error": str(exc), "status_code": exc.response.status_code, "response_text": exc.response.text, }, ) except Exception as e: error_content = f"An unexpected error occurred: {e}" return ChatMessage( role="assistant", content=error_content, raw={"error": str(e)} ) def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: """Converts a list of messages into a single string prompt.""" # Simple concatenation, could be more sophisticated based on roles if needed # Ensure we handle cases where 'content' might not be a string (though it should be) return "\n".join([str(msg.get("content", "")) for msg in messages])