Santhosh Subramanian
files
097c210
raw
history blame
5.13 kB
import base64
import os
import shutil
import uuid
import zipfile
from argparse import Namespace
from glob import glob
from io import BytesIO
from itertools import cycle
import banana_dev as banana
import streamlit as st
from PIL import Image
from st_btn_select import st_btn_select
from streamlit_image_select import image_select
if "key" not in st.session_state:
st.session_state["key"] = uuid.uuid4().hex
if "model_inputs" not in st.session_state:
st.session_state["model_inputs"] = None
if (
"s3_face_file_path" not in st.session_state
and "s3_theme_file_path" not in st.session_state
):
st.session_state["s3_face_file_path"] = None
st.session_state["s3_theme_file_path"] = None
if "view" not in st.session_state:
st.session_state["view"] = False
def callback():
st.session_state["button_clicked"] = True
def zip_and_upload_images(identifier, uploaded_files, image_type):
if not os.path.exists(identifier):
os.makedirs(identifier)
for num, uploaded_file in enumerate(uploaded_files):
file_ = Image.open(uploaded_file).convert("RGB")
file_.save(f"{identifier}/{num}_test.png")
shutil.make_archive(f"{identifier}_{image_type}_images", "zip", identifier)
os.system(
f"aws s3 cp {identifier}_{image_type}_images.zip s3://gretel-image-synthetics/data/{identifier}/{image_type}_images.zip --no-sign-request"
)
return f"s3://gretel-image-synthetics/data/{identifier}/{image_type}_images.zip"
def train_model(model_inputs):
api_key = "03cdd72e-5c04-4207-bd6a-fd5712c1740e"
model_key = "fb9e7bcc-7291-4af6-b2fc-2e98a3b6e7e5"
st.markdown(str(model_inputs))
# out = banana.run(api_key, model_key, model_inputs)
# if not os.path.exists("generated"):
# os.makedirs("generated")
# for num, img in enumerate(out["modelOutputs"][0]["image_base64"]):
# image_encoded = img.encode("utf-8")
# image_bytes = BytesIO(base64.b64decode(image_encoded))
# image = Image.open(image_bytes)
# image.save(f"{num}_output.jpg")
identifier = st.session_state["key"]
face_images = st.empty()
with face_images.form("my_form"):
uploaded_files = st.file_uploader(
"Choose image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
)
submitted = st.form_submit_button("Submit")
if submitted:
st.session_state["s3_face_file_path"] = zip_and_upload_images(
identifier, uploaded_files, "face"
)
preset_theme_images = st.empty()
with preset_theme_images.form("choose-preset-theme"):
img = image_select(
"Choose a Theme!",
images=[
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/got.png",
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/ironman.png",
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/thor.png",
],
captions=["Game of Thrones", "Iron Man", "Thor"],
return_value="index",
)
col1, col2 = st.columns([0.15, 1])
with col1:
submitted_3 = st.form_submit_button("Submit!")
if submitted_3:
dictionary = {
0: [
"s3://gretel-image-synthetics/data/game-of-thrones.zip",
"game-of-thrones",
],
1: ["s3://gretel-image-synthetics/data/iron-man.zip", "iron-man"],
2: ["s3://gretel-image-synthetics/data/thor.zip", "thor"],
}
st.session_state["model_inputs"] = {
"superhero_file_path": dictionary[img][0],
"person_file_path": st.session_state["s3_face_file_path"],
"superhero_prompt": dictionary[img][1],
"num_images": 50,
}
with col2:
submitted_4 = st.form_submit_button(
"If none of the themes interest you, click here!"
)
if submitted_4:
st.session_state["view"] = True
if st.session_state["view"]:
custom_theme_images = st.empty()
with custom_theme_images.form("input_custom_themes"):
st.markdown("If none of the themes interest you, please input your own!")
uploaded_files_2 = st.file_uploader(
"Choose image files",
accept_multiple_files=True,
type=["png", "jpg", "jpeg"],
)
title = st.text_input("Theme Name")
submitted_3 = st.form_submit_button("Submit!")
if submitted_3:
st.session_state["s3_theme_file_path"] = zip_and_upload_images(
identifier, uploaded_files_2, "theme"
)
st.session_state["model_inputs"] = {
"superhero_file_path": st.session_state["s3_theme_file_path"],
"person_file_path": st.session_state["s3_face_file_path"],
"superhero_prompt": title,
"num_images": 50,
}
train = st.empty()
with train.form("training"):
submitted = st.form_submit_button("Train Model!")
if submitted:
train_model(st.session_state["model_inputs"])