Spaces:
Running
Running
import datetime | |
import os | |
from random import random | |
import time | |
import traceback | |
from typing import List | |
from toolformers.base import Conversation, Tool, Toolformer | |
import google.generativeai as genai | |
from google.generativeai.generative_models import ChatSession | |
from utils import register_cost | |
genai.configure(api_key=os.environ['GOOGLE_API_KEY']) | |
COSTS = { | |
'gemini-1.5-pro': { | |
'prompt_tokens': 1.25e-6, | |
'completion_tokens': 5e-6 | |
} | |
} | |
class GeminiConversation(Conversation): | |
def __init__(self, model_name, chat_agent : ChatSession, category=None): | |
self.model_name = model_name | |
self.chat_agent = chat_agent | |
self.category = category | |
def chat(self, message, role='user', print_output=True): | |
agent_id = os.environ.get('AGENT_ID', None) | |
time_start = datetime.datetime.now() | |
exponential_backoff_lower = 30 | |
exponential_backoff_higher = 60 | |
for i in range(5): | |
try: | |
response = self.chat_agent.send_message({ | |
'role': role, | |
'parts': [ | |
message | |
] | |
}) | |
break | |
except Exception as e: | |
print(e) | |
if '429' in str(e): | |
print('Rate limit exceeded. Waiting with random exponential backoff.') | |
if i < 4: | |
time.sleep(random() * (exponential_backoff_higher - exponential_backoff_lower) + exponential_backoff_lower) | |
exponential_backoff_lower *= 2 | |
exponential_backoff_higher *= 2 | |
elif 'candidates[0]' in traceback.format_exc(): | |
# When Gemini has nothing to say, it raises an error with this message | |
print('No response') | |
return 'No response' | |
elif '500' in str(e): | |
# Sometimes Gemini just decides to return a 500 error for absolutely no reason. Retry. | |
print('500 error') | |
time.sleep(5) | |
traceback.print_exc() | |
else: | |
raise e | |
time_end = datetime.datetime.now() | |
usage_info = { | |
'prompt_tokens': response.usage_metadata.prompt_token_count, | |
'completion_tokens': response.usage_metadata.candidates_token_count | |
} | |
total_cost = 0 | |
for cost_name in ['prompt_tokens', 'completion_tokens']: | |
total_cost += COSTS[self.model_name][cost_name] * usage_info[cost_name] | |
register_cost(self.category, total_cost) | |
#send_usage_to_db( | |
# usage_info, | |
# time_start, | |
# time_end, | |
# agent_id, | |
# self.category, | |
# self.model_name | |
#) | |
reply = response.text | |
if print_output: | |
print(reply) | |
return reply | |
class GeminiToolformer(Toolformer): | |
def __init__(self, model_name): | |
self.model_name = model_name | |
def new_conversation(self, system_prompt, tools : List[Tool], category=None) -> Conversation: | |
print('Tools:') | |
print('\n'.join([str(tool.as_openai_info()) for tool in tools])) | |
model = genai.GenerativeModel( | |
model_name=self.model_name, | |
system_instruction=system_prompt, | |
tools=[tool.as_gemini_tool() for tool in tools] | |
) | |
chat = model.start_chat(enable_automatic_function_calling=True) | |
return GeminiConversation(self.model_name, chat, category) | |
def make_gemini_toolformer(model_name): | |
if model_name not in ['gemini-1.5-flash', 'gemini-1.5-pro']: | |
raise ValueError(f"Unknown model name: {model_name}") | |
return GeminiToolformer(model_name) |