Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 3,396 Bytes
15eda99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import logging
from typing import Optional, Dict, Any
import time
from src.db.schemas.models import ModelUsage
from src.db.init_db import session_factory
from datetime import datetime
import tiktoken
from src.routes.analytics_routes import handle_new_model_usage
import asyncio
from src.utils.logger import Logger
from src.utils.model_registry import get_provider_for_model, calculate_cost
logger = Logger(name="ai_manager", see_time=True, console_log=True)
class AI_Manager:
"""Manages AI model interactions and usage tracking"""
def __init__(self):
self.tokenizer = None
# Initialize tokenizer - could use tiktoken or another tokenizer
try:
import tiktoken
self.tokenizer = tiktoken.get_encoding("cl100k_base")
except ImportError:
logger.log_message("Tiktoken not available, using simple tokenizer", level=logging.WARNING)
self.tokenizer = SimpleTokenizer()
def save_usage_to_db(self, user_id, chat_id, model_name, provider,
prompt_tokens, completion_tokens, total_tokens,
query_size, response_size, cost, request_time_ms,
is_streaming=False):
"""Save model usage data to the database"""
try:
session = session_factory()
usage = ModelUsage(
user_id=user_id,
chat_id=chat_id,
model_name=model_name,
provider=provider,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
query_size=query_size,
response_size=response_size,
cost=cost,
timestamp=datetime.utcnow(),
is_streaming=is_streaming,
request_time_ms=request_time_ms
)
session.add(usage)
session.commit()
# logger.info(f"Saved usage data to database for chat {chat_id}: {total_tokens} tokens, ${cost:.6f}")
# Broadcast the event asynchronously
asyncio.create_task(handle_new_model_usage(usage))
except Exception as e:
session.rollback()
logger.log_message(f"Error saving usage data to database for chat {chat_id}: {str(e)}", level=logging.ERROR)
finally:
session.close()
def calculate_cost(self, model_name, input_tokens, output_tokens):
"""Calculate the cost for using the model based on tokens"""
if not model_name:
return 0
# Get provider for logging
model_provider = get_provider_for_model(model_name)
logger.log_message(f"[> ] Model Name: {model_name}, Model Provider: {model_provider}", level=logging.INFO)
# Use the centralized calculate_cost function
return calculate_cost(model_name, input_tokens, output_tokens)
def get_provider_for_model(self, model_name):
"""Determine the provider based on model name"""
# Use the centralized get_provider_for_model function
return get_provider_for_model(model_name)
class SimpleTokenizer:
"""A very simple tokenizer implementation for fallback"""
def encode(self, text):
return len(text.split())
|