Spaces:
Runtime error
Runtime error
import base64 | |
import boto3 | |
from botocore.config import Config | |
from dotenv import load_dotenv | |
import os | |
import shutil | |
from typing import List, Tuple | |
import uuid | |
import zipfile | |
import argparse | |
import logging | |
import sendgrid | |
from sendgrid.helpers.mail import Mail, Email, To, Content | |
from glob import glob | |
from io import BytesIO | |
from itertools import cycle | |
import requests | |
import banana_dev as banana | |
import streamlit as st | |
from PIL import Image | |
from st_btn_select import st_btn_select | |
from streamlit_image_select import image_select | |
import smart_open | |
logging.basicConfig() | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
# Looks for .env file in current directory to pull environment variables. Should | |
# not overwrite already set environment variables. Used for S3 credentials. | |
load_dotenv() | |
_S3_PATH_OUTPUT = "s3://gretel-image-synthetics-use2/data/{identifier}/{image_type}_images.zip" | |
# Command-line arguments to control some stuff for easier local testing. | |
# Eventually may want to move everything into functions and have a | |
# if __name__ == "main" setup instead of everything inline. | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--dry-run", action="store_true", | |
help="Skip sending train request to backend server.", | |
) | |
parser.add_argument( | |
"--train-endpoint-url", default=None, | |
help="URL of backend server to send train request to. If None, use hardcoded banana setup.", | |
) | |
cli_args = parser.parse_args() | |
if "key" not in st.session_state: | |
st.session_state["key"] = uuid.uuid4().hex | |
if "captcha_bool" not in st.session_state: | |
st.session_state["captcha_bool"] = False | |
if "model_inputs" not in st.session_state: | |
st.session_state["model_inputs"] = None | |
if ( | |
"s3_face_file_path" not in st.session_state | |
and "s3_theme_file_path" not in st.session_state | |
): | |
st.session_state["s3_face_file_path"] = None | |
st.session_state["s3_theme_file_path"] = None | |
if "view" not in st.session_state: | |
st.session_state["view"] = False | |
if "train_view" not in st.session_state: | |
st.session_state["train_view"] = False | |
if "captcha_response" not in st.session_state: | |
st.session_state["captcha_response"] = None | |
if "captcha" not in st.session_state: | |
st.session_state["captcha"] = {} | |
if "login" not in st.session_state: | |
st.session_state["login"] = None | |
if "user_auth_sess" not in st.session_state: | |
st.session_state["user_auth_sess"] = False | |
if "user_email" not in st.session_state: | |
st.session_state["email_provided"] = True | |
def callback(): | |
st.session_state["button_clicked"] = True | |
def bucket_parts(s3_path: str) -> Tuple[str, str]: | |
"""Split an S3 path into bucket and key. | |
Args: | |
s3_path: path starting with "s3:" | |
Returns: | |
Tuple of bucket and key for the path | |
""" | |
parts = s3_path.split("/") | |
bucket = parts[2] | |
key = "/".join(parts[3:]) | |
return bucket, key | |
def generate_s3_get_url(s3_path: str, expiration_seconds: int) -> str: | |
"""Generate a presigned S3 url to read from an S3 path. | |
A presigned url allows anyone accessing that url to read the s3 path without | |
needing s3 credentials until the url expires. | |
Args: | |
s3_path: path starting with "s3:" | |
expiration_seconds: how long the url will be valid (does not influence | |
lifetime of the underlying s3 object, only the presigned url) | |
Returns: | |
The presigned url | |
""" | |
bucket, key = bucket_parts(s3_path) | |
s3_client = boto3.client("s3", config=Config(signature_version="s3v4", s3={"addressing_style": "path"})) | |
download_url = s3_client.generate_presigned_url( | |
"get_object", | |
Params={ | |
"Bucket": bucket, | |
"Key": key | |
}, | |
ExpiresIn=expiration_seconds | |
) | |
return download_url | |
def generate_s3_put_url(s3_path: str, expiration_seconds: int) -> str: | |
"""Generate a presigned S3 url to write to an S3 path. | |
A presigned url allows anyone accessing that url to write to the s3 path | |
without needing s3 credentials until the url expires. | |
Args: | |
s3_path: path starting with "s3:" | |
expiration_seconds: how long the url will be valid (does not influence | |
lifetime of the underlying s3 object, only the presigned url) | |
Returns: | |
The presigned url | |
""" | |
bucket, key = bucket_parts(s3_path) | |
s3_client = boto3.client("s3", config=Config(signature_version="s3v4", s3={"addressing_style": "path"})) | |
upload_url = s3_client.generate_presigned_url( | |
"put_object", | |
Params={ | |
"Bucket": bucket, | |
"Key": key | |
}, | |
ExpiresIn=expiration_seconds | |
) | |
return upload_url | |
def zip_and_upload_images(identifier: str, uploaded_files: List[str], image_type: str) -> str: | |
"""Save images as zip file to s3 for use in backend. | |
Blocks until images are processed, added to zip file, and uploaded to S3. | |
Args: | |
identifier: unique identifier for the run, used in s3 link | |
uploaded_files: list of file names | |
image_type: string to identify different batches of images used in the | |
backend model/training. Currently used values: "face", "theme" | |
Returns: | |
S3 location of zip file containing png images. | |
""" | |
if not os.path.exists(identifier): | |
os.makedirs(identifier) | |
logger.info("Processing uploaded images") | |
for num, uploaded_file in enumerate(uploaded_files): | |
file_ = Image.open(uploaded_file).convert("RGB") | |
file_.save(f"{identifier}/{num}_test.png") | |
local_zip_filestem = f"{identifier}_{image_type}_images" | |
logger.info("Making zip archive") | |
shutil.make_archive(local_zip_filestem, "zip", identifier) | |
local_zip_filename = f"{local_zip_filestem}.zip" | |
logger.info("Uploading zip file to s3") | |
# TODO: can we define expiration when making the s3 path? | |
# Probably if we use the boto3 library instead of smart open | |
s3_path = _S3_PATH_OUTPUT.format(identifier=identifier, image_type=image_type) | |
with open(local_zip_filename, "rb") as fin: | |
with smart_open.open(s3_path, "wb") as fout: | |
fout.write(fin.read()) | |
logger.info(f"Completed upload to {s3_path}") | |
return s3_path | |
def send_email(to_email, user_code): | |
sg = sendgrid.SendGridAPIClient(api_key=os.environ.get('SENDGRID_API_KEY')) | |
from_email = Email("santhosh@gretel.ai") | |
to_email = To(to_email) | |
subject = "One Time Code" | |
content = Content("text/plain", f"Here is your one-time code: {user_code}") | |
mail = Mail(from_email, to_email, subject, content) | |
mail_json = mail.get() | |
response = sg.client.mail.send.post(request_body=mail_json) | |
CAPTCHA_ENDPOINT = "https://captcha-api.akshit.me/v2/generate" | |
VERIFY_ENDPOINT = "https://captcha-api.akshit.me/v2/verify" | |
# Create a function to generate a captcha | |
def generate_captcha(): | |
# Make a GET request to the API endpoint to generate a captcha | |
response = requests.get(CAPTCHA_ENDPOINT) | |
# If the request was successful, return the API response | |
if response.status_code == 200: | |
return response.json() | |
else: | |
logger.warn(f"Error from generate captcha request: {response.json()}") | |
# Otherwise, return an error message | |
return {"error": "Failed to generate captcha"} | |
# Create a function to verify the captcha | |
def verify_captcha(captcha_id, captcha_response): | |
# Make a POST request to the API endpoint with the captcha ID and response | |
verify_json = {"uuid": captcha_id, "captcha": captcha_response} | |
response = requests.post( | |
VERIFY_ENDPOINT, json=verify_json, | |
) | |
logger.info(f"Response from captcha verify: {response}") | |
# If the request was successful, return the API response | |
if response.status_code == 200: | |
return response.json() | |
# Otherwise, return an error message | |
return {"error": "Failed to verify captcha"} | |
def train_model(model_inputs): | |
if cli_args.dry_run: | |
logger.info("Skipping model training since --dry-run is enabled.") | |
logger.info(f"model_inputs: {model_inputs}") | |
return | |
if cli_args.train_endpoint_url is None: | |
# Use banana backend | |
api_key = "03cdd72e-5c04-4207-bd6a-fd5712c1740e" | |
model_key = "1a3b4ce5-164f-4efb-9f4a-c2ad3a930d0b" | |
st.markdown(str(model_inputs)) | |
_ = banana.run(api_key, model_key, model_inputs) | |
else: | |
# Send request directly to specified url | |
_ = requests.post(cli_args.train_endpoint_url, json=model_inputs) | |
if st.session_state["email_provided"]: | |
user_email_input = st.empty() | |
with user_email_input.form(key='user_auth'): | |
text_input = st.text_input(label='Please Enter Your Email') | |
submit_button = st.form_submit_button(label='Submit') | |
if submit_button: | |
st.session_state["user_auth_sess"] = True | |
st.session_state["email_provided"] = False | |
send_email(text_input, str(st.session_state["key"])) | |
if st.session_state["user_auth_sess"]: | |
user_auth = st.empty() | |
user_email_input.empty() | |
with user_auth.form("one-code"): | |
text_input = st.text_input(label='Please Input One Time Code') | |
submit_button = st.form_submit_button(label='Submit') | |
if submit_button: | |
if text_input == st.session_state["key"]: | |
st.session_state["login"] = True | |
else: | |
st.markdown("Please Enter Correct Code!") | |
if st.session_state["login"]: | |
identifier = st.session_state["key"] | |
user_auth.empty() | |
user_email_input.empty() | |
face_images = st.empty() | |
with face_images.form("my_form"): | |
uploaded_files = st.file_uploader( | |
"Choose image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"] | |
) | |
submitted = st.form_submit_button(f"Upload") | |
if submitted: | |
with st.spinner('Uploading...'): | |
st.session_state["s3_face_file_path"] = zip_and_upload_images( | |
identifier, uploaded_files, "face" | |
) | |
st.success(f'Uploading {len(uploaded_files)} files done!') | |
preset_theme_images = st.empty() | |
with preset_theme_images.form("choose-preset-theme"): | |
img = image_select( | |
"Choose a Theme!", | |
images=[ | |
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/got.png", | |
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/ironman.png", | |
"https://ichef.bbci.co.uk/images/ic/640x360/p09t1hg0.jpg", | |
], | |
captions=["Game of Thrones", "Iron Man", "Thor"], | |
return_value="index", | |
) | |
col1, col2 = st.columns([0.17, 1]) | |
with col1: | |
submitted_3 = st.form_submit_button("Submit!") | |
if submitted_3: | |
with st.spinner(): | |
dictionary = { | |
0: [ | |
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/data/game-of-thrones.zip", | |
"game-of-thrones", | |
], | |
1: ["https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/data/iron-man.zip", "iron-man"], | |
2: ["https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/data/thor.zip", "thor"], | |
} | |
st.session_state["model_inputs"] = { | |
"superhero_file_path": dictionary[img][0], | |
# Use presigned url since backend does not have credentials | |
"person_file_path": generate_s3_get_url(st.session_state["s3_face_file_path"], expiration_seconds=3600), | |
"superhero_prompt": dictionary[img][1], | |
"num_images": 50, | |
} | |
st.success("Success!") | |
with col2: | |
submitted_4 = st.form_submit_button( | |
"If none of the themes interest you, click here!" | |
) | |
if submitted_4: | |
st.session_state["view"] = True | |
if st.session_state["view"]: | |
custom_theme_images = st.empty() | |
with custom_theme_images.form("input_custom_themes"): | |
st.markdown("If none of the themes interest you, please input your own!") | |
uploaded_files_2 = st.file_uploader( | |
"Choose image files", | |
accept_multiple_files=True, | |
type=["png", "jpg", "jpeg"], | |
) | |
title = st.text_input("Theme Name") | |
submitted_3 = st.form_submit_button("Submit!") | |
if submitted_3: | |
with st.spinner('Uploading...'): | |
st.session_state["s3_theme_file_path"] = zip_and_upload_images( | |
identifier, uploaded_files_2, "theme" | |
) | |
st.session_state["model_inputs"] = { | |
# Use presigned urls since backend does not have credentials | |
"superhero_file_path": generate_s3_get_url(st.session_state["s3_theme_file_path"], expiration_seconds=3600), | |
"person_file_path": generate_s3_get_url(st.session_state["s3_face_file_path"], expiration_seconds=3600), | |
"superhero_prompt": title, | |
"num_images": 50, | |
} | |
st.success('Done!') | |
train = st.empty() | |
with train.form("training"): | |
col1, col3, col2 = st.columns(3) | |
with col1: | |
email = st.text_input("Enter Email") | |
with col2: | |
submitted = st.form_submit_button("Train Model!") | |
if submitted: | |
if not email: | |
st.markdown('Please input an email!') | |
else: | |
st.session_state["captcha_bool"] = True | |
if st.session_state["captcha_bool"]: | |
captcha_form = st.empty() | |
with captcha_form.form("captcha_form", clear_on_submit=True): | |
# Create container to create image/text input out of order from the | |
# format submit button. Needed since we need to know the status of the | |
# form submit to know what the captcha should do. | |
captcha_container = st.container() | |
display_captcha = True | |
# TODO: Submit button renders first, then drops down once the image is | |
# fetched leading to page reflow. Would be nice to not have reflow, but | |
# we need to know if the submit button was previously pressed and if the | |
# captcha was solved to generate and display a new captcha or not. | |
# Possible solution is use an on_click callback to set a session_state | |
# variable to access whether the button was pushed or not instead of the | |
# return value here. | |
submitted = st.form_submit_button("Submit Captcha!") | |
if submitted: | |
result = verify_captcha(st.session_state['captcha']['uuid'], st.session_state["captcha_response"]) | |
del st.session_state["captcha_response"] | |
if 'message' in result and result['message'] == 'CAPTCHA_SOLVED': | |
st.session_state['captcha'] = {} | |
display_captcha = False | |
with st.spinner("Model Fine Tuning..."): | |
st.session_state["model_inputs"]["identifier"] = st.session_state["key"] | |
st.session_state["model_inputs"]["email"] = email | |
s3_output_path = _S3_PATH_OUTPUT.format(identifier=st.session_state["key"], image_type="generated") | |
# The backend does not have s3 credentials, so generate | |
# presigned urls for the backend to use to write and read | |
# the generated images. | |
st.session_state["model_inputs"]["output_s3_url_get"] = generate_s3_get_url( | |
s3_output_path, expiration_seconds=60 * 60 * 24, | |
) | |
st.session_state["model_inputs"]["output_s3_url_put"] = generate_s3_put_url( | |
s3_output_path, expiration_seconds=3600, | |
) | |
train_model(st.session_state["model_inputs"]) | |
st.session_state["train_view"] = True | |
else: | |
st.error(result['error']) | |
if display_captcha: | |
# Generate new captcha and display. Occurs on first load with the | |
# captcha_bool=True, or after previously failed captcha attempts. | |
result = generate_captcha() | |
captcha_id = result['uuid'] | |
captcha_image = result['captcha'] | |
st.session_state['captcha']['uuid'] = captcha_id | |
st.session_state['captcha']['captcha'] = captcha_image | |
captcha_container.image(captcha_image, width=300) | |
captcha_container.text_input("Enter the captcha response", key="captcha_response") | |
# Submit button already setup previously. |