Spaces:
Sleeping
Sleeping
File size: 5,242 Bytes
e577d93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# aws_aiclient.py
import os
import time
import json
from typing import List, Dict, Optional, Union, AsyncGenerator
import boto3
from starlette.responses import StreamingResponse
from observability import log_execution, LLMObservabilityManager
import psycopg2
import logging
from langchain_aws import ChatBedrockConverse
logger = logging.getLogger(__name__)
text_models = {
'Claude 3 Sonnet': {
'model': 'anthropic.claude-3-sonnet-20240229-v1:0',
'input_cost': 0.000003, # $3 per million tokens = $0.000003 per token
'output_cost': 0.000015 # $15 per million tokens = $0.000015 per token
},
'Claude 3 Haiku': {
'model': 'anthropic.claude-3-haiku-20240307-v1:0',
'input_cost': 0.00000025, # $0.25 per million tokens
'output_cost': 0.00000125 # $1.25 per million tokens
},
'Llama 3 8B': {
'model': 'meta.llama3-8b-instruct-v1:0',
'input_cost': 0.00000019, # $0.19 per million tokens
'output_cost': 0.00000019 # $0.19 per million tokens
},
'Llama 3 70B': {
'model': 'meta.llama3-70b-instruct-v1:0',
'input_cost': 0.00000143, # $1.43 per million tokens
'output_cost': 0.00000143 # $1.43 per million tokens
}
}
class AIClient:
def __init__(self):
self.client = ChatBedrockConverse(
model='meta.llama3-70b-instruct-v1:0', # default model
region_name="ap-south-1",
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY")
)
self.observability_manager = LLMObservabilityManager()
self.models = text_models
async def generate_response(
self,
messages: List[Dict[str, str]],
model: str = "meta.llama3-70b-instruct-v1:0",
max_tokens: int = 32000,
conversation_id: str = "default",
user: str = "anonymous"
) -> AsyncGenerator[str, None]:
if not messages:
return
start_time = time.time()
full_response = ""
usage = {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0}
status = "success"
try:
# Update the client's model if different from current
if model != self.client.model:
self.client.model = model
# Stream the response
async for chunk in self.client.astream(messages):
if chunk.content and chunk.content[0].get("text"):
content = chunk.content[0].get("text")
yield content
full_response += content
if chunk.usage_metadata:
usage["prompt_tokens"] = chunk.usage_metadata.get("input_tokens", 0)
usage["completion_tokens"] = chunk.usage_metadata.get("output_tokens", 0)
usage["total_tokens"] = chunk.usage_metadata.get("total_tokens", 0)
except Exception as e:
status = "error"
full_response = str(e)
print(f"Error in generate_response: {e}")
finally:
latency = time.time() - start_time
# Calculate cost based on the model being used
model_name = next((name for name, info in text_models.items()
if info['model'] == model), None)
if model_name:
model_info = text_models[model_name]
cost = (usage["prompt_tokens"] * model_info["input_cost"] +
usage["completion_tokens"] * model_info["output_cost"])
else:
cost = 0 # Default if model not found
try:
self.observability_manager.insert_observation(
response=full_response,
model=model,
completion_tokens=usage["completion_tokens"],
prompt_tokens=usage["prompt_tokens"],
total_tokens=usage["total_tokens"],
cost=cost,
conversation_id=conversation_id,
status=status,
request=json.dumps([msg for msg in messages if msg.get('role') != 'system']),
latency=latency,
user=user
)
except Exception as obs_error:
print(f"Error logging observation: {obs_error}")
class DatabaseManager:
"""Manages database operations."""
def __init__(self):
self.db_params = {
"dbname": "postgres",
"user": os.environ['SUPABASE_USER'],
"password": os.environ['SUPABASE_PASSWORD'],
"host": "aws-0-us-west-1.pooler.supabase.com",
"port": "5432"
}
@log_execution
def update_database(self, user_id: str, user_query: str, response: str) -> None:
with psycopg2.connect(**self.db_params) as conn:
with conn.cursor() as cur:
insert_query = """
INSERT INTO ai_document_generator (user_id, user_query, response)
VALUES (%s, %s, %s);
"""
cur.execute(insert_query, (user_id, user_query, response)) |