import base64 import boto3 from botocore.config import Config from dotenv import load_dotenv import os import shutil from typing import List, Tuple import uuid import zipfile import argparse import logging import sendgrid from sendgrid.helpers.mail import Mail, Email, To, Content from glob import glob from io import BytesIO from itertools import cycle import requests import banana_dev as banana import streamlit as st from PIL import Image from st_btn_select import st_btn_select from streamlit_image_select import image_select import smart_open logging.basicConfig() logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # Looks for .env file in current directory to pull environment variables. Should # not overwrite already set environment variables. Used for S3 credentials. load_dotenv() _S3_PATH_OUTPUT = "s3://gretel-image-synthetics-use2/data/{identifier}/{image_type}_images.zip" # Command-line arguments to control some stuff for easier local testing. # Eventually may want to move everything into functions and have a # if __name__ == "main" setup instead of everything inline. parser = argparse.ArgumentParser() parser.add_argument( "--dry-run", action="store_true", help="Skip sending train request to backend server.", ) parser.add_argument( "--train-endpoint-url", default=None, help="URL of backend server to send train request to. If None, use hardcoded banana setup.", ) cli_args = parser.parse_args() if "key" not in st.session_state: st.session_state["key"] = uuid.uuid4().hex if "captcha_bool" not in st.session_state: st.session_state["captcha_bool"] = False if "model_inputs" not in st.session_state: st.session_state["model_inputs"] = None if ( "s3_face_file_path" not in st.session_state and "s3_theme_file_path" not in st.session_state ): st.session_state["s3_face_file_path"] = None st.session_state["s3_theme_file_path"] = None if "view" not in st.session_state: st.session_state["view"] = False if "train_view" not in st.session_state: st.session_state["train_view"] = False if "captcha_response" not in st.session_state: st.session_state["captcha_response"] = None if "captcha" not in st.session_state: st.session_state["captcha"] = {} if "login" not in st.session_state: st.session_state["login"] = None if "user_auth_sess" not in st.session_state: st.session_state["user_auth_sess"] = False if "user_email" not in st.session_state: st.session_state["email_provided"] = True def callback(): st.session_state["button_clicked"] = True def bucket_parts(s3_path: str) -> Tuple[str, str]: """Split an S3 path into bucket and key. Args: s3_path: path starting with "s3:" Returns: Tuple of bucket and key for the path """ parts = s3_path.split("/") bucket = parts[2] key = "/".join(parts[3:]) return bucket, key def generate_s3_get_url(s3_path: str, expiration_seconds: int) -> str: """Generate a presigned S3 url to read from an S3 path. A presigned url allows anyone accessing that url to read the s3 path without needing s3 credentials until the url expires. Args: s3_path: path starting with "s3:" expiration_seconds: how long the url will be valid (does not influence lifetime of the underlying s3 object, only the presigned url) Returns: The presigned url """ bucket, key = bucket_parts(s3_path) s3_client = boto3.client("s3", config=Config(signature_version="s3v4", s3={"addressing_style": "path"})) download_url = s3_client.generate_presigned_url( "get_object", Params={ "Bucket": bucket, "Key": key }, ExpiresIn=expiration_seconds ) return download_url def generate_s3_put_url(s3_path: str, expiration_seconds: int) -> str: """Generate a presigned S3 url to write to an S3 path. A presigned url allows anyone accessing that url to write to the s3 path without needing s3 credentials until the url expires. Args: s3_path: path starting with "s3:" expiration_seconds: how long the url will be valid (does not influence lifetime of the underlying s3 object, only the presigned url) Returns: The presigned url """ bucket, key = bucket_parts(s3_path) s3_client = boto3.client("s3", config=Config(signature_version="s3v4", s3={"addressing_style": "path"})) upload_url = s3_client.generate_presigned_url( "put_object", Params={ "Bucket": bucket, "Key": key }, ExpiresIn=expiration_seconds ) return upload_url def zip_and_upload_images(identifier: str, uploaded_files: List[str], image_type: str) -> str: """Save images as zip file to s3 for use in backend. Blocks until images are processed, added to zip file, and uploaded to S3. Args: identifier: unique identifier for the run, used in s3 link uploaded_files: list of file names image_type: string to identify different batches of images used in the backend model/training. Currently used values: "face", "theme" Returns: S3 location of zip file containing png images. """ if not os.path.exists(identifier): os.makedirs(identifier) logger.info("Processing uploaded images") for num, uploaded_file in enumerate(uploaded_files): file_ = Image.open(uploaded_file).convert("RGB") file_.save(f"{identifier}/{num}_test.png") local_zip_filestem = f"{identifier}_{image_type}_images" logger.info("Making zip archive") shutil.make_archive(local_zip_filestem, "zip", identifier) local_zip_filename = f"{local_zip_filestem}.zip" logger.info("Uploading zip file to s3") # TODO: can we define expiration when making the s3 path? # Probably if we use the boto3 library instead of smart open s3_path = _S3_PATH_OUTPUT.format(identifier=identifier, image_type=image_type) with open(local_zip_filename, "rb") as fin: with smart_open.open(s3_path, "wb") as fout: fout.write(fin.read()) logger.info(f"Completed upload to {s3_path}") return s3_path def send_email(to_email, user_code): sg = sendgrid.SendGridAPIClient(api_key=os.environ.get('SENDGRID_API_KEY')) from_email = Email("santhosh@gretel.ai") to_email = To(to_email) subject = "One Time Code" content = Content("text/plain", f"Here is your one-time code: {user_code}") mail = Mail(from_email, to_email, subject, content) mail_json = mail.get() response = sg.client.mail.send.post(request_body=mail_json) CAPTCHA_ENDPOINT = "https://captcha-api.akshit.me/v2/generate" VERIFY_ENDPOINT = "https://captcha-api.akshit.me/v2/verify" # Create a function to generate a captcha def generate_captcha(): # Make a GET request to the API endpoint to generate a captcha response = requests.get(CAPTCHA_ENDPOINT) # If the request was successful, return the API response if response.status_code == 200: return response.json() else: logger.warn(f"Error from generate captcha request: {response.json()}") # Otherwise, return an error message return {"error": "Failed to generate captcha"} # Create a function to verify the captcha def verify_captcha(captcha_id, captcha_response): # Make a POST request to the API endpoint with the captcha ID and response verify_json = {"uuid": captcha_id, "captcha": captcha_response} response = requests.post( VERIFY_ENDPOINT, json=verify_json, ) logger.info(f"Response from captcha verify: {response}") # If the request was successful, return the API response if response.status_code == 200: return response.json() # Otherwise, return an error message return {"error": "Failed to verify captcha"} def train_model(model_inputs): if cli_args.dry_run: logger.info("Skipping model training since --dry-run is enabled.") logger.info(f"model_inputs: {model_inputs}") return if cli_args.train_endpoint_url is None: # Use banana backend api_key = "03cdd72e-5c04-4207-bd6a-fd5712c1740e" model_key = "1a3b4ce5-164f-4efb-9f4a-c2ad3a930d0b" st.markdown(str(model_inputs)) _ = banana.run(api_key, model_key, model_inputs) else: # Send request directly to specified url _ = requests.post(cli_args.train_endpoint_url, json=model_inputs) if st.session_state["email_provided"]: user_email_input = st.empty() with user_email_input.form(key='user_auth'): text_input = st.text_input(label='Please Enter Your Email') submit_button = st.form_submit_button(label='Submit') if submit_button: st.session_state["user_auth_sess"] = True st.session_state["email_provided"] = False send_email(text_input, str(st.session_state["key"])) if st.session_state["user_auth_sess"]: user_auth = st.empty() user_email_input.empty() with user_auth.form("one-code"): text_input = st.text_input(label='Please Input One Time Code') submit_button = st.form_submit_button(label='Submit') if submit_button: if text_input == st.session_state["key"]: st.session_state["login"] = True else: st.markdown("Please Enter Correct Code!") if st.session_state["login"]: identifier = st.session_state["key"] user_auth.empty() user_email_input.empty() face_images = st.empty() with face_images.form("my_form"): uploaded_files = st.file_uploader( "Choose image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"] ) submitted = st.form_submit_button(f"Upload") if submitted: with st.spinner('Uploading...'): st.session_state["s3_face_file_path"] = zip_and_upload_images( identifier, uploaded_files, "face" ) st.success(f'Uploading {len(uploaded_files)} files done!') preset_theme_images = st.empty() with preset_theme_images.form("choose-preset-theme"): img = image_select( "Choose a Theme!", images=[ "https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/got.png", "https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/ironman.png", "https://ichef.bbci.co.uk/images/ic/640x360/p09t1hg0.jpg", ], captions=["Game of Thrones", "Iron Man", "Thor"], return_value="index", ) col1, col2 = st.columns([0.17, 1]) with col1: submitted_3 = st.form_submit_button("Submit!") if submitted_3: with st.spinner(): dictionary = { 0: [ "https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/data/game-of-thrones.zip", "game-of-thrones", ], 1: ["https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/data/iron-man.zip", "iron-man"], 2: ["https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/data/thor.zip", "thor"], } st.session_state["model_inputs"] = { "superhero_file_path": dictionary[img][0], # Use presigned url since backend does not have credentials "person_file_path": generate_s3_get_url(st.session_state["s3_face_file_path"], expiration_seconds=3600), "superhero_prompt": dictionary[img][1], "num_images": 50, } st.success("Success!") with col2: submitted_4 = st.form_submit_button( "If none of the themes interest you, click here!" ) if submitted_4: st.session_state["view"] = True if st.session_state["view"]: custom_theme_images = st.empty() with custom_theme_images.form("input_custom_themes"): st.markdown("If none of the themes interest you, please input your own!") uploaded_files_2 = st.file_uploader( "Choose image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"], ) title = st.text_input("Theme Name") submitted_3 = st.form_submit_button("Submit!") if submitted_3: with st.spinner('Uploading...'): st.session_state["s3_theme_file_path"] = zip_and_upload_images( identifier, uploaded_files_2, "theme" ) st.session_state["model_inputs"] = { # Use presigned urls since backend does not have credentials "superhero_file_path": generate_s3_get_url(st.session_state["s3_theme_file_path"], expiration_seconds=3600), "person_file_path": generate_s3_get_url(st.session_state["s3_face_file_path"], expiration_seconds=3600), "superhero_prompt": title, "num_images": 50, } st.success('Done!') train = st.empty() with train.form("training"): col1, col3, col2 = st.columns(3) with col1: email = st.text_input("Enter Email") with col2: submitted = st.form_submit_button("Train Model!") if submitted: if not email: st.markdown('Please input an email!') else: st.session_state["captcha_bool"] = True if st.session_state["captcha_bool"]: captcha_form = st.empty() with captcha_form.form("captcha_form", clear_on_submit=True): # Create container to create image/text input out of order from the # format submit button. Needed since we need to know the status of the # form submit to know what the captcha should do. captcha_container = st.container() display_captcha = True # TODO: Submit button renders first, then drops down once the image is # fetched leading to page reflow. Would be nice to not have reflow, but # we need to know if the submit button was previously pressed and if the # captcha was solved to generate and display a new captcha or not. # Possible solution is use an on_click callback to set a session_state # variable to access whether the button was pushed or not instead of the # return value here. submitted = st.form_submit_button("Submit Captcha!") if submitted: result = verify_captcha(st.session_state['captcha']['uuid'], st.session_state["captcha_response"]) del st.session_state["captcha_response"] if 'message' in result and result['message'] == 'CAPTCHA_SOLVED': st.session_state['captcha'] = {} display_captcha = False with st.spinner("Model Fine Tuning..."): st.session_state["model_inputs"]["identifier"] = st.session_state["key"] st.session_state["model_inputs"]["email"] = email s3_output_path = _S3_PATH_OUTPUT.format(identifier=st.session_state["key"], image_type="generated") # The backend does not have s3 credentials, so generate # presigned urls for the backend to use to write and read # the generated images. st.session_state["model_inputs"]["output_s3_url_get"] = generate_s3_get_url( s3_output_path, expiration_seconds=60 * 60 * 24, ) st.session_state["model_inputs"]["output_s3_url_put"] = generate_s3_put_url( s3_output_path, expiration_seconds=3600, ) train_model(st.session_state["model_inputs"]) st.session_state["train_view"] = True else: st.error(result['error']) if display_captcha: # Generate new captcha and display. Occurs on first load with the # captcha_bool=True, or after previously failed captcha attempts. result = generate_captcha() captcha_id = result['uuid'] captcha_image = result['captcha'] st.session_state['captcha']['uuid'] = captcha_id st.session_state['captcha']['captcha'] = captcha_image captcha_container.image(captcha_image, width=300) captcha_container.text_input("Enter the captcha response", key="captcha_response") # Submit button already setup previously.