santhosh97's picture
Update app.py
3b015d2
raw
history blame
17.2 kB
import base64
import boto3
from botocore.config import Config
from dotenv import load_dotenv
import os
import shutil
from typing import List, Tuple
import uuid
import zipfile
import argparse
import logging
import sendgrid
from sendgrid.helpers.mail import Mail, Email, To, Content
from glob import glob
from io import BytesIO
from itertools import cycle
import requests
import banana_dev as banana
import streamlit as st
from PIL import Image
from st_btn_select import st_btn_select
from streamlit_image_select import image_select
import smart_open
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Looks for .env file in current directory to pull environment variables. Should
# not overwrite already set environment variables. Used for S3 credentials.
load_dotenv()
_S3_PATH_OUTPUT = "s3://gretel-image-synthetics-use2/data/{identifier}/{image_type}_images.zip"
# Command-line arguments to control some stuff for easier local testing.
# Eventually may want to move everything into functions and have a
# if __name__ == "main" setup instead of everything inline.
parser = argparse.ArgumentParser()
parser.add_argument(
"--dry-run", action="store_true",
help="Skip sending train request to backend server.",
)
parser.add_argument(
"--train-endpoint-url", default=None,
help="URL of backend server to send train request to. If None, use hardcoded banana setup.",
)
cli_args = parser.parse_args()
if "key" not in st.session_state:
st.session_state["key"] = uuid.uuid4().hex
if "captcha_bool" not in st.session_state:
st.session_state["captcha_bool"] = False
if "model_inputs" not in st.session_state:
st.session_state["model_inputs"] = None
if (
"s3_face_file_path" not in st.session_state
and "s3_theme_file_path" not in st.session_state
):
st.session_state["s3_face_file_path"] = None
st.session_state["s3_theme_file_path"] = None
if "view" not in st.session_state:
st.session_state["view"] = False
if "train_view" not in st.session_state:
st.session_state["train_view"] = False
if "captcha_response" not in st.session_state:
st.session_state["captcha_response"] = None
if "captcha" not in st.session_state:
st.session_state["captcha"] = {}
if "login" not in st.session_state:
st.session_state["login"] = None
if "user_auth_sess" not in st.session_state:
st.session_state["user_auth_sess"] = False
if "user_email" not in st.session_state:
st.session_state["email_provided"] = True
def callback():
st.session_state["button_clicked"] = True
def bucket_parts(s3_path: str) -> Tuple[str, str]:
"""Split an S3 path into bucket and key.
Args:
s3_path: path starting with "s3:"
Returns:
Tuple of bucket and key for the path
"""
parts = s3_path.split("/")
bucket = parts[2]
key = "/".join(parts[3:])
return bucket, key
def generate_s3_get_url(s3_path: str, expiration_seconds: int) -> str:
"""Generate a presigned S3 url to read from an S3 path.
A presigned url allows anyone accessing that url to read the s3 path without
needing s3 credentials until the url expires.
Args:
s3_path: path starting with "s3:"
expiration_seconds: how long the url will be valid (does not influence
lifetime of the underlying s3 object, only the presigned url)
Returns:
The presigned url
"""
bucket, key = bucket_parts(s3_path)
s3_client = boto3.client("s3", config=Config(signature_version="s3v4", s3={"addressing_style": "path"}))
download_url = s3_client.generate_presigned_url(
"get_object",
Params={
"Bucket": bucket,
"Key": key
},
ExpiresIn=expiration_seconds
)
return download_url
def generate_s3_put_url(s3_path: str, expiration_seconds: int) -> str:
"""Generate a presigned S3 url to write to an S3 path.
A presigned url allows anyone accessing that url to write to the s3 path
without needing s3 credentials until the url expires.
Args:
s3_path: path starting with "s3:"
expiration_seconds: how long the url will be valid (does not influence
lifetime of the underlying s3 object, only the presigned url)
Returns:
The presigned url
"""
bucket, key = bucket_parts(s3_path)
s3_client = boto3.client("s3", config=Config(signature_version="s3v4", s3={"addressing_style": "path"}))
upload_url = s3_client.generate_presigned_url(
"put_object",
Params={
"Bucket": bucket,
"Key": key
},
ExpiresIn=expiration_seconds
)
return upload_url
def zip_and_upload_images(identifier: str, uploaded_files: List[str], image_type: str) -> str:
"""Save images as zip file to s3 for use in backend.
Blocks until images are processed, added to zip file, and uploaded to S3.
Args:
identifier: unique identifier for the run, used in s3 link
uploaded_files: list of file names
image_type: string to identify different batches of images used in the
backend model/training. Currently used values: "face", "theme"
Returns:
S3 location of zip file containing png images.
"""
if not os.path.exists(identifier):
os.makedirs(identifier)
logger.info("Processing uploaded images")
for num, uploaded_file in enumerate(uploaded_files):
file_ = Image.open(uploaded_file).convert("RGB")
file_.save(f"{identifier}/{num}_test.png")
local_zip_filestem = f"{identifier}_{image_type}_images"
logger.info("Making zip archive")
shutil.make_archive(local_zip_filestem, "zip", identifier)
local_zip_filename = f"{local_zip_filestem}.zip"
logger.info("Uploading zip file to s3")
# TODO: can we define expiration when making the s3 path?
# Probably if we use the boto3 library instead of smart open
s3_path = _S3_PATH_OUTPUT.format(identifier=identifier, image_type=image_type)
with open(local_zip_filename, "rb") as fin:
with smart_open.open(s3_path, "wb") as fout:
fout.write(fin.read())
logger.info(f"Completed upload to {s3_path}")
return s3_path
def send_email(to_email, user_code):
sg = sendgrid.SendGridAPIClient(api_key=os.environ.get('SENDGRID_API_KEY'))
from_email = Email("santhosh@gretel.ai")
to_email = To(to_email)
subject = "One Time Code"
content = Content("text/plain", f"Here is your one-time code: {user_code}")
mail = Mail(from_email, to_email, subject, content)
mail_json = mail.get()
response = sg.client.mail.send.post(request_body=mail_json)
CAPTCHA_ENDPOINT = "https://captcha-api.akshit.me/v2/generate"
VERIFY_ENDPOINT = "https://captcha-api.akshit.me/v2/verify"
# Create a function to generate a captcha
def generate_captcha():
# Make a GET request to the API endpoint to generate a captcha
response = requests.get(CAPTCHA_ENDPOINT)
# If the request was successful, return the API response
if response.status_code == 200:
return response.json()
else:
logger.warn(f"Error from generate captcha request: {response.json()}")
# Otherwise, return an error message
return {"error": "Failed to generate captcha"}
# Create a function to verify the captcha
def verify_captcha(captcha_id, captcha_response):
# Make a POST request to the API endpoint with the captcha ID and response
verify_json = {"uuid": captcha_id, "captcha": captcha_response}
response = requests.post(
VERIFY_ENDPOINT, json=verify_json,
)
logger.info(f"Response from captcha verify: {response}")
# If the request was successful, return the API response
if response.status_code == 200:
return response.json()
# Otherwise, return an error message
return {"error": "Failed to verify captcha"}
def train_model(model_inputs):
if cli_args.dry_run:
logger.info("Skipping model training since --dry-run is enabled.")
logger.info(f"model_inputs: {model_inputs}")
return
if cli_args.train_endpoint_url is None:
# Use banana backend
api_key = "03cdd72e-5c04-4207-bd6a-fd5712c1740e"
model_key = "1a3b4ce5-164f-4efb-9f4a-c2ad3a930d0b"
st.markdown(str(model_inputs))
_ = banana.run(api_key, model_key, model_inputs)
else:
# Send request directly to specified url
_ = requests.post(cli_args.train_endpoint_url, json=model_inputs)
if st.session_state["email_provided"]:
user_email_input = st.empty()
with user_email_input.form(key='user_auth'):
text_input = st.text_input(label='Please Enter Your Email')
submit_button = st.form_submit_button(label='Submit')
if submit_button:
st.session_state["user_auth_sess"] = True
st.session_state["email_provided"] = False
send_email(text_input, str(st.session_state["key"]))
if st.session_state["user_auth_sess"]:
user_auth = st.empty()
user_email_input.empty()
with user_auth.form("one-code"):
text_input = st.text_input(label='Please Input One Time Code')
submit_button = st.form_submit_button(label='Submit')
if submit_button:
if text_input == st.session_state["key"]:
st.session_state["login"] = True
else:
st.markdown("Please Enter Correct Code!")
if st.session_state["login"]:
identifier = st.session_state["key"]
user_auth.empty()
user_email_input.empty()
face_images = st.empty()
with face_images.form("my_form"):
uploaded_files = st.file_uploader(
"Choose image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
)
submitted = st.form_submit_button(f"Upload")
if submitted:
with st.spinner('Uploading...'):
st.session_state["s3_face_file_path"] = zip_and_upload_images(
identifier, uploaded_files, "face"
)
st.success(f'Uploading {len(uploaded_files)} files done!')
preset_theme_images = st.empty()
with preset_theme_images.form("choose-preset-theme"):
img = image_select(
"Choose a Theme!",
images=[
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/got.png",
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/ironman.png",
"https://ichef.bbci.co.uk/images/ic/640x360/p09t1hg0.jpg",
],
captions=["Game of Thrones", "Iron Man", "Thor"],
return_value="index",
)
col1, col2 = st.columns([0.17, 1])
with col1:
submitted_3 = st.form_submit_button("Submit!")
if submitted_3:
with st.spinner():
dictionary = {
0: [
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/data/game-of-thrones.zip",
"game-of-thrones",
],
1: ["https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/data/iron-man.zip", "iron-man"],
2: ["https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/data/thor.zip", "thor"],
}
st.session_state["model_inputs"] = {
"superhero_file_path": dictionary[img][0],
# Use presigned url since backend does not have credentials
"person_file_path": generate_s3_get_url(st.session_state["s3_face_file_path"], expiration_seconds=3600),
"superhero_prompt": dictionary[img][1],
"num_images": 50,
}
st.success("Success!")
with col2:
submitted_4 = st.form_submit_button(
"If none of the themes interest you, click here!"
)
if submitted_4:
st.session_state["view"] = True
if st.session_state["view"]:
custom_theme_images = st.empty()
with custom_theme_images.form("input_custom_themes"):
st.markdown("If none of the themes interest you, please input your own!")
uploaded_files_2 = st.file_uploader(
"Choose image files",
accept_multiple_files=True,
type=["png", "jpg", "jpeg"],
)
title = st.text_input("Theme Name")
submitted_3 = st.form_submit_button("Submit!")
if submitted_3:
with st.spinner('Uploading...'):
st.session_state["s3_theme_file_path"] = zip_and_upload_images(
identifier, uploaded_files_2, "theme"
)
st.session_state["model_inputs"] = {
# Use presigned urls since backend does not have credentials
"superhero_file_path": generate_s3_get_url(st.session_state["s3_theme_file_path"], expiration_seconds=3600),
"person_file_path": generate_s3_get_url(st.session_state["s3_face_file_path"], expiration_seconds=3600),
"superhero_prompt": title,
"num_images": 50,
}
st.success('Done!')
train = st.empty()
with train.form("training"):
col1, col3, col2 = st.columns(3)
with col1:
email = st.text_input("Enter Email")
with col2:
submitted = st.form_submit_button("Train Model!")
if submitted:
if not email:
st.markdown('Please input an email!')
else:
st.session_state["captcha_bool"] = True
if st.session_state["captcha_bool"]:
captcha_form = st.empty()
with captcha_form.form("captcha_form", clear_on_submit=True):
# Create container to create image/text input out of order from the
# format submit button. Needed since we need to know the status of the
# form submit to know what the captcha should do.
captcha_container = st.container()
display_captcha = True
# TODO: Submit button renders first, then drops down once the image is
# fetched leading to page reflow. Would be nice to not have reflow, but
# we need to know if the submit button was previously pressed and if the
# captcha was solved to generate and display a new captcha or not.
# Possible solution is use an on_click callback to set a session_state
# variable to access whether the button was pushed or not instead of the
# return value here.
submitted = st.form_submit_button("Submit Captcha!")
if submitted:
result = verify_captcha(st.session_state['captcha']['uuid'], st.session_state["captcha_response"])
del st.session_state["captcha_response"]
if 'message' in result and result['message'] == 'CAPTCHA_SOLVED':
st.session_state['captcha'] = {}
display_captcha = False
with st.spinner("Model Fine Tuning..."):
st.session_state["model_inputs"]["identifier"] = st.session_state["key"]
st.session_state["model_inputs"]["email"] = email
s3_output_path = _S3_PATH_OUTPUT.format(identifier=st.session_state["key"], image_type="generated")
# The backend does not have s3 credentials, so generate
# presigned urls for the backend to use to write and read
# the generated images.
st.session_state["model_inputs"]["output_s3_url_get"] = generate_s3_get_url(
s3_output_path, expiration_seconds=60 * 60 * 24,
)
st.session_state["model_inputs"]["output_s3_url_put"] = generate_s3_put_url(
s3_output_path, expiration_seconds=3600,
)
train_model(st.session_state["model_inputs"])
st.session_state["train_view"] = True
else:
st.error(result['error'])
if display_captcha:
# Generate new captcha and display. Occurs on first load with the
# captcha_bool=True, or after previously failed captcha attempts.
result = generate_captcha()
captcha_id = result['uuid']
captcha_image = result['captcha']
st.session_state['captcha']['uuid'] = captcha_id
st.session_state['captcha']['captcha'] = captcha_image
captcha_container.image(captcha_image, width=300)
captcha_container.text_input("Enter the captcha response", key="captcha_response")
# Submit button already setup previously.