Spaces:
Sleeping
Sleeping
Create aws_aiclient.py
Browse files- aws_aiclient.py +138 -0
aws_aiclient.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# aws_aiclient.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import json
|
6 |
+
from typing import List, Dict, Optional, Union, AsyncGenerator
|
7 |
+
import boto3
|
8 |
+
from starlette.responses import StreamingResponse
|
9 |
+
from observability import log_execution, LLMObservabilityManager
|
10 |
+
import psycopg2
|
11 |
+
import logging
|
12 |
+
from langchain_aws import ChatBedrockConverse
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
text_models = {
|
17 |
+
'Claude 3 Sonnet': {
|
18 |
+
'model': 'anthropic.claude-3-sonnet-20240229-v1:0',
|
19 |
+
'input_cost': 0.000003, # $3 per million tokens = $0.000003 per token
|
20 |
+
'output_cost': 0.000015 # $15 per million tokens = $0.000015 per token
|
21 |
+
},
|
22 |
+
'Claude 3 Haiku': {
|
23 |
+
'model': 'anthropic.claude-3-haiku-20240307-v1:0',
|
24 |
+
'input_cost': 0.00000025, # $0.25 per million tokens
|
25 |
+
'output_cost': 0.00000125 # $1.25 per million tokens
|
26 |
+
},
|
27 |
+
'Llama 3 8B': {
|
28 |
+
'model': 'meta.llama3-8b-instruct-v1:0',
|
29 |
+
'input_cost': 0.00000019, # $0.19 per million tokens
|
30 |
+
'output_cost': 0.00000019 # $0.19 per million tokens
|
31 |
+
},
|
32 |
+
'Llama 3 70B': {
|
33 |
+
'model': 'meta.llama3-70b-instruct-v1:0',
|
34 |
+
'input_cost': 0.00000143, # $1.43 per million tokens
|
35 |
+
'output_cost': 0.00000143 # $1.43 per million tokens
|
36 |
+
}
|
37 |
+
}
|
38 |
+
|
39 |
+
class AIClient:
|
40 |
+
def __init__(self):
|
41 |
+
self.client = ChatBedrockConverse(
|
42 |
+
model='meta.llama3-70b-instruct-v1:0', # default model
|
43 |
+
region_name="ap-south-1",
|
44 |
+
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
|
45 |
+
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY")
|
46 |
+
)
|
47 |
+
self.observability_manager = LLMObservabilityManager()
|
48 |
+
self.models = text_models
|
49 |
+
|
50 |
+
async def generate_response(
|
51 |
+
self,
|
52 |
+
messages: List[Dict[str, str]],
|
53 |
+
model: str = "meta.llama3-70b-instruct-v1:0",
|
54 |
+
max_tokens: int = 32000,
|
55 |
+
conversation_id: str = "default",
|
56 |
+
user: str = "anonymous"
|
57 |
+
) -> AsyncGenerator[str, None]:
|
58 |
+
if not messages:
|
59 |
+
return
|
60 |
+
|
61 |
+
start_time = time.time()
|
62 |
+
full_response = ""
|
63 |
+
usage = {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0}
|
64 |
+
status = "success"
|
65 |
+
|
66 |
+
try:
|
67 |
+
# Update the client's model if different from current
|
68 |
+
if model != self.client.model:
|
69 |
+
self.client.model = model
|
70 |
+
|
71 |
+
# Stream the response
|
72 |
+
async for chunk in self.client.astream(messages):
|
73 |
+
if chunk.content and chunk.content[0].get("text"):
|
74 |
+
content = chunk.content[0].get("text")
|
75 |
+
yield content
|
76 |
+
full_response += content
|
77 |
+
|
78 |
+
if chunk.usage_metadata:
|
79 |
+
usage["prompt_tokens"] = chunk.usage_metadata.get("input_tokens", 0)
|
80 |
+
usage["completion_tokens"] = chunk.usage_metadata.get("output_tokens", 0)
|
81 |
+
usage["total_tokens"] = chunk.usage_metadata.get("total_tokens", 0)
|
82 |
+
|
83 |
+
except Exception as e:
|
84 |
+
status = "error"
|
85 |
+
full_response = str(e)
|
86 |
+
print(f"Error in generate_response: {e}")
|
87 |
+
|
88 |
+
finally:
|
89 |
+
latency = time.time() - start_time
|
90 |
+
|
91 |
+
# Calculate cost based on the model being used
|
92 |
+
model_name = next((name for name, info in text_models.items()
|
93 |
+
if info['model'] == model), None)
|
94 |
+
if model_name:
|
95 |
+
model_info = text_models[model_name]
|
96 |
+
cost = (usage["prompt_tokens"] * model_info["input_cost"] +
|
97 |
+
usage["completion_tokens"] * model_info["output_cost"])
|
98 |
+
else:
|
99 |
+
cost = 0 # Default if model not found
|
100 |
+
|
101 |
+
try:
|
102 |
+
self.observability_manager.insert_observation(
|
103 |
+
response=full_response,
|
104 |
+
model=model,
|
105 |
+
completion_tokens=usage["completion_tokens"],
|
106 |
+
prompt_tokens=usage["prompt_tokens"],
|
107 |
+
total_tokens=usage["total_tokens"],
|
108 |
+
cost=cost,
|
109 |
+
conversation_id=conversation_id,
|
110 |
+
status=status,
|
111 |
+
request=json.dumps([msg for msg in messages if msg.get('role') != 'system']),
|
112 |
+
latency=latency,
|
113 |
+
user=user
|
114 |
+
)
|
115 |
+
except Exception as obs_error:
|
116 |
+
print(f"Error logging observation: {obs_error}")
|
117 |
+
|
118 |
+
class DatabaseManager:
|
119 |
+
"""Manages database operations."""
|
120 |
+
|
121 |
+
def __init__(self):
|
122 |
+
self.db_params = {
|
123 |
+
"dbname": "postgres",
|
124 |
+
"user": os.environ['SUPABASE_USER'],
|
125 |
+
"password": os.environ['SUPABASE_PASSWORD'],
|
126 |
+
"host": "aws-0-us-west-1.pooler.supabase.com",
|
127 |
+
"port": "5432"
|
128 |
+
}
|
129 |
+
|
130 |
+
@log_execution
|
131 |
+
def update_database(self, user_id: str, user_query: str, response: str) -> None:
|
132 |
+
with psycopg2.connect(**self.db_params) as conn:
|
133 |
+
with conn.cursor() as cur:
|
134 |
+
insert_query = """
|
135 |
+
INSERT INTO ai_document_generator (user_id, user_query, response)
|
136 |
+
VALUES (%s, %s, %s);
|
137 |
+
"""
|
138 |
+
cur.execute(insert_query, (user_id, user_query, response))
|