File size: 5,242 Bytes
e577d93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# aws_aiclient.py

import os
import time
import json
from typing import List, Dict, Optional, Union, AsyncGenerator
import boto3
from starlette.responses import StreamingResponse
from observability import log_execution, LLMObservabilityManager
import psycopg2
import logging
from langchain_aws import ChatBedrockConverse

logger = logging.getLogger(__name__)

text_models = {
    'Claude 3 Sonnet': {
        'model': 'anthropic.claude-3-sonnet-20240229-v1:0',
        'input_cost': 0.000003,  # $3 per million tokens = $0.000003 per token
        'output_cost': 0.000015  # $15 per million tokens = $0.000015 per token
    },
    'Claude 3 Haiku': {
        'model': 'anthropic.claude-3-haiku-20240307-v1:0',
        'input_cost': 0.00000025,  # $0.25 per million tokens
        'output_cost': 0.00000125  # $1.25 per million tokens
    },
    'Llama 3 8B': {
        'model': 'meta.llama3-8b-instruct-v1:0',
        'input_cost': 0.00000019,  # $0.19 per million tokens
        'output_cost': 0.00000019  # $0.19 per million tokens
    },
    'Llama 3 70B': {
        'model': 'meta.llama3-70b-instruct-v1:0',
        'input_cost': 0.00000143,  # $1.43 per million tokens
        'output_cost': 0.00000143  # $1.43 per million tokens
    }
}

class AIClient:
    def __init__(self):
        self.client = ChatBedrockConverse(
            model='meta.llama3-70b-instruct-v1:0',  # default model
            region_name="ap-south-1",
            aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
            aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY")
        )
        self.observability_manager = LLMObservabilityManager()
        self.models = text_models

    async def generate_response(
        self,
        messages: List[Dict[str, str]],
        model: str = "meta.llama3-70b-instruct-v1:0",
        max_tokens: int = 32000,
        conversation_id: str = "default",
        user: str = "anonymous"
    ) -> AsyncGenerator[str, None]:
        if not messages:
            return

        start_time = time.time()
        full_response = ""
        usage = {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0}
        status = "success"

        try:
            # Update the client's model if different from current
            if model != self.client.model:
                self.client.model = model

            # Stream the response
            async for chunk in self.client.astream(messages):
                if chunk.content and chunk.content[0].get("text"):
                    content = chunk.content[0].get("text")
                    yield content
                    full_response += content

                if chunk.usage_metadata:
                    usage["prompt_tokens"] = chunk.usage_metadata.get("input_tokens", 0)
                    usage["completion_tokens"] = chunk.usage_metadata.get("output_tokens", 0)
                    usage["total_tokens"] = chunk.usage_metadata.get("total_tokens", 0)

        except Exception as e:
            status = "error"
            full_response = str(e)
            print(f"Error in generate_response: {e}")
            
        finally:
            latency = time.time() - start_time
            
            # Calculate cost based on the model being used
            model_name = next((name for name, info in text_models.items() 
                             if info['model'] == model), None)
            if model_name:
                model_info = text_models[model_name]
                cost = (usage["prompt_tokens"] * model_info["input_cost"] +
                       usage["completion_tokens"] * model_info["output_cost"])
            else:
                cost = 0  # Default if model not found

            try:
                self.observability_manager.insert_observation(
                    response=full_response,
                    model=model,
                    completion_tokens=usage["completion_tokens"],
                    prompt_tokens=usage["prompt_tokens"],
                    total_tokens=usage["total_tokens"],
                    cost=cost,
                    conversation_id=conversation_id,
                    status=status,
                    request=json.dumps([msg for msg in messages if msg.get('role') != 'system']),
                    latency=latency,
                    user=user
                )
            except Exception as obs_error:
                print(f"Error logging observation: {obs_error}")

class DatabaseManager:
    """Manages database operations."""

    def __init__(self):
        self.db_params = {
            "dbname": "postgres",
            "user": os.environ['SUPABASE_USER'],
            "password": os.environ['SUPABASE_PASSWORD'],
            "host": "aws-0-us-west-1.pooler.supabase.com",
            "port": "5432"
        }

    @log_execution
    def update_database(self, user_id: str, user_query: str, response: str) -> None:
        with psycopg2.connect(**self.db_params) as conn:
            with conn.cursor() as cur:
                insert_query = """
                INSERT INTO ai_document_generator (user_id, user_query, response)
                VALUES (%s, %s, %s);
                """
                cur.execute(insert_query, (user_id, user_query, response))