File size: 9,709 Bytes
f66af3d ec5ad7b f66af3d 91db358 f66af3d 48ee539 f66af3d 48ee539 f66af3d 4538fa2 3487766 f66af3d 48ee539 2b559c9 f66af3d b9b9709 f66af3d 48ee539 f66af3d 6182f9d dca8500 6182f9d d8c7d29 2fade7b 91db358 f1ebb39 028e695 91db358 dca8500 91db358 dca8500 4256990 2fade7b dca8500 4256990 dca8500 91db358 2fade7b f8f734d 91db358 4256990 4499016 4256990 dca8500 91db358 4256990 28522a8 91db358 2fade7b f8f734d dca8500 4256990 28522a8 dca8500 91db358 2fade7b f8f734d 2fade7b f66af3d 2fade7b f66af3d 48ee539 f66af3d 028e695 f66af3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
import gradio as gr
import boto3
import json
import os
import numpy as np
import botocore
import time
from scipy.spatial.distance import cosine as cosine_similarity
theme = gr.themes.Base(text_size='sm')
# Retrieve AWS credentials from environment variables
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY')
AWS_REGION = os.getenv('REGION_NAME')
AWS_SESSION = os.getenv('AWS_SESSION')
BUCKET_NAME = os.getenv('BUCKET_NAME')
EXTRACTIONS_PATH = os.getenv('EXTRACTIONS_PATH')
employee_type = None
division = None
authenticated = False
extractions = {}
# Create AWS Bedrock client using environment variables
def create_bedrock_client():
return boto3.client(
'bedrock-runtime',
region_name=AWS_REGION,
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
aws_session_token=AWS_SESSION
)
# Create AWS S3 client using environment variables
def create_s3_client():
return boto3.client(
's3',
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
aws_session_token=AWS_SESSION
)
# Read JSON directly into mem from S3
def read_json_from_s3():
response = s3_client.get_object(Bucket=BUCKET_NAME, Key=EXTRACTIONS_PATH)
file_content = response['Body'].read().decode('utf-8')
json_content = json.loads(file_content)
return json_content
# Get AWS Titan embedding of text
def get_titan_embedding(bedrock_client, doc_name, text, attempt=0, cutoff=10000):
"""
Retrieves a text embedding for a given document using the Amazon Titan Embedding model.
This function sends the provided text to the Amazon Titan text embedding model
and retrieves the resulting embedding. It handles retries for throttling exceptions
and input size limitations by recursively calling itself with adjusted parameters.
Parameters:
doc_name (str): The name of the document, used for logging and error messages.
text (str): The text content to be processed by the Titan embedding model.
attempt (int): The current attempt number (used in recursive calls to handle retries). Defaults to 0.
cutoff (int): The maximum number of words to include from the input text if a ValidationException occurs due to input size limits. Defaults to 5000.
Returns:
dict or None: The embedding response from the Titan model as a dictionary, or None if the operation fails or exceeds the retry limits.
"""
retries = 5
model_id = 'amazon.titan-embed-text-v1'
accept = 'application/json'
content_type = 'application/json'
body = json.dumps({
"inputText": text,
})
# Invoke model
response = bedrock_client.invoke_model(
body=body,
modelId=model_id,
accept=accept,
contentType=content_type
)
# Print response
response_body = json.loads(response['body'].read())
return response_body.get('embedding')
# Main Chat
def ask_ds(message, history):
global employee_type
global division
global authenticated
global extractions
global EXTRACTIONS_PATH
if len(message) == 0:
yield None
if message == 'admin_reset':
authenticated = False
yield "Select your division:\n[1] IME \n[2] PAS\n[3] Peer Disability"
if authenticated == False:
divisions = {'1': 'ime', '2': 'pas', '3': 'peer disability'}
if division == None:
if message.lower().strip() in list(divisions.values()):
division = message.lower().strip()
yield "[1] CSR\n[2] QA"
elif message.lower().strip() in list(divisions.keys()):
division = divisions[message.lower().strip()]
yield "[1] CSR\n[2] QA"
else:
yield "Please select a valid choice."
elif employee_type == None:
if division == 'peer disability':
division = 'dis'
employee_types = {'1': 'csr', '2': 'qa'}
if message.lower().strip() in list(employee_types.values()):
employee_type = message.lower().strip()
authenticated = True
EXTRACTIONS_PATH = EXTRACTIONS_PATH.replace('{employee_type}', employee_type).replace('{division}', division)
yield EXTRACTIONS_PATH
extractions = read_json_from_s3()
yield "Welcome to Ask Dane Street! Whether you're new to the team or just looking for some quick information, I'm here to guide you through our company's literature and platform. Simply ask your question, and I'll provide you with the most relevant information I can."
elif message.lower().strip() in list(employee_types.keys()):
employee_type = employee_types[message.lower().strip()]
authenticated = True
EXTRACTIONS_PATH = EXTRACTIONS_PATH.replace('{employee_type}', employee_type).replace('{division}', division)
yield EXTRACTIONS_PATH
extractions = read_json_from_s3()
yield "Welcome to Ask Dane Street! Whether you're new to the team or just looking for some quick information, I'm here to guide you through our company's literature and platform. Simply ask your question, and I'll provide you with the most relevant information I can."
else:
yield "Please select a valid choice."
else:
question = message
# RAG
question_embedding = get_titan_embedding(bedrock_client, 'question', question)
similar_documents = []
for file, data in extractions.items():
similarity = cosine_similarity(question_embedding, np.array(data['embedding']))
similar_documents.append((file, similarity))
similar_documents.sort(key=lambda x: x[1], reverse=False)
top_docs = similar_documents[:5]
similar_content = ''
for file, _ in top_docs:
similar_content += extractions[file]['content'] + '\n'
# Invoke
response = bedrock_client.invoke_model_with_response_stream(
modelId="anthropic.claude-3-sonnet-20240229-v1:0",
body=json.dumps(
{
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 4096,
"system": f"""Here is some relevant information that may help answer the user's upcoming question:
<relevant_information>
{similar_content}
</relevant_information>
The user's question is:
<question>{question}</question>
Please carefully review the relevant information provided above.
Your task is to review the provided relevant information and answer the user's question to the best of your ability.
Aim to use information from the relevant information section to directly address the question asked, and refrain from saying
things like 'According to the relevant information provided'.
Format your output nicely with sentences that are not too long, in a professional and kind tone. You should prefer lists or
bullet points when applicable. Begin by thanking the user for their question, and at the end of your answer, say "Thank you for using Ask Dane Street!"
Remember, aim to only use information from the relevant information section in your response, without explicitly referring
to that section. Return your answer immediately and without preamble.
</Task>
<Relevant Information>
{similar_content}
</Relevant Information>""",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": message
}
]
}
],
}
),
)
# Stream the response
all_text = ''
stream = response.get('body')
if stream:
for event in stream:
chunk = event.get('chunk')
if chunk and json.loads(chunk.get('bytes').decode()):
# check if delta is present
try:
this_text = json.loads(chunk.get('bytes').decode()).get('delta').get('text')
all_text += this_text
yield all_text # Stream the text back to the UI
except:
pass
# Print relevant files
output = '\n\nCheck out the following documents for more information:\n'
for file, sim in top_docs:
output += f"\n{file.replace('.txt', '.pdf')}"
yield all_text + output
# Create necessary services and collect data
bedrock_client = create_bedrock_client()
s3_client = create_s3_client()
demo = gr.ChatInterface(fn=ask_ds, title="Ask DS", multimodal=False, chatbot=gr.Chatbot(value=[(None, "Please select your division:\n[1] IME \n[2] PAS\n[3] Peer Disability")],),theme=theme)
demo.launch()
|