Zekun Wu commited on
Commit
97becca
1 Parent(s): f5c8eb4
Files changed (3) hide show
  1. pages/1_Injection.py +13 -6
  2. requirements.txt +2 -1
  3. util/model.py +47 -0
pages/1_Injection.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
  import pandas as pd
3
  from io import StringIO
4
  from util.injection import process_scores_multiple
5
- from util.model import AzureAgent, GPTAgent
6
  from util.prompt import PROMPT_TEMPLATE
7
  import os
8
 
@@ -49,14 +49,18 @@ else:
49
  st.sidebar.title('Model Settings')
50
  initialize_state()
51
 
 
 
52
  # Model selection and configuration
53
- model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent'))
54
  st.session_state.api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key)
55
  st.session_state.endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url)
56
  st.session_state.deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name)
57
- api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else ''
58
- st.session_state.temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01)
59
- st.session_state.max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens)
 
 
60
 
61
  if st.sidebar.button("Reset Model Info"):
62
  initialize_state() # Reset all state to defaults
@@ -111,9 +115,12 @@ else:
111
  if model_type == 'AzureAgent':
112
  agent = AzureAgent(st.session_state.api_key, st.session_state.endpoint_url,
113
  st.session_state.deployment_name)
114
- else:
115
  agent = GPTAgent(st.session_state.api_key, st.session_state.endpoint_url,
116
  st.session_state.deployment_name, api_version)
 
 
 
117
 
118
  with st.spinner('Processing data...'):
119
  parameters = {"temperature": st.session_state.temperature, "max_tokens": st.session_state.max_tokens}
 
2
  import pandas as pd
3
  from io import StringIO
4
  from util.injection import process_scores_multiple
5
+ from util.model import AzureAgent, GPTAgent,Claude3Agent
6
  from util.prompt import PROMPT_TEMPLATE
7
  import os
8
 
 
49
  st.sidebar.title('Model Settings')
50
  initialize_state()
51
 
52
+
53
+
54
  # Model selection and configuration
55
+ model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent','Claude3Agent'))
56
  st.session_state.api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key)
57
  st.session_state.endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url)
58
  st.session_state.deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name)
59
+
60
+ if model_type == 'GPTAgent' or model_type == 'AzureAgent':
61
+ api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else ''
62
+ st.session_state.temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01)
63
+ st.session_state.max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens)
64
 
65
  if st.sidebar.button("Reset Model Info"):
66
  initialize_state() # Reset all state to defaults
 
115
  if model_type == 'AzureAgent':
116
  agent = AzureAgent(st.session_state.api_key, st.session_state.endpoint_url,
117
  st.session_state.deployment_name)
118
+ elif model_type == 'GPTAgent':
119
  agent = GPTAgent(st.session_state.api_key, st.session_state.endpoint_url,
120
  st.session_state.deployment_name, api_version)
121
+ else:
122
+ agent = Claude3Agent(st.session_state.api_key,st.session_state.deployment_name)
123
+
124
 
125
  with st.spinner('Processing data...'):
126
  parameters = {"temperature": st.session_state.temperature, "max_tokens": st.session_state.max_tokens}
requirements.txt CHANGED
@@ -5,4 +5,5 @@ scipy
5
  statsmodels
6
  scikit-posthocs
7
  json-repair
8
- plotly
 
 
5
  statsmodels
6
  scikit-posthocs
7
  json-repair
8
+ plotly
9
+ boto3
util/model.py CHANGED
@@ -1,6 +1,49 @@
1
  import json
2
  import http.client
3
  from openai import AzureOpenAI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  class ContentFormatter:
6
  @staticmethod
@@ -53,3 +96,7 @@ class GPTAgent:
53
  **kwargs
54
  )
55
  return response.choices[0].message.content
 
 
 
 
 
1
  import json
2
  import http.client
3
  from openai import AzureOpenAI
4
+ import time
5
+ from tqdm import tqdm
6
+ from typing import Any, List
7
+ from botocore.exceptions import ClientError
8
+ from enum import Enum
9
+ import boto3
10
+ import json
11
+ import logging
12
+
13
+
14
+ class Model(Enum):
15
+ CLAUDE3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0"
16
+ CLAUDE3_HAIKU = "anthropic.claude-3-haiku-20240307-v1:0"
17
+
18
+
19
+ class Claude3Agent:
20
+ def __init__(self, aws_secret_access_key: str,model: str ):
21
+ self.client = boto3.client("bedrock-runtime", region_name="us-east-1", aws_access_key_id="AKIAZR6ZJPKTKJAMLP5W",
22
+ aws_secret_access_key=aws_secret_access_key)
23
+ if model == "SONNET":
24
+ self.model = Model.CLAUDE3_SONNET
25
+ elif model == "HAIKU":
26
+ self.model = Model.CLAUDE3_HAIKU
27
+ else:
28
+ raise ValueError("Invalid model type. Please choose from 'SONNET' or 'HAIKU' models.")
29
+
30
+ def invoke(self, text: str,**kwargs) -> str:
31
+ try:
32
+ body = json.dumps(
33
+ {
34
+ "anthropic_version": "bedrock-2023-05-31",
35
+ "messages": [
36
+ {"role": "user", "content": [{"type": "text", "text": text}]}
37
+ ],
38
+ **kwargs
39
+ }
40
+ )
41
+ response = self.client.invoke_model(modelId=self.model.value, body=body)
42
+ completion = json.loads(response["body"].read())["content"][0]["text"]
43
+ return completion
44
+ except ClientError:
45
+ logging.error("Couldn't invoke model")
46
+ raise
47
 
48
  class ContentFormatter:
49
  @staticmethod
 
96
  **kwargs
97
  )
98
  return response.choices[0].message.content
99
+
100
+ if __name__ == '__main__':
101
+ agent = Claude3Agent("TyzS1CYdvhYtes+V9u2qqS5sggS3asSeXAfYYvOS", "SONNET")
102
+ print(agent.invoke("I am a software engineer.", max_tokens=200, temperature=0.5))