samuelemarro's picture
Added cost tracking.
c07f594
raw
history blame
3.81 kB
import datetime
import os
from random import random
import time
import traceback
from typing import List
from toolformers.base import Conversation, Tool, Toolformer
import google.generativeai as genai
from google.generativeai.generative_models import ChatSession
from utils import register_cost
genai.configure(api_key=os.environ['GOOGLE_API_KEY'])
COSTS = {
'gemini-1.5-pro': {
'prompt_tokens': 1.25e-6,
'completion_tokens': 5e-6
}
}
class GeminiConversation(Conversation):
def __init__(self, model_name, chat_agent : ChatSession, category=None):
self.model_name = model_name
self.chat_agent = chat_agent
self.category = category
def chat(self, message, role='user', print_output=True):
agent_id = os.environ.get('AGENT_ID', None)
time_start = datetime.datetime.now()
exponential_backoff_lower = 30
exponential_backoff_higher = 60
for i in range(5):
try:
response = self.chat_agent.send_message({
'role': role,
'parts': [
message
]
})
break
except Exception as e:
print(e)
if '429' in str(e):
print('Rate limit exceeded. Waiting with random exponential backoff.')
if i < 4:
time.sleep(random() * (exponential_backoff_higher - exponential_backoff_lower) + exponential_backoff_lower)
exponential_backoff_lower *= 2
exponential_backoff_higher *= 2
elif 'candidates[0]' in traceback.format_exc():
# When Gemini has nothing to say, it raises an error with this message
print('No response')
return 'No response'
elif '500' in str(e):
# Sometimes Gemini just decides to return a 500 error for absolutely no reason. Retry.
print('500 error')
time.sleep(5)
traceback.print_exc()
else:
raise e
time_end = datetime.datetime.now()
usage_info = {
'prompt_tokens': response.usage_metadata.prompt_token_count,
'completion_tokens': response.usage_metadata.candidates_token_count
}
total_cost = 0
for cost_name in ['prompt_tokens', 'completion_tokens']:
total_cost += COSTS[self.model_name][cost_name] * usage_info[cost_name]
register_cost(self.category, total_cost)
#send_usage_to_db(
# usage_info,
# time_start,
# time_end,
# agent_id,
# self.category,
# self.model_name
#)
reply = response.text
if print_output:
print(reply)
return reply
class GeminiToolformer(Toolformer):
def __init__(self, model_name):
self.model_name = model_name
def new_conversation(self, system_prompt, tools : List[Tool], category=None) -> Conversation:
print('Tools:')
print('\n'.join([str(tool.as_openai_info()) for tool in tools]))
model = genai.GenerativeModel(
model_name=self.model_name,
system_instruction=system_prompt,
tools=[tool.as_gemini_tool() for tool in tools]
)
chat = model.start_chat(enable_automatic_function_calling=True)
return GeminiConversation(self.model_name, chat, category)
def make_gemini_toolformer(model_name):
if model_name not in ['gemini-1.5-flash', 'gemini-1.5-pro']:
raise ValueError(f"Unknown model name: {model_name}")
return GeminiToolformer(model_name)