Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import boto3
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import botocore
|
7 |
+
import time
|
8 |
+
|
9 |
+
theme = gr.themes.Base(text_size='sm')
|
10 |
+
|
11 |
+
# Retrieve AWS credentials from environment variables
|
12 |
+
|
13 |
+
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
|
14 |
+
AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY')
|
15 |
+
AWS_REGION = os.getenv('REGION_NAME')
|
16 |
+
AWS_SESSION = os.getenv('AWS_SESSION')
|
17 |
+
BUCKET_NAME = os.getenv('BUCKET_NAME')
|
18 |
+
EXTRACTIONS_PATH = os.getenv('EXTRACTIONS_PATH')
|
19 |
+
|
20 |
+
# Create AWS Bedrock client using environment variables
|
21 |
+
def create_bedrock_client():
|
22 |
+
|
23 |
+
return boto3.client(
|
24 |
+
'bedrock-runtime',
|
25 |
+
region_name=AWS_REGION,
|
26 |
+
aws_access_key_id=AWS_ACCESS_KEY_ID,
|
27 |
+
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
|
28 |
+
aws_session_token=AWS_SESSION
|
29 |
+
)
|
30 |
+
|
31 |
+
def create_s3_client():
|
32 |
+
|
33 |
+
# Create an S3 client
|
34 |
+
return boto3.client(
|
35 |
+
's3',
|
36 |
+
aws_access_key_id=AWS_ACCESS_KEY_ID,
|
37 |
+
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
|
38 |
+
aws_session_token=AWS_SESSION
|
39 |
+
)
|
40 |
+
|
41 |
+
def read_json_from_s3():
|
42 |
+
|
43 |
+
try:
|
44 |
+
response = s3_client.get_object(Bucket=BUCKET_NAME, Key=EXTRACTIONS_PATH)
|
45 |
+
file_content = response['Body'].read().decode('utf-8')
|
46 |
+
json_content = json.loads(file_content)
|
47 |
+
except Exception as e:
|
48 |
+
yield f"Error reading JSON file from S3: {e}"
|
49 |
+
return None
|
50 |
+
|
51 |
+
return s3_file_path, json_content
|
52 |
+
|
53 |
+
def get_titan_embedding(bedrock, doc_name, text, attempt=0, cutoff=10000):
|
54 |
+
"""
|
55 |
+
Retrieves a text embedding for a given document using the Amazon Titan Embedding model.
|
56 |
+
|
57 |
+
This function sends the provided text to the Amazon Titan text embedding model
|
58 |
+
and retrieves the resulting embedding. It handles retries for throttling exceptions
|
59 |
+
and input size limitations by recursively calling itself with adjusted parameters.
|
60 |
+
|
61 |
+
Parameters:
|
62 |
+
doc_name (str): The name of the document, used for logging and error messages.
|
63 |
+
text (str): The text content to be processed by the Titan embedding model.
|
64 |
+
attempt (int): The current attempt number (used in recursive calls to handle retries). Defaults to 0.
|
65 |
+
cutoff (int): The maximum number of words to include from the input text if a ValidationException occurs due to input size limits. Defaults to 5000.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
dict or None: The embedding response from the Titan model as a dictionary, or None if the operation fails or exceeds the retry limits.
|
69 |
+
"""
|
70 |
+
|
71 |
+
retries = 5
|
72 |
+
|
73 |
+
try:
|
74 |
+
model_id = 'amazon.titan-embed-text-v1'
|
75 |
+
accept = 'application/json'
|
76 |
+
content_type = 'application/json'
|
77 |
+
|
78 |
+
body = json.dumps({
|
79 |
+
"inputText": text,
|
80 |
+
})
|
81 |
+
|
82 |
+
# Invoke model
|
83 |
+
response = bedrock.invoke_model(
|
84 |
+
body=body,
|
85 |
+
modelId=model_id,
|
86 |
+
accept=accept,
|
87 |
+
contentType=content_type
|
88 |
+
)
|
89 |
+
|
90 |
+
# Print response
|
91 |
+
response_body = json.loads(response['body'].read())
|
92 |
+
|
93 |
+
|
94 |
+
# Handle a few common client exceptions
|
95 |
+
except botocore.exceptions.ClientError as error:
|
96 |
+
if error.response['Error']['Code'] == 'ThrottlingException':
|
97 |
+
if attempt + 1 == retries:
|
98 |
+
return None
|
99 |
+
|
100 |
+
delay = 2 ** (attempt + 1);
|
101 |
+
time.sleep(delay)
|
102 |
+
return get_titan_embedding(doc_name, text, attempt=attempt + 1)
|
103 |
+
|
104 |
+
elif error.response['Error']['Code'] == 'ValidationException':
|
105 |
+
# get chunks of text length 20000 characters
|
106 |
+
text_chunks = [text[i:i+cutoff] for i in range(0, len(text), cutoff)]
|
107 |
+
embeddings = []
|
108 |
+
for chunk in text_chunks:
|
109 |
+
embeddings.append(get_titan_embedding(bedrock, doc_name, chunk))
|
110 |
+
|
111 |
+
# return the average of the embeddinngs
|
112 |
+
return np.mean(embeddings, axis=0)
|
113 |
+
|
114 |
+
else:
|
115 |
+
yield f"Unhandled Exception when processing {doc_name}! : {error.response['Error']['Code']}"
|
116 |
+
return None
|
117 |
+
|
118 |
+
# Catch-all for any other exceptions
|
119 |
+
except Exception as error:
|
120 |
+
yield f"Unhandled Exception when processing {doc_name}: {type(error).__name__}"
|
121 |
+
return None
|
122 |
+
|
123 |
+
return response_body.get('embedding')
|
124 |
+
|
125 |
+
def ask_ds(message, history):
|
126 |
+
|
127 |
+
|
128 |
+
question = message
|
129 |
+
|
130 |
+
# RAG
|
131 |
+
question_embedding = get_titan_embedding(bedrock_client, 'question', question)
|
132 |
+
|
133 |
+
similar_documents = []
|
134 |
+
for file, data in extractions.items():
|
135 |
+
similarity = cosine_similarity(question_embedding, data['embedding'])
|
136 |
+
similar_documents.append((file, similarity))
|
137 |
+
|
138 |
+
similar_documents.sort(key=lambda x: x[1], reverse=False)
|
139 |
+
|
140 |
+
similar_content = ''
|
141 |
+
for file, _ in similar_documents[:5]:
|
142 |
+
similar_content += extractions[file]['content'] + '\n'
|
143 |
+
|
144 |
+
|
145 |
+
# Invoke
|
146 |
+
response = bedrock_client.invoke_model_with_response_stream(
|
147 |
+
modelId="anthropic.claude-3-sonnet-20240229-v1:0",
|
148 |
+
body=json.dumps(
|
149 |
+
{
|
150 |
+
"anthropic_version": "bedrock-2023-05-31",
|
151 |
+
"max_tokens": 4096,
|
152 |
+
"system": f"""You are a helpful, excited assistant that answers questions about certain provided documents.
|
153 |
+
<Task>
|
154 |
+
Your task is to review the provided relevant information and answer the user's question to the best of your ability.
|
155 |
+
Try to use only the information in the document to answer. Refrain from saying things like 'According to the relevant information provided'.
|
156 |
+
|
157 |
+
Format your output nicely with sentences that are not too long. You should prefer lists or bullet points when applicable.
|
158 |
+
Begin by thanking the user for their question, and at the end of your answer, say "Thank you for using Ask Dane Street!"
|
159 |
+
</Task>
|
160 |
+
|
161 |
+
<Relevant Information>
|
162 |
+
{similar_content}
|
163 |
+
</Relevant Information>""",
|
164 |
+
"messages": [
|
165 |
+
{
|
166 |
+
"role": "user",
|
167 |
+
"content": [
|
168 |
+
{
|
169 |
+
"type": "text",
|
170 |
+
"text": message
|
171 |
+
}
|
172 |
+
]
|
173 |
+
}
|
174 |
+
],
|
175 |
+
}
|
176 |
+
),
|
177 |
+
)
|
178 |
+
|
179 |
+
all_text = ''
|
180 |
+
stream = response.get('body')
|
181 |
+
if stream:
|
182 |
+
for event in stream:
|
183 |
+
chunk = event.get('chunk')
|
184 |
+
if chunk and json.loads(chunk.get('bytes').decode()):
|
185 |
+
# check if delta is present
|
186 |
+
try:
|
187 |
+
this_text = json.loads(chunk.get('bytes').decode()).get('delta').get('text')
|
188 |
+
all_text += this_text
|
189 |
+
yield all_text # Stream the text back to the UI
|
190 |
+
except:
|
191 |
+
pass
|
192 |
+
|
193 |
+
output = '\n\nCheck out the following documents for more information:\n'
|
194 |
+
for file, _ in similar_documents[:5]:
|
195 |
+
output += f"\n{file.replace('.txt', '.pdf')}"
|
196 |
+
|
197 |
+
yield all_text + output
|
198 |
+
|
199 |
+
|
200 |
+
bedrock_client = create_bedrock_client()
|
201 |
+
s3_client = create_s3_client()
|
202 |
+
_, extractions = read_json_from_s3(s3_client, BUCKET_NAME, EXTRACTIONS_PATH)
|
203 |
+
|
204 |
+
demo = gr.ChatInterface(fn=ask_ds, title="AskDS_HR", multimodal=False, chatbot=gr.Chatbot(value=[(None, "")],),theme=theme)
|
205 |
+
demo.launch()
|