|
import gradio as gr |
|
import boto3 |
|
import json |
|
import os |
|
import numpy as np |
|
import botocore |
|
import time |
|
|
|
theme = gr.themes.Base(text_size='sm') |
|
|
|
|
|
|
|
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID') |
|
AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY') |
|
AWS_REGION = os.getenv('REGION_NAME') |
|
AWS_SESSION = os.getenv('AWS_SESSION') |
|
BUCKET_NAME = os.getenv('BUCKET_NAME') |
|
EXTRACTIONS_PATH = os.getenv('EXTRACTIONS_PATH') |
|
|
|
|
|
def create_bedrock_client(): |
|
|
|
return boto3.client( |
|
'bedrock-runtime', |
|
region_name=AWS_REGION, |
|
aws_access_key_id=AWS_ACCESS_KEY_ID, |
|
aws_secret_access_key=AWS_SECRET_ACCESS_KEY, |
|
aws_session_token=AWS_SESSION |
|
) |
|
|
|
def create_s3_client(): |
|
|
|
|
|
return boto3.client( |
|
's3', |
|
aws_access_key_id=AWS_ACCESS_KEY_ID, |
|
aws_secret_access_key=AWS_SECRET_ACCESS_KEY, |
|
aws_session_token=AWS_SESSION |
|
) |
|
|
|
def read_json_from_s3(): |
|
|
|
try: |
|
response = s3_client.get_object(Bucket=BUCKET_NAME, Key=EXTRACTIONS_PATH) |
|
file_content = response['Body'].read().decode('utf-8') |
|
json_content = json.loads(file_content) |
|
except Exception as e: |
|
yield f"Error reading JSON file from S3: {e}" |
|
return None |
|
|
|
return s3_file_path, json_content |
|
|
|
def get_titan_embedding(bedrock, doc_name, text, attempt=0, cutoff=10000): |
|
""" |
|
Retrieves a text embedding for a given document using the Amazon Titan Embedding model. |
|
|
|
This function sends the provided text to the Amazon Titan text embedding model |
|
and retrieves the resulting embedding. It handles retries for throttling exceptions |
|
and input size limitations by recursively calling itself with adjusted parameters. |
|
|
|
Parameters: |
|
doc_name (str): The name of the document, used for logging and error messages. |
|
text (str): The text content to be processed by the Titan embedding model. |
|
attempt (int): The current attempt number (used in recursive calls to handle retries). Defaults to 0. |
|
cutoff (int): The maximum number of words to include from the input text if a ValidationException occurs due to input size limits. Defaults to 5000. |
|
|
|
Returns: |
|
dict or None: The embedding response from the Titan model as a dictionary, or None if the operation fails or exceeds the retry limits. |
|
""" |
|
|
|
retries = 5 |
|
|
|
try: |
|
model_id = 'amazon.titan-embed-text-v1' |
|
accept = 'application/json' |
|
content_type = 'application/json' |
|
|
|
body = json.dumps({ |
|
"inputText": text, |
|
}) |
|
|
|
|
|
response = bedrock.invoke_model( |
|
body=body, |
|
modelId=model_id, |
|
accept=accept, |
|
contentType=content_type |
|
) |
|
|
|
|
|
response_body = json.loads(response['body'].read()) |
|
|
|
|
|
|
|
except botocore.exceptions.ClientError as error: |
|
if error.response['Error']['Code'] == 'ThrottlingException': |
|
if attempt + 1 == retries: |
|
return None |
|
|
|
delay = 2 ** (attempt + 1); |
|
time.sleep(delay) |
|
return get_titan_embedding(doc_name, text, attempt=attempt + 1) |
|
|
|
elif error.response['Error']['Code'] == 'ValidationException': |
|
|
|
text_chunks = [text[i:i+cutoff] for i in range(0, len(text), cutoff)] |
|
embeddings = [] |
|
for chunk in text_chunks: |
|
embeddings.append(get_titan_embedding(bedrock, doc_name, chunk)) |
|
|
|
|
|
return np.mean(embeddings, axis=0) |
|
|
|
else: |
|
yield f"Unhandled Exception when processing {doc_name}! : {error.response['Error']['Code']}" |
|
return None |
|
|
|
|
|
except Exception as error: |
|
yield f"Unhandled Exception when processing {doc_name}: {type(error).__name__}" |
|
return None |
|
|
|
return response_body.get('embedding') |
|
|
|
def ask_ds(message, history): |
|
|
|
|
|
question = message |
|
|
|
|
|
question_embedding = get_titan_embedding(bedrock_client, 'question', question) |
|
|
|
similar_documents = [] |
|
for file, data in extractions.items(): |
|
similarity = cosine_similarity(question_embedding, data['embedding']) |
|
similar_documents.append((file, similarity)) |
|
|
|
similar_documents.sort(key=lambda x: x[1], reverse=False) |
|
|
|
similar_content = '' |
|
for file, _ in similar_documents[:5]: |
|
similar_content += extractions[file]['content'] + '\n' |
|
|
|
|
|
|
|
response = bedrock_client.invoke_model_with_response_stream( |
|
modelId="anthropic.claude-3-sonnet-20240229-v1:0", |
|
body=json.dumps( |
|
{ |
|
"anthropic_version": "bedrock-2023-05-31", |
|
"max_tokens": 4096, |
|
"system": f"""You are a helpful, excited assistant that answers questions about certain provided documents. |
|
<Task> |
|
Your task is to review the provided relevant information and answer the user's question to the best of your ability. |
|
Try to use only the information in the document to answer. Refrain from saying things like 'According to the relevant information provided'. |
|
|
|
Format your output nicely with sentences that are not too long. You should prefer lists or bullet points when applicable. |
|
Begin by thanking the user for their question, and at the end of your answer, say "Thank you for using Ask Dane Street!" |
|
</Task> |
|
|
|
<Relevant Information> |
|
{similar_content} |
|
</Relevant Information>""", |
|
"messages": [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "text", |
|
"text": message |
|
} |
|
] |
|
} |
|
], |
|
} |
|
), |
|
) |
|
|
|
all_text = '' |
|
stream = response.get('body') |
|
if stream: |
|
for event in stream: |
|
chunk = event.get('chunk') |
|
if chunk and json.loads(chunk.get('bytes').decode()): |
|
|
|
try: |
|
this_text = json.loads(chunk.get('bytes').decode()).get('delta').get('text') |
|
all_text += this_text |
|
yield all_text |
|
except: |
|
pass |
|
|
|
output = '\n\nCheck out the following documents for more information:\n' |
|
for file, _ in similar_documents[:5]: |
|
output += f"\n{file.replace('.txt', '.pdf')}" |
|
|
|
yield all_text + output |
|
|
|
|
|
bedrock_client = create_bedrock_client() |
|
s3_client = create_s3_client() |
|
_, extractions = read_json_from_s3(s3_client, BUCKET_NAME, EXTRACTIONS_PATH) |
|
|
|
demo = gr.ChatInterface(fn=ask_ds, title="AskDS_HR", multimodal=False, chatbot=gr.Chatbot(value=[(None, "")],),theme=theme) |
|
demo.launch() |
|
|