pvanand commited on
Commit
e577d93
·
verified ·
1 Parent(s): 003f154

Create aws_aiclient.py

Browse files
Files changed (1) hide show
  1. aws_aiclient.py +138 -0
aws_aiclient.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # aws_aiclient.py
2
+
3
+ import os
4
+ import time
5
+ import json
6
+ from typing import List, Dict, Optional, Union, AsyncGenerator
7
+ import boto3
8
+ from starlette.responses import StreamingResponse
9
+ from observability import log_execution, LLMObservabilityManager
10
+ import psycopg2
11
+ import logging
12
+ from langchain_aws import ChatBedrockConverse
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ text_models = {
17
+ 'Claude 3 Sonnet': {
18
+ 'model': 'anthropic.claude-3-sonnet-20240229-v1:0',
19
+ 'input_cost': 0.000003, # $3 per million tokens = $0.000003 per token
20
+ 'output_cost': 0.000015 # $15 per million tokens = $0.000015 per token
21
+ },
22
+ 'Claude 3 Haiku': {
23
+ 'model': 'anthropic.claude-3-haiku-20240307-v1:0',
24
+ 'input_cost': 0.00000025, # $0.25 per million tokens
25
+ 'output_cost': 0.00000125 # $1.25 per million tokens
26
+ },
27
+ 'Llama 3 8B': {
28
+ 'model': 'meta.llama3-8b-instruct-v1:0',
29
+ 'input_cost': 0.00000019, # $0.19 per million tokens
30
+ 'output_cost': 0.00000019 # $0.19 per million tokens
31
+ },
32
+ 'Llama 3 70B': {
33
+ 'model': 'meta.llama3-70b-instruct-v1:0',
34
+ 'input_cost': 0.00000143, # $1.43 per million tokens
35
+ 'output_cost': 0.00000143 # $1.43 per million tokens
36
+ }
37
+ }
38
+
39
+ class AIClient:
40
+ def __init__(self):
41
+ self.client = ChatBedrockConverse(
42
+ model='meta.llama3-70b-instruct-v1:0', # default model
43
+ region_name="ap-south-1",
44
+ aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
45
+ aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY")
46
+ )
47
+ self.observability_manager = LLMObservabilityManager()
48
+ self.models = text_models
49
+
50
+ async def generate_response(
51
+ self,
52
+ messages: List[Dict[str, str]],
53
+ model: str = "meta.llama3-70b-instruct-v1:0",
54
+ max_tokens: int = 32000,
55
+ conversation_id: str = "default",
56
+ user: str = "anonymous"
57
+ ) -> AsyncGenerator[str, None]:
58
+ if not messages:
59
+ return
60
+
61
+ start_time = time.time()
62
+ full_response = ""
63
+ usage = {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0}
64
+ status = "success"
65
+
66
+ try:
67
+ # Update the client's model if different from current
68
+ if model != self.client.model:
69
+ self.client.model = model
70
+
71
+ # Stream the response
72
+ async for chunk in self.client.astream(messages):
73
+ if chunk.content and chunk.content[0].get("text"):
74
+ content = chunk.content[0].get("text")
75
+ yield content
76
+ full_response += content
77
+
78
+ if chunk.usage_metadata:
79
+ usage["prompt_tokens"] = chunk.usage_metadata.get("input_tokens", 0)
80
+ usage["completion_tokens"] = chunk.usage_metadata.get("output_tokens", 0)
81
+ usage["total_tokens"] = chunk.usage_metadata.get("total_tokens", 0)
82
+
83
+ except Exception as e:
84
+ status = "error"
85
+ full_response = str(e)
86
+ print(f"Error in generate_response: {e}")
87
+
88
+ finally:
89
+ latency = time.time() - start_time
90
+
91
+ # Calculate cost based on the model being used
92
+ model_name = next((name for name, info in text_models.items()
93
+ if info['model'] == model), None)
94
+ if model_name:
95
+ model_info = text_models[model_name]
96
+ cost = (usage["prompt_tokens"] * model_info["input_cost"] +
97
+ usage["completion_tokens"] * model_info["output_cost"])
98
+ else:
99
+ cost = 0 # Default if model not found
100
+
101
+ try:
102
+ self.observability_manager.insert_observation(
103
+ response=full_response,
104
+ model=model,
105
+ completion_tokens=usage["completion_tokens"],
106
+ prompt_tokens=usage["prompt_tokens"],
107
+ total_tokens=usage["total_tokens"],
108
+ cost=cost,
109
+ conversation_id=conversation_id,
110
+ status=status,
111
+ request=json.dumps([msg for msg in messages if msg.get('role') != 'system']),
112
+ latency=latency,
113
+ user=user
114
+ )
115
+ except Exception as obs_error:
116
+ print(f"Error logging observation: {obs_error}")
117
+
118
+ class DatabaseManager:
119
+ """Manages database operations."""
120
+
121
+ def __init__(self):
122
+ self.db_params = {
123
+ "dbname": "postgres",
124
+ "user": os.environ['SUPABASE_USER'],
125
+ "password": os.environ['SUPABASE_PASSWORD'],
126
+ "host": "aws-0-us-west-1.pooler.supabase.com",
127
+ "port": "5432"
128
+ }
129
+
130
+ @log_execution
131
+ def update_database(self, user_id: str, user_query: str, response: str) -> None:
132
+ with psycopg2.connect(**self.db_params) as conn:
133
+ with conn.cursor() as cur:
134
+ insert_query = """
135
+ INSERT INTO ai_document_generator (user_id, user_query, response)
136
+ VALUES (%s, %s, %s);
137
+ """
138
+ cur.execute(insert_query, (user_id, user_query, response))