1st push
Browse files- LICENSE +21 -0
- app.py +259 -0
- figs/false-insurance-policy.jpeg +0 -0
- figs/labcorp_accessioning.jpg +0 -0
- figs/system-architect.drawio +82 -0
- figs/system-architect.png +0 -0
- lambda/my_textract.py +95 -0
- models/cnn_transformer/tf_keras_image_captioning_cnn+transformer_flicker8k.index +0 -0
- requirements.txt +13 -0
- utils/cnn_transformer.py +379 -0
- utils/helpers.py +192 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Yiqiao Yin
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
app.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
from typing import Any, Dict, List
|
6 |
+
|
7 |
+
import chromadb
|
8 |
+
import google.generativeai as palm
|
9 |
+
import pandas as pd
|
10 |
+
import requests
|
11 |
+
import streamlit as st
|
12 |
+
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
|
13 |
+
from langchain.text_splitter import (
|
14 |
+
RecursiveCharacterTextSplitter,
|
15 |
+
SentenceTransformersTokenTextSplitter,
|
16 |
+
)
|
17 |
+
from PIL import Image, ImageDraw, ImageFont
|
18 |
+
from pypdf import PdfReader
|
19 |
+
from transformers import pipeline
|
20 |
+
|
21 |
+
from utils.cnn_transformer import *
|
22 |
+
from utils.helpers import *
|
23 |
+
|
24 |
+
# API Key (You should set this in your environment variables)
|
25 |
+
api_key = st.secrets["PALM_API_KEY"]
|
26 |
+
palm.configure(api_key=api_key)
|
27 |
+
|
28 |
+
|
29 |
+
# Load YOLO pipeline
|
30 |
+
yolo_pipe = pipeline("object-detection", model="hustvl/yolos-small")
|
31 |
+
|
32 |
+
|
33 |
+
# Function to draw bounding boxes and labels on image
|
34 |
+
def draw_boxes(image, predictions):
|
35 |
+
draw = ImageDraw.Draw(image)
|
36 |
+
font = ImageFont.load_default()
|
37 |
+
|
38 |
+
for pred in predictions:
|
39 |
+
label = pred["label"]
|
40 |
+
score = pred["score"]
|
41 |
+
box = pred["box"]
|
42 |
+
xmin, ymin, xmax, ymax = box.values()
|
43 |
+
draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=2)
|
44 |
+
draw.text((xmin, ymin), f"{label} ({score:.2f})", fill="red", font=font)
|
45 |
+
|
46 |
+
return image
|
47 |
+
|
48 |
+
|
49 |
+
# Main function of the Streamlit app
|
50 |
+
def main():
|
51 |
+
st.title("Generative AI Demo on Camera Input/Image/PDF 💻")
|
52 |
+
|
53 |
+
# Dropdown for user to choose the input method
|
54 |
+
input_method = st.sidebar.selectbox(
|
55 |
+
"Choose input method:", ["Camera", "Upload Image", "Upload PDF"]
|
56 |
+
)
|
57 |
+
|
58 |
+
image, uploaded_file = None, None
|
59 |
+
if input_method == "Camera":
|
60 |
+
# Streamlit widget to capture an image from the user's webcam
|
61 |
+
image = st.sidebar.camera_input("Take a picture 📸")
|
62 |
+
elif input_method == "Upload Image":
|
63 |
+
# Create a file uploader in the sidebar
|
64 |
+
image = st.sidebar.file_uploader("Upload a JPG image", type=["jpg"])
|
65 |
+
elif input_method == "Upload PDF":
|
66 |
+
# File uploader widget
|
67 |
+
uploaded_file = st.sidebar.file_uploader("Choose a PDF file", type="pdf")
|
68 |
+
|
69 |
+
# Add instruction
|
70 |
+
st.sidebar.markdown(
|
71 |
+
"""
|
72 |
+
# 🌟 How to Use the App 🌟
|
73 |
+
|
74 |
+
1) **🌈 User Input Magic**:
|
75 |
+
- 📸 **Camera Snap**: Tap to capture a moment with your device's camera. Say cheese!
|
76 |
+
- 🖼️ **Image Upload Extravaganza**: Got a cool pic? Upload it from your computer and let the magic begin!
|
77 |
+
- 📄 **PDF Adventure**: Use gen AI as ctrl+F to search information on any PDF, like opening a treasure chest of information!
|
78 |
+
- 📄 **YOLO Algorithm**: Wanna detect the object in the image? Use our object detection algorithm to see if the objects can be detected.
|
79 |
+
|
80 |
+
2) **🤖 AI Interaction Wonderland**:
|
81 |
+
- 🌟 **Gemini's AI**: Google's Gemini AI is your companion, ready to dive deep into your uploads.
|
82 |
+
- 🌐 **Chroma Database**: As you upload, we're crafting a colorful Chroma database in our secret lab, making your interaction even more awesome!
|
83 |
+
|
84 |
+
3) **💬 Chit-Chat with AI Post-Upload**:
|
85 |
+
- 🌍 Once your content is up in the app, ask away! Any question, any time.
|
86 |
+
- 💡 Light up the conversation with Gemini AI. It is like having a chat with a wise wizard from the digital realm!
|
87 |
+
|
88 |
+
Enjoy exploring and have fun! 😄🎉
|
89 |
+
"""
|
90 |
+
)
|
91 |
+
|
92 |
+
if image is not None:
|
93 |
+
# Display the captured image
|
94 |
+
st.image(image, caption="Captured Image", use_column_width=True)
|
95 |
+
|
96 |
+
# Convert the image to PIL format and resize
|
97 |
+
pil_image = Image.open(image)
|
98 |
+
resized_image = resize_image(pil_image)
|
99 |
+
|
100 |
+
# Convert the resized image to base64
|
101 |
+
image_base64 = convert_image_to_base64(resized_image)
|
102 |
+
|
103 |
+
# OCR by API Call of AWS Textract via Post Method
|
104 |
+
if input_method == "Upload Image":
|
105 |
+
st.success("Running textract!")
|
106 |
+
url = "https://2tsig211e0.execute-api.us-east-1.amazonaws.com/my_textract"
|
107 |
+
payload = {"image": image_base64}
|
108 |
+
result_dict = post_request_and_parse_response(url, payload)
|
109 |
+
output_data = extract_line_items(result_dict)
|
110 |
+
df = pd.DataFrame(output_data)
|
111 |
+
|
112 |
+
# Using an expander to hide the json
|
113 |
+
with st.expander("Show/Hide Raw Json"):
|
114 |
+
st.write(result_dict)
|
115 |
+
|
116 |
+
# Using an expander to hide the table
|
117 |
+
with st.expander("Show/Hide Table"):
|
118 |
+
st.table(df)
|
119 |
+
|
120 |
+
if api_key:
|
121 |
+
# Make API call
|
122 |
+
st.success("Running Gemini!")
|
123 |
+
with st.spinner('Wait for it...'):
|
124 |
+
response = call_gemini_api(image_base64, api_key)
|
125 |
+
|
126 |
+
with st.expander("Raw output from Gemini"):
|
127 |
+
st.write(response)
|
128 |
+
|
129 |
+
# Display the response
|
130 |
+
if response["candidates"][0]["content"]["parts"][0]["text"]:
|
131 |
+
text_from_response = response["candidates"][0]["content"]["parts"][0][
|
132 |
+
"text"
|
133 |
+
]
|
134 |
+
with st.spinner("Wait for it..."):
|
135 |
+
st.write(text_from_response)
|
136 |
+
|
137 |
+
# Text input for the question
|
138 |
+
input_prompt = st.text_input(
|
139 |
+
"Type your question here:",
|
140 |
+
)
|
141 |
+
|
142 |
+
# Display the entered question
|
143 |
+
if input_prompt:
|
144 |
+
updated_text_from_response = call_gemini_api(
|
145 |
+
image_base64, api_key, prompt=input_prompt
|
146 |
+
)
|
147 |
+
|
148 |
+
if updated_text_from_response is not None:
|
149 |
+
# Do something with the text
|
150 |
+
updated_ans = updated_text_from_response["candidates"][0][
|
151 |
+
"content"
|
152 |
+
]["parts"][0]["text"]
|
153 |
+
with st.spinner("Wait for it..."):
|
154 |
+
st.write(f"Gemini: {updated_ans}")
|
155 |
+
else:
|
156 |
+
st.warning("Check gemini's API.")
|
157 |
+
|
158 |
+
else:
|
159 |
+
st.write("No response from API.")
|
160 |
+
else:
|
161 |
+
st.write("API Key is not set. Please set the API Key.")
|
162 |
+
|
163 |
+
# YOLO
|
164 |
+
if image is not None:
|
165 |
+
st.sidebar.success("Check the following box to run YOLO algorithm if desired!")
|
166 |
+
use_yolo = st.sidebar.checkbox("Use YOLO!", value=False)
|
167 |
+
|
168 |
+
if use_yolo:
|
169 |
+
# Process image with YOLO
|
170 |
+
image = Image.open(image)
|
171 |
+
with st.spinner("Wait for it..."):
|
172 |
+
st.success("Running YOLO algorithm!")
|
173 |
+
predictions = yolo_pipe(image)
|
174 |
+
st.success("YOLO running successfully.")
|
175 |
+
|
176 |
+
# Draw bounding boxes and labels
|
177 |
+
image_with_boxes = draw_boxes(image.copy(), predictions)
|
178 |
+
st.success("Bounding boxes drawn.")
|
179 |
+
|
180 |
+
# Display annotated image
|
181 |
+
st.image(image_with_boxes, caption="Annotated Image", use_column_width=True)
|
182 |
+
|
183 |
+
# File uploader widget
|
184 |
+
if uploaded_file is not None:
|
185 |
+
# To read file as bytes:
|
186 |
+
bytes_data = uploaded_file.getvalue()
|
187 |
+
st.success("Your PDF is uploaded successfully.")
|
188 |
+
|
189 |
+
# Get the file name
|
190 |
+
file_name = uploaded_file.name
|
191 |
+
|
192 |
+
# Save the file temporarily
|
193 |
+
with open(file_name, "wb") as f:
|
194 |
+
f.write(uploaded_file.getbuffer())
|
195 |
+
|
196 |
+
# Display PDF
|
197 |
+
# displayPDF(file_name)
|
198 |
+
|
199 |
+
# Read file
|
200 |
+
reader = PdfReader(file_name)
|
201 |
+
pdf_texts = [p.extract_text().strip() for p in reader.pages]
|
202 |
+
|
203 |
+
# Filter the empty strings
|
204 |
+
pdf_texts = [text for text in pdf_texts if text]
|
205 |
+
st.success("PDF extracted successfully.")
|
206 |
+
|
207 |
+
# Split the texts
|
208 |
+
character_splitter = RecursiveCharacterTextSplitter(
|
209 |
+
separators=["\n\n", "\n", ". ", " ", ""], chunk_size=1000, chunk_overlap=0
|
210 |
+
)
|
211 |
+
character_split_texts = character_splitter.split_text("\n\n".join(pdf_texts))
|
212 |
+
st.success("Texts splitted successfully.")
|
213 |
+
|
214 |
+
# Tokenize it
|
215 |
+
st.warning("Start tokenzing ...")
|
216 |
+
token_splitter = SentenceTransformersTokenTextSplitter(
|
217 |
+
chunk_overlap=0, tokens_per_chunk=256
|
218 |
+
)
|
219 |
+
token_split_texts = []
|
220 |
+
for text in character_split_texts:
|
221 |
+
token_split_texts += token_splitter.split_text(text)
|
222 |
+
st.success("Tokenized successfully.")
|
223 |
+
|
224 |
+
# Add to vector database
|
225 |
+
embedding_function = SentenceTransformerEmbeddingFunction()
|
226 |
+
chroma_client = chromadb.Client()
|
227 |
+
chroma_collection = chroma_client.create_collection(
|
228 |
+
"tmp", embedding_function=embedding_function
|
229 |
+
)
|
230 |
+
ids = [str(i) for i in range(len(token_split_texts))]
|
231 |
+
chroma_collection.add(ids=ids, documents=token_split_texts)
|
232 |
+
st.success("Vector database loaded successfully.")
|
233 |
+
|
234 |
+
# User input
|
235 |
+
query = st.text_input("Ask me anything!", "What is the document about?")
|
236 |
+
results = chroma_collection.query(query_texts=[query], n_results=5)
|
237 |
+
retrieved_documents = results["documents"][0]
|
238 |
+
results_as_table = pd.DataFrame(
|
239 |
+
{
|
240 |
+
"ids": results["ids"][0],
|
241 |
+
"documents": results["documents"][0],
|
242 |
+
"distances": results["distances"][0],
|
243 |
+
}
|
244 |
+
)
|
245 |
+
|
246 |
+
# API of a foundation model
|
247 |
+
output = rag(query=query, retrieved_documents=retrieved_documents)
|
248 |
+
st.write(output)
|
249 |
+
st.success(
|
250 |
+
"Please see where the chatbot got the information from the document below.👇"
|
251 |
+
)
|
252 |
+
with st.expander("Raw query outputs:"):
|
253 |
+
st.write(results)
|
254 |
+
with st.expander("Processed tabular form query outputs:"):
|
255 |
+
st.table(results_as_table)
|
256 |
+
|
257 |
+
|
258 |
+
if __name__ == "__main__":
|
259 |
+
main()
|
figs/false-insurance-policy.jpeg
ADDED
figs/labcorp_accessioning.jpg
ADDED
figs/system-architect.drawio
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<mxfile host="65bd71144e">
|
2 |
+
<diagram id="6I0VWqCgP7JPpdnrNpuH" name="Page-1">
|
3 |
+
<mxGraphModel dx="721" dy="917" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
|
4 |
+
<root>
|
5 |
+
<mxCell id="0"/>
|
6 |
+
<mxCell id="1" parent="0"/>
|
7 |
+
<mxCell id="37" value="" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
|
8 |
+
<mxGeometry x="80" y="110" width="720" height="480" as="geometry"/>
|
9 |
+
</mxCell>
|
10 |
+
<mxCell id="32" style="edgeStyle=none;html=1;" parent="1" source="2" target="29" edge="1">
|
11 |
+
<mxGeometry relative="1" as="geometry">
|
12 |
+
<mxPoint x="237.5" y="380" as="targetPoint"/>
|
13 |
+
</mxGeometry>
|
14 |
+
</mxCell>
|
15 |
+
<mxCell id="2" value="<b>PDF</b>" style="html=1;verticalLabelPosition=bottom;align=center;labelBackgroundColor=#ffffff;verticalAlign=top;strokeWidth=2;strokeColor=#0080F0;shadow=0;dashed=0;shape=mxgraph.ios7.icons.documents;" parent="1" vertex="1">
|
16 |
+
<mxGeometry x="137.5" y="195" width="55" height="60" as="geometry"/>
|
17 |
+
</mxCell>
|
18 |
+
<mxCell id="12" style="html=1;entryX=1;entryY=0.5;entryDx=0;entryDy=0;entryPerimeter=0;" parent="1" source="3" target="11" edge="1">
|
19 |
+
<mxGeometry relative="1" as="geometry"/>
|
20 |
+
</mxCell>
|
21 |
+
<mxCell id="3" value="<b>Textract</b>" style="sketch=0;points=[[0,0,0],[0.25,0,0],[0.5,0,0],[0.75,0,0],[1,0,0],[0,1,0],[0.25,1,0],[0.5,1,0],[0.75,1,0],[1,1,0],[0,0.25,0],[0,0.5,0],[0,0.75,0],[1,0.25,0],[1,0.5,0],[1,0.75,0]];outlineConnect=0;fontColor=#232F3E;gradientColor=#4AB29A;gradientDirection=north;fillColor=#116D5B;strokeColor=#ffffff;dashed=0;verticalLabelPosition=bottom;verticalAlign=top;align=center;html=1;fontSize=12;fontStyle=0;aspect=fixed;shape=mxgraph.aws4.resourceIcon;resIcon=mxgraph.aws4.textract;" parent="1" vertex="1">
|
22 |
+
<mxGeometry x="700" y="337.5" width="78" height="78" as="geometry"/>
|
23 |
+
</mxCell>
|
24 |
+
<mxCell id="15" style="edgeStyle=none;html=1;" parent="1" source="10" edge="1">
|
25 |
+
<mxGeometry relative="1" as="geometry">
|
26 |
+
<mxPoint x="590" y="420" as="targetPoint"/>
|
27 |
+
<Array as="points">
|
28 |
+
<mxPoint x="460" y="520"/>
|
29 |
+
<mxPoint x="520" y="520"/>
|
30 |
+
<mxPoint x="570" y="520"/>
|
31 |
+
</Array>
|
32 |
+
</mxGeometry>
|
33 |
+
</mxCell>
|
34 |
+
<mxCell id="31" style="edgeStyle=none;html=1;entryX=1.04;entryY=0.492;entryDx=0;entryDy=0;entryPerimeter=0;" parent="1" source="10" target="29" edge="1">
|
35 |
+
<mxGeometry relative="1" as="geometry"/>
|
36 |
+
</mxCell>
|
37 |
+
<mxCell id="10" value="<b>API Gateway</b>" style="outlineConnect=0;dashed=0;verticalLabelPosition=bottom;verticalAlign=top;align=center;html=1;shape=mxgraph.aws3.api_gateway;fillColor=#D9A741;gradientColor=none;" parent="1" vertex="1">
|
38 |
+
<mxGeometry x="400" y="330" width="76.5" height="93" as="geometry"/>
|
39 |
+
</mxCell>
|
40 |
+
<mxCell id="13" style="edgeStyle=none;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;entryPerimeter=0;" parent="1" source="11" target="3" edge="1">
|
41 |
+
<mxGeometry relative="1" as="geometry"/>
|
42 |
+
</mxCell>
|
43 |
+
<mxCell id="16" style="edgeStyle=none;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;exitPerimeter=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;entryPerimeter=0;" parent="1" source="11" target="10" edge="1">
|
44 |
+
<mxGeometry relative="1" as="geometry">
|
45 |
+
<Array as="points">
|
46 |
+
<mxPoint x="570" y="240"/>
|
47 |
+
<mxPoint x="510" y="240"/>
|
48 |
+
<mxPoint x="450" y="240"/>
|
49 |
+
</Array>
|
50 |
+
</mxGeometry>
|
51 |
+
</mxCell>
|
52 |
+
<mxCell id="11" value="<b>AWS Lambda</b>" style="sketch=0;points=[[0,0,0],[0.25,0,0],[0.5,0,0],[0.75,0,0],[1,0,0],[0,1,0],[0.25,1,0],[0.5,1,0],[0.75,1,0],[1,1,0],[0,0.25,0],[0,0.5,0],[0,0.75,0],[1,0.25,0],[1,0.5,0],[1,0.75,0]];outlineConnect=0;fontColor=#232F3E;gradientColor=#F78E04;gradientDirection=north;fillColor=#D05C17;strokeColor=#ffffff;dashed=0;verticalLabelPosition=bottom;verticalAlign=top;align=center;html=1;fontSize=12;fontStyle=0;aspect=fixed;shape=mxgraph.aws4.resourceIcon;resIcon=mxgraph.aws4.lambda;" parent="1" vertex="1">
|
53 |
+
<mxGeometry x="546" y="337.5" width="78" height="78" as="geometry"/>
|
54 |
+
</mxCell>
|
55 |
+
<mxCell id="22" value="<b>OCR Output</b>" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
|
56 |
+
<mxGeometry x="491" y="208" width="60" height="30" as="geometry"/>
|
57 |
+
</mxCell>
|
58 |
+
<mxCell id="25" value="<b>base64&nbsp;<br>Encoded<br>Image<br></b>" style="text;html=1;align=center;verticalAlign=middle;resizable=0;points=[];autosize=1;strokeColor=none;fillColor=none;" parent="1" vertex="1">
|
59 |
+
<mxGeometry x="486" y="513" width="70" height="60" as="geometry"/>
|
60 |
+
</mxCell>
|
61 |
+
<mxCell id="26" value="<b>base64&nbsp;<br>Encoded<br>Image<br></b>" style="text;html=1;align=center;verticalAlign=middle;resizable=0;points=[];autosize=1;strokeColor=none;fillColor=none;" parent="1" vertex="1">
|
62 |
+
<mxGeometry x="265" y="374.25" width="70" height="60" as="geometry"/>
|
63 |
+
</mxCell>
|
64 |
+
<mxCell id="28" value="<b>Extracted<br>Text<br></b>" style="text;html=1;align=center;verticalAlign=middle;resizable=0;points=[];autosize=1;strokeColor=none;fillColor=none;" parent="1" vertex="1">
|
65 |
+
<mxGeometry x="260" y="334.25" width="80" height="40" as="geometry"/>
|
66 |
+
</mxCell>
|
67 |
+
<mxCell id="30" style="edgeStyle=none;html=1;" parent="1" source="29" target="10" edge="1">
|
68 |
+
<mxGeometry relative="1" as="geometry"/>
|
69 |
+
</mxCell>
|
70 |
+
<mxCell id="29" value="<b>User</b>" style="html=1;verticalLabelPosition=bottom;align=center;labelBackgroundColor=#ffffff;verticalAlign=top;strokeWidth=2;strokeColor=#0080F0;shadow=0;dashed=0;shape=mxgraph.ios7.icons.user;" parent="1" vertex="1">
|
71 |
+
<mxGeometry x="125" y="334.25" width="80" height="84.5" as="geometry"/>
|
72 |
+
</mxCell>
|
73 |
+
<mxCell id="33" value="Streamlit App" style="swimlane;whiteSpace=wrap;html=1;align=left;" vertex="1" parent="1">
|
74 |
+
<mxGeometry x="100" y="150" width="690" height="430" as="geometry"/>
|
75 |
+
</mxCell>
|
76 |
+
<mxCell id="34" value="<b>EC2</b>" style="outlineConnect=0;dashed=0;verticalLabelPosition=bottom;verticalAlign=top;align=center;html=1;shape=mxgraph.aws3.ec2;fillColor=#F58534;gradientColor=none;" vertex="1" parent="1">
|
77 |
+
<mxGeometry x="95" y="80" width="35" height="43" as="geometry"/>
|
78 |
+
</mxCell>
|
79 |
+
</root>
|
80 |
+
</mxGraphModel>
|
81 |
+
</diagram>
|
82 |
+
</mxfile>
|
figs/system-architect.png
ADDED
lambda/my_textract.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Purpose
|
3 |
+
An AWS lambda function that analyzes documents with Amazon Textract.
|
4 |
+
"""
|
5 |
+
import json
|
6 |
+
import base64
|
7 |
+
import logging
|
8 |
+
import boto3
|
9 |
+
|
10 |
+
from botocore.exceptions import ClientError
|
11 |
+
|
12 |
+
# Set up logging.
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
# Get the boto3 client.
|
16 |
+
textract_client = boto3.client("textract")
|
17 |
+
|
18 |
+
|
19 |
+
def lambda_handler(event, context):
|
20 |
+
"""
|
21 |
+
Lambda handler function
|
22 |
+
param: event: The event object for the Lambda function.
|
23 |
+
param: context: The context object for the lambda function.
|
24 |
+
return: The list of Block objects recognized in the document
|
25 |
+
passed in the event object.
|
26 |
+
"""
|
27 |
+
|
28 |
+
# raw_image = json.loads(event['body'])['image']
|
29 |
+
# message = f"i love {country}"
|
30 |
+
|
31 |
+
# return message
|
32 |
+
|
33 |
+
try:
|
34 |
+
# Determine document source.
|
35 |
+
# event['image'] = event["queryStringParameters"]['image']
|
36 |
+
# event['image'] = json.loads(event['body'])["queryStringParameters"]['image']
|
37 |
+
event["image"] = json.loads(event["body"])["image"]
|
38 |
+
if "image" in event:
|
39 |
+
# Decode the image
|
40 |
+
image_bytes = event["image"].encode("utf-8")
|
41 |
+
img_b64decoded = base64.b64decode(image_bytes)
|
42 |
+
image = {"Bytes": img_b64decoded}
|
43 |
+
|
44 |
+
elif "S3Object" in event:
|
45 |
+
image = {
|
46 |
+
"S3Object": {
|
47 |
+
"Bucket": event["S3Object"]["Bucket"],
|
48 |
+
"Name": event["S3Object"]["Name"],
|
49 |
+
}
|
50 |
+
}
|
51 |
+
|
52 |
+
else:
|
53 |
+
raise ValueError(
|
54 |
+
"Invalid source. Only image base 64 encoded image bytes or S3Object are supported."
|
55 |
+
)
|
56 |
+
|
57 |
+
# Analyze the document.
|
58 |
+
response = textract_client.detect_document_text(Document=image)
|
59 |
+
|
60 |
+
# Get the Blocks
|
61 |
+
blocks = response["Blocks"]
|
62 |
+
|
63 |
+
lambda_response = {"statusCode": 200, "body": json.dumps(blocks)}
|
64 |
+
|
65 |
+
except ClientError as err:
|
66 |
+
error_message = "Couldn't analyze image. " + err.response["Error"]["Message"]
|
67 |
+
|
68 |
+
lambda_response = {
|
69 |
+
"statusCode": 400,
|
70 |
+
"body": {
|
71 |
+
"Error": err.response["Error"]["Code"],
|
72 |
+
"ErrorMessage": error_message,
|
73 |
+
},
|
74 |
+
}
|
75 |
+
logger.error(
|
76 |
+
"Error function %s: %s", context.invoked_function_arn, error_message
|
77 |
+
)
|
78 |
+
|
79 |
+
except ValueError as val_error:
|
80 |
+
lambda_response = {
|
81 |
+
"statusCode": 400,
|
82 |
+
"body": {"Error": "ValueError", "ErrorMessage": format(val_error)},
|
83 |
+
}
|
84 |
+
logger.error(
|
85 |
+
"Error function %s: %s", context.invoked_function_arn, format(val_error)
|
86 |
+
)
|
87 |
+
|
88 |
+
# Create return body
|
89 |
+
http_resp = {}
|
90 |
+
http_resp["statusCode"] = 200
|
91 |
+
http_resp["headers"] = {}
|
92 |
+
http_resp["headers"]["Content-Type"] = "application/json"
|
93 |
+
http_resp["body"] = json.dumps(lambda_response)
|
94 |
+
|
95 |
+
return http_resp
|
models/cnn_transformer/tf_keras_image_captioning_cnn+transformer_flicker8k.index
ADDED
Binary file (28.9 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
chromadb==0.3.29
|
2 |
+
langchain==0.0.343
|
3 |
+
matplotlib
|
4 |
+
numpy
|
5 |
+
google-generativeai>=0.1.0
|
6 |
+
pandas
|
7 |
+
pypdf==3.17.1
|
8 |
+
Pillow
|
9 |
+
sentence-transformers==2.2.2
|
10 |
+
streamlit
|
11 |
+
transformers
|
12 |
+
torch
|
13 |
+
tensorflow
|
utils/cnn_transformer.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["KERAS_BACKEND"] = "tensorflow"
|
4 |
+
|
5 |
+
import re
|
6 |
+
import numpy as np
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
|
9 |
+
import tensorflow as tf
|
10 |
+
import keras
|
11 |
+
from keras import layers
|
12 |
+
from keras.applications import efficientnet
|
13 |
+
from keras.layers import TextVectorization
|
14 |
+
|
15 |
+
keras.utils.set_random_seed(111)
|
16 |
+
|
17 |
+
|
18 |
+
# Desired image dimensions
|
19 |
+
IMAGE_SIZE = (299, 299)
|
20 |
+
|
21 |
+
# Dimension for the image embeddings and token embeddings
|
22 |
+
EMBED_DIM = 512
|
23 |
+
|
24 |
+
# Per-layer units in the feed-forward network
|
25 |
+
FF_DIM = 512
|
26 |
+
|
27 |
+
# Fixed length allowed for any sequence
|
28 |
+
SEQ_LENGTH = 25
|
29 |
+
|
30 |
+
# Vocabulary size
|
31 |
+
VOCAB_SIZE = 10000
|
32 |
+
|
33 |
+
# Data augmentation for image data
|
34 |
+
image_augmentation = keras.Sequential(
|
35 |
+
[
|
36 |
+
layers.RandomFlip("horizontal"),
|
37 |
+
layers.RandomRotation(0.2),
|
38 |
+
layers.RandomContrast(0.3),
|
39 |
+
]
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
def get_cnn_model():
|
44 |
+
base_model = efficientnet.EfficientNetB0(
|
45 |
+
input_shape=(*IMAGE_SIZE, 3),
|
46 |
+
include_top=False,
|
47 |
+
weights="imagenet",
|
48 |
+
)
|
49 |
+
# We freeze our feature extractor
|
50 |
+
base_model.trainable = False
|
51 |
+
base_model_out = base_model.output
|
52 |
+
base_model_out = layers.Reshape((-1, base_model_out.shape[-1]))(base_model_out)
|
53 |
+
cnn_model = keras.models.Model(base_model.input, base_model_out)
|
54 |
+
return cnn_model
|
55 |
+
|
56 |
+
|
57 |
+
class TransformerEncoderBlock(layers.Layer):
|
58 |
+
def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
|
59 |
+
super().__init__(**kwargs)
|
60 |
+
self.embed_dim = embed_dim
|
61 |
+
self.dense_dim = dense_dim
|
62 |
+
self.num_heads = num_heads
|
63 |
+
self.attention_1 = layers.MultiHeadAttention(
|
64 |
+
num_heads=num_heads, key_dim=embed_dim, dropout=0.0
|
65 |
+
)
|
66 |
+
self.layernorm_1 = layers.LayerNormalization()
|
67 |
+
self.layernorm_2 = layers.LayerNormalization()
|
68 |
+
self.dense_1 = layers.Dense(embed_dim, activation="relu")
|
69 |
+
|
70 |
+
def call(self, inputs, training, mask=None):
|
71 |
+
inputs = self.layernorm_1(inputs)
|
72 |
+
inputs = self.dense_1(inputs)
|
73 |
+
|
74 |
+
attention_output_1 = self.attention_1(
|
75 |
+
query=inputs,
|
76 |
+
value=inputs,
|
77 |
+
key=inputs,
|
78 |
+
attention_mask=None,
|
79 |
+
training=training,
|
80 |
+
)
|
81 |
+
out_1 = self.layernorm_2(inputs + attention_output_1)
|
82 |
+
return out_1
|
83 |
+
|
84 |
+
|
85 |
+
class PositionalEmbedding(layers.Layer):
|
86 |
+
def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
|
87 |
+
super().__init__(**kwargs)
|
88 |
+
self.token_embeddings = layers.Embedding(
|
89 |
+
input_dim=vocab_size, output_dim=embed_dim
|
90 |
+
)
|
91 |
+
self.position_embeddings = layers.Embedding(
|
92 |
+
input_dim=sequence_length, output_dim=embed_dim
|
93 |
+
)
|
94 |
+
self.sequence_length = sequence_length
|
95 |
+
self.vocab_size = vocab_size
|
96 |
+
self.embed_dim = embed_dim
|
97 |
+
self.embed_scale = tf.math.sqrt(tf.cast(embed_dim, tf.float32))
|
98 |
+
|
99 |
+
def call(self, inputs):
|
100 |
+
length = tf.shape(inputs)[-1]
|
101 |
+
positions = tf.range(start=0, limit=length, delta=1)
|
102 |
+
embedded_tokens = self.token_embeddings(inputs)
|
103 |
+
embedded_tokens = embedded_tokens * self.embed_scale
|
104 |
+
embedded_positions = self.position_embeddings(positions)
|
105 |
+
return embedded_tokens + embedded_positions
|
106 |
+
|
107 |
+
def compute_mask(self, inputs, mask=None):
|
108 |
+
return tf.math.not_equal(inputs, 0)
|
109 |
+
|
110 |
+
|
111 |
+
class TransformerDecoderBlock(layers.Layer):
|
112 |
+
def __init__(self, embed_dim, ff_dim, num_heads, **kwargs):
|
113 |
+
super().__init__(**kwargs)
|
114 |
+
self.embed_dim = embed_dim
|
115 |
+
self.ff_dim = ff_dim
|
116 |
+
self.num_heads = num_heads
|
117 |
+
self.attention_1 = layers.MultiHeadAttention(
|
118 |
+
num_heads=num_heads, key_dim=embed_dim, dropout=0.1
|
119 |
+
)
|
120 |
+
self.attention_2 = layers.MultiHeadAttention(
|
121 |
+
num_heads=num_heads, key_dim=embed_dim, dropout=0.1
|
122 |
+
)
|
123 |
+
self.ffn_layer_1 = layers.Dense(ff_dim, activation="relu")
|
124 |
+
self.ffn_layer_2 = layers.Dense(embed_dim)
|
125 |
+
|
126 |
+
self.layernorm_1 = layers.LayerNormalization()
|
127 |
+
self.layernorm_2 = layers.LayerNormalization()
|
128 |
+
self.layernorm_3 = layers.LayerNormalization()
|
129 |
+
|
130 |
+
self.embedding = PositionalEmbedding(
|
131 |
+
embed_dim=EMBED_DIM,
|
132 |
+
sequence_length=SEQ_LENGTH,
|
133 |
+
vocab_size=VOCAB_SIZE,
|
134 |
+
)
|
135 |
+
self.out = layers.Dense(VOCAB_SIZE, activation="softmax")
|
136 |
+
|
137 |
+
self.dropout_1 = layers.Dropout(0.3)
|
138 |
+
self.dropout_2 = layers.Dropout(0.5)
|
139 |
+
self.supports_masking = True
|
140 |
+
|
141 |
+
def call(self, inputs, encoder_outputs, training, mask=None):
|
142 |
+
inputs = self.embedding(inputs)
|
143 |
+
causal_mask = self.get_causal_attention_mask(inputs)
|
144 |
+
|
145 |
+
if mask is not None:
|
146 |
+
padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
|
147 |
+
combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
|
148 |
+
combined_mask = tf.minimum(combined_mask, causal_mask)
|
149 |
+
|
150 |
+
attention_output_1 = self.attention_1(
|
151 |
+
query=inputs,
|
152 |
+
value=inputs,
|
153 |
+
key=inputs,
|
154 |
+
attention_mask=combined_mask,
|
155 |
+
training=training,
|
156 |
+
)
|
157 |
+
out_1 = self.layernorm_1(inputs + attention_output_1)
|
158 |
+
|
159 |
+
attention_output_2 = self.attention_2(
|
160 |
+
query=out_1,
|
161 |
+
value=encoder_outputs,
|
162 |
+
key=encoder_outputs,
|
163 |
+
attention_mask=padding_mask,
|
164 |
+
training=training,
|
165 |
+
)
|
166 |
+
out_2 = self.layernorm_2(out_1 + attention_output_2)
|
167 |
+
|
168 |
+
ffn_out = self.ffn_layer_1(out_2)
|
169 |
+
ffn_out = self.dropout_1(ffn_out, training=training)
|
170 |
+
ffn_out = self.ffn_layer_2(ffn_out)
|
171 |
+
|
172 |
+
ffn_out = self.layernorm_3(ffn_out + out_2, training=training)
|
173 |
+
ffn_out = self.dropout_2(ffn_out, training=training)
|
174 |
+
preds = self.out(ffn_out)
|
175 |
+
return preds
|
176 |
+
|
177 |
+
def get_causal_attention_mask(self, inputs):
|
178 |
+
input_shape = tf.shape(inputs)
|
179 |
+
batch_size, sequence_length = input_shape[0], input_shape[1]
|
180 |
+
i = tf.range(sequence_length)[:, tf.newaxis]
|
181 |
+
j = tf.range(sequence_length)
|
182 |
+
mask = tf.cast(i >= j, dtype="int32")
|
183 |
+
mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
|
184 |
+
mult = tf.concat(
|
185 |
+
[
|
186 |
+
tf.expand_dims(batch_size, -1),
|
187 |
+
tf.constant([1, 1], dtype=tf.int32),
|
188 |
+
],
|
189 |
+
axis=0,
|
190 |
+
)
|
191 |
+
return tf.tile(mask, mult)
|
192 |
+
|
193 |
+
|
194 |
+
class ImageCaptioningModel(keras.Model):
|
195 |
+
def __init__(
|
196 |
+
self,
|
197 |
+
cnn_model,
|
198 |
+
encoder,
|
199 |
+
decoder,
|
200 |
+
num_captions_per_image=5,
|
201 |
+
image_aug=None,
|
202 |
+
):
|
203 |
+
super().__init__()
|
204 |
+
self.cnn_model = cnn_model
|
205 |
+
self.encoder = encoder
|
206 |
+
self.decoder = decoder
|
207 |
+
self.loss_tracker = keras.metrics.Mean(name="loss")
|
208 |
+
self.acc_tracker = keras.metrics.Mean(name="accuracy")
|
209 |
+
self.num_captions_per_image = num_captions_per_image
|
210 |
+
self.image_aug = image_aug
|
211 |
+
|
212 |
+
def calculate_loss(self, y_true, y_pred, mask):
|
213 |
+
loss = self.loss(y_true, y_pred)
|
214 |
+
mask = tf.cast(mask, dtype=loss.dtype)
|
215 |
+
loss *= mask
|
216 |
+
return tf.reduce_sum(loss) / tf.reduce_sum(mask)
|
217 |
+
|
218 |
+
def calculate_accuracy(self, y_true, y_pred, mask):
|
219 |
+
accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
|
220 |
+
accuracy = tf.math.logical_and(mask, accuracy)
|
221 |
+
accuracy = tf.cast(accuracy, dtype=tf.float32)
|
222 |
+
mask = tf.cast(mask, dtype=tf.float32)
|
223 |
+
return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)
|
224 |
+
|
225 |
+
def _compute_caption_loss_and_acc(self, img_embed, batch_seq, training=True):
|
226 |
+
encoder_out = self.encoder(img_embed, training=training)
|
227 |
+
batch_seq_inp = batch_seq[:, :-1]
|
228 |
+
batch_seq_true = batch_seq[:, 1:]
|
229 |
+
mask = tf.math.not_equal(batch_seq_true, 0)
|
230 |
+
batch_seq_pred = self.decoder(
|
231 |
+
batch_seq_inp, encoder_out, training=training, mask=mask
|
232 |
+
)
|
233 |
+
loss = self.calculate_loss(batch_seq_true, batch_seq_pred, mask)
|
234 |
+
acc = self.calculate_accuracy(batch_seq_true, batch_seq_pred, mask)
|
235 |
+
return loss, acc
|
236 |
+
|
237 |
+
def train_step(self, batch_data):
|
238 |
+
batch_img, batch_seq = batch_data
|
239 |
+
batch_loss = 0
|
240 |
+
batch_acc = 0
|
241 |
+
|
242 |
+
if self.image_aug:
|
243 |
+
batch_img = self.image_aug(batch_img)
|
244 |
+
|
245 |
+
# 1. Get image embeddings
|
246 |
+
img_embed = self.cnn_model(batch_img)
|
247 |
+
|
248 |
+
# 2. Pass each of the five captions one by one to the decoder
|
249 |
+
# along with the encoder outputs and compute the loss as well as accuracy
|
250 |
+
# for each caption.
|
251 |
+
for i in range(self.num_captions_per_image):
|
252 |
+
with tf.GradientTape() as tape:
|
253 |
+
loss, acc = self._compute_caption_loss_and_acc(
|
254 |
+
img_embed, batch_seq[:, i, :], training=True
|
255 |
+
)
|
256 |
+
|
257 |
+
# 3. Update loss and accuracy
|
258 |
+
batch_loss += loss
|
259 |
+
batch_acc += acc
|
260 |
+
|
261 |
+
# 4. Get the list of all the trainable weights
|
262 |
+
train_vars = (
|
263 |
+
self.encoder.trainable_variables + self.decoder.trainable_variables
|
264 |
+
)
|
265 |
+
|
266 |
+
# 5. Get the gradients
|
267 |
+
grads = tape.gradient(loss, train_vars)
|
268 |
+
|
269 |
+
# 6. Update the trainable weights
|
270 |
+
self.optimizer.apply_gradients(zip(grads, train_vars))
|
271 |
+
|
272 |
+
# 7. Update the trackers
|
273 |
+
batch_acc /= float(self.num_captions_per_image)
|
274 |
+
self.loss_tracker.update_state(batch_loss)
|
275 |
+
self.acc_tracker.update_state(batch_acc)
|
276 |
+
|
277 |
+
# 8. Return the loss and accuracy values
|
278 |
+
return {
|
279 |
+
"loss": self.loss_tracker.result(),
|
280 |
+
"acc": self.acc_tracker.result(),
|
281 |
+
}
|
282 |
+
|
283 |
+
def test_step(self, batch_data):
|
284 |
+
batch_img, batch_seq = batch_data
|
285 |
+
batch_loss = 0
|
286 |
+
batch_acc = 0
|
287 |
+
|
288 |
+
# 1. Get image embeddings
|
289 |
+
img_embed = self.cnn_model(batch_img)
|
290 |
+
|
291 |
+
# 2. Pass each of the five captions one by one to the decoder
|
292 |
+
# along with the encoder outputs and compute the loss as well as accuracy
|
293 |
+
# for each caption.
|
294 |
+
for i in range(self.num_captions_per_image):
|
295 |
+
loss, acc = self._compute_caption_loss_and_acc(
|
296 |
+
img_embed, batch_seq[:, i, :], training=False
|
297 |
+
)
|
298 |
+
|
299 |
+
# 3. Update batch loss and batch accuracy
|
300 |
+
batch_loss += loss
|
301 |
+
batch_acc += acc
|
302 |
+
|
303 |
+
batch_acc /= float(self.num_captions_per_image)
|
304 |
+
|
305 |
+
# 4. Update the trackers
|
306 |
+
self.loss_tracker.update_state(batch_loss)
|
307 |
+
self.acc_tracker.update_state(batch_acc)
|
308 |
+
|
309 |
+
# 5. Return the loss and accuracy values
|
310 |
+
return {
|
311 |
+
"loss": self.loss_tracker.result(),
|
312 |
+
"acc": self.acc_tracker.result(),
|
313 |
+
}
|
314 |
+
|
315 |
+
@property
|
316 |
+
def metrics(self):
|
317 |
+
# We need to list our metrics here so the `reset_states()` can be
|
318 |
+
# called automatically.
|
319 |
+
return [self.loss_tracker, self.acc_tracker]
|
320 |
+
|
321 |
+
|
322 |
+
|
323 |
+
strip_chars = "!\"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"
|
324 |
+
strip_chars = strip_chars.replace("<", "")
|
325 |
+
strip_chars = strip_chars.replace(">", "")
|
326 |
+
|
327 |
+
|
328 |
+
def custom_standardization(input_string):
|
329 |
+
lowercase = tf.strings.lower(input_string)
|
330 |
+
return tf.strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")
|
331 |
+
|
332 |
+
|
333 |
+
vectorization = TextVectorization(
|
334 |
+
max_tokens=VOCAB_SIZE,
|
335 |
+
output_mode="int",
|
336 |
+
output_sequence_length=SEQ_LENGTH,
|
337 |
+
standardize=custom_standardization,
|
338 |
+
)
|
339 |
+
|
340 |
+
|
341 |
+
def generate_caption(caption_model: None):
|
342 |
+
# Select a random image from the validation dataset
|
343 |
+
# sample_img = np.random.choice(valid_images)
|
344 |
+
|
345 |
+
# # Read the image from the disk
|
346 |
+
# sample_img = decode_and_resize(sample_img)
|
347 |
+
# img = sample_img.numpy().clip(0, 255).astype(np.uint8)
|
348 |
+
# plt.imshow(img)
|
349 |
+
# plt.show()
|
350 |
+
|
351 |
+
# Pass the image to the CNN
|
352 |
+
# img = tf.expand_dims(sample_img, 0)
|
353 |
+
#TOOD
|
354 |
+
img = None
|
355 |
+
img = caption_model.cnn_model(img)
|
356 |
+
|
357 |
+
# Pass the image features to the Transformer encoder
|
358 |
+
encoded_img = caption_model.encoder(img, training=False)
|
359 |
+
|
360 |
+
# Generate the caption using the Transformer decoder
|
361 |
+
decoded_caption = "<start> "
|
362 |
+
vocab = vectorization.get_vocabulary()
|
363 |
+
index_lookup = dict(zip(range(len(vocab)), vocab))
|
364 |
+
max_decoded_sentence_length = SEQ_LENGTH - 1
|
365 |
+
for i in range(max_decoded_sentence_length):
|
366 |
+
tokenized_caption = vectorization([decoded_caption])[:, :-1]
|
367 |
+
mask = tf.math.not_equal(tokenized_caption, 0)
|
368 |
+
predictions = caption_model.decoder(
|
369 |
+
tokenized_caption, encoded_img, training=False, mask=mask
|
370 |
+
)
|
371 |
+
sampled_token_index = np.argmax(predictions[0, i, :])
|
372 |
+
sampled_token = index_lookup[sampled_token_index]
|
373 |
+
if sampled_token == "<end>":
|
374 |
+
break
|
375 |
+
decoded_caption += " " + sampled_token
|
376 |
+
|
377 |
+
decoded_caption = decoded_caption.replace("<start> ", "")
|
378 |
+
decoded_caption = decoded_caption.replace(" <end>", "").strip()
|
379 |
+
print("Predicted Caption: ", decoded_caption)
|
utils/helpers.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
from typing import Any, Dict, List
|
6 |
+
|
7 |
+
import pandas as pd
|
8 |
+
import requests
|
9 |
+
import streamlit as st
|
10 |
+
from PIL import Image
|
11 |
+
import google.generativeai as palm
|
12 |
+
from pypdf import PdfReader
|
13 |
+
from langchain.text_splitter import (
|
14 |
+
RecursiveCharacterTextSplitter,
|
15 |
+
SentenceTransformersTokenTextSplitter,
|
16 |
+
)
|
17 |
+
import chromadb
|
18 |
+
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
|
19 |
+
|
20 |
+
|
21 |
+
# API Key (You should set this in your environment variables)
|
22 |
+
api_key = st.secrets["PALM_API_KEY"]
|
23 |
+
palm.configure(api_key=api_key)
|
24 |
+
|
25 |
+
|
26 |
+
# Function to convert the image to bytes for download
|
27 |
+
def convert_image_to_bytes(image):
|
28 |
+
buffered = io.BytesIO()
|
29 |
+
image.save(buffered, format="JPEG")
|
30 |
+
return buffered.getvalue()
|
31 |
+
|
32 |
+
|
33 |
+
# Function to resize the image
|
34 |
+
def resize_image(image):
|
35 |
+
return image.resize((512, int(image.height * 512 / image.width)))
|
36 |
+
|
37 |
+
|
38 |
+
# Function to convert the image to base64
|
39 |
+
def convert_image_to_base64(image):
|
40 |
+
buffered = io.BytesIO()
|
41 |
+
image.save(buffered, format="JPEG")
|
42 |
+
return base64.b64encode(buffered.getvalue()).decode()
|
43 |
+
|
44 |
+
|
45 |
+
# Function to make an API call to Palm
|
46 |
+
def call_palm(prompt: str) -> str:
|
47 |
+
completion = palm.generate_text(
|
48 |
+
model="models/text-bison-001",
|
49 |
+
prompt=prompt,
|
50 |
+
temperature=0,
|
51 |
+
max_output_tokens=800,
|
52 |
+
)
|
53 |
+
|
54 |
+
return completion.result
|
55 |
+
|
56 |
+
|
57 |
+
# Function to make an API call to Google's Gemini API
|
58 |
+
def call_gemini_api(image_base64, api_key=api_key, prompt="What is this picture?"):
|
59 |
+
headers = {
|
60 |
+
"Content-Type": "application/json",
|
61 |
+
}
|
62 |
+
data = {
|
63 |
+
"contents": [
|
64 |
+
{
|
65 |
+
"parts": [
|
66 |
+
{"text": prompt},
|
67 |
+
{"inline_data": {"mime_type": "image/jpeg", "data": image_base64}},
|
68 |
+
]
|
69 |
+
}
|
70 |
+
]
|
71 |
+
}
|
72 |
+
response = requests.post(
|
73 |
+
f"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro-vision:generateContent?key={api_key}",
|
74 |
+
headers=headers,
|
75 |
+
json=data,
|
76 |
+
)
|
77 |
+
return response.json()
|
78 |
+
|
79 |
+
|
80 |
+
def safely_get_text(response):
|
81 |
+
try:
|
82 |
+
response
|
83 |
+
except Exception as e:
|
84 |
+
print(f"An error occurred: {e}")
|
85 |
+
|
86 |
+
# Return None or a default value if the path does not exist
|
87 |
+
return None
|
88 |
+
|
89 |
+
|
90 |
+
def post_request_and_parse_response(
|
91 |
+
url: str, payload: Dict[str, Any]
|
92 |
+
) -> Dict[str, Any]:
|
93 |
+
"""
|
94 |
+
Sends a POST request to the specified URL with the given payload,
|
95 |
+
then parses the byte response to a dictionary.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
url (str): The URL to which the POST request is sent.
|
99 |
+
payload (Dict[str, Any]): The payload to send in the POST request.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
Dict[str, Any]: The parsed dictionary from the response.
|
103 |
+
"""
|
104 |
+
# Set headers for the POST request
|
105 |
+
headers = {"Content-Type": "application/json"}
|
106 |
+
|
107 |
+
# Send the POST request and get the response
|
108 |
+
response = requests.post(url, json=payload, headers=headers)
|
109 |
+
|
110 |
+
# Extract the byte data from the response
|
111 |
+
byte_data = response.content
|
112 |
+
|
113 |
+
# Decode the byte data to a string
|
114 |
+
decoded_string = byte_data.decode("utf-8")
|
115 |
+
|
116 |
+
# Convert the JSON string to a dictionary
|
117 |
+
dict_data = json.loads(decoded_string)
|
118 |
+
|
119 |
+
return dict_data
|
120 |
+
|
121 |
+
|
122 |
+
def extract_line_items(input_data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
123 |
+
"""
|
124 |
+
Extracts items with "BlockType": "LINE" from the provided JSON data.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
input_data (Dict[str, Any]): The input JSON data as a dictionary.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
List[Dict[str, Any]]: A list of dictionaries with the extracted data.
|
131 |
+
"""
|
132 |
+
# Initialize an empty list to hold the extracted line items
|
133 |
+
line_items: List[Dict[str, Any]] = []
|
134 |
+
|
135 |
+
# Get the list of items from the 'body' key in the input data
|
136 |
+
body_items = json.loads(input_data.get("body", "[]"))
|
137 |
+
|
138 |
+
# Iterate through each item in the body
|
139 |
+
for item in body_items:
|
140 |
+
# Check if the BlockType of the item is 'LINE'
|
141 |
+
if item.get("BlockType") == "LINE":
|
142 |
+
# Add the item to the line_items list
|
143 |
+
line_items.append(item)
|
144 |
+
|
145 |
+
return line_items
|
146 |
+
|
147 |
+
|
148 |
+
def rag(query: str, retrieved_documents: list, api_key: str = api_key) -> str:
|
149 |
+
"""
|
150 |
+
Function to process a query and a list of retrieved documents using the Gemini API.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
query (str): The user's query or question.
|
154 |
+
retrieved_documents (list): A list of documents retrieved as relevant information to the query.
|
155 |
+
api_key (str): API key for accessing the Gemini API. Default is a predefined 'api_key'.
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
str: The cleaned output from the Gemini API response.
|
159 |
+
"""
|
160 |
+
# Combine the retrieved documents into a single string, separated by two newlines.
|
161 |
+
information = "\n\n".join(retrieved_documents)
|
162 |
+
|
163 |
+
# Format the query and combined information into a single message.
|
164 |
+
messages = f"Question: {query}. \n Information: {information}"
|
165 |
+
|
166 |
+
# Call the Gemini API with the formatted message and the API key.
|
167 |
+
gemini_output = call_palm(prompt=messages)
|
168 |
+
|
169 |
+
# Placeholder for processing the Gemini output. Currently, it simply assigns the raw output to 'cleaned_output'.
|
170 |
+
cleaned_output = gemini_output # ["candidates"][0]["content"]["parts"][0]["text"]
|
171 |
+
|
172 |
+
return cleaned_output
|
173 |
+
|
174 |
+
|
175 |
+
def displayPDF(file: str) -> None:
|
176 |
+
"""
|
177 |
+
Displays a PDF file in a Streamlit application.
|
178 |
+
|
179 |
+
Parameters:
|
180 |
+
- file (str): The path to the PDF file to be displayed.
|
181 |
+
"""
|
182 |
+
|
183 |
+
# Opening the PDF file in binary read mode
|
184 |
+
with open(file, "rb") as f:
|
185 |
+
# Encoding the PDF file content to base64
|
186 |
+
base64_pdf: str = base64.b64encode(f.read()).decode('utf-8')
|
187 |
+
|
188 |
+
# Creating an HTML embed string for displaying the PDF
|
189 |
+
pdf_display: str = F'<embed src="data:application/pdf;base64,{base64_pdf}" width="700" height="1000" type="application/pdf">'
|
190 |
+
|
191 |
+
# Using Streamlit to display the HTML embed string as unsafe HTML
|
192 |
+
st.markdown(pdf_display, unsafe_allow_html=True)
|