Spaces:
Runtime error
Runtime error
import base64 | |
import os | |
import shutil | |
import uuid | |
import zipfile | |
from argparse import Namespace | |
from glob import glob | |
from io import BytesIO | |
from itertools import cycle | |
import banana_dev as banana | |
import streamlit as st | |
from PIL import Image | |
from st_btn_select import st_btn_select | |
from streamlit_image_select import image_select | |
if "key" not in st.session_state: | |
st.session_state["key"] = uuid.uuid4().hex | |
if "model_inputs" not in st.session_state: | |
st.session_state["model_inputs"] = None | |
if ( | |
"s3_face_file_path" not in st.session_state | |
and "s3_theme_file_path" not in st.session_state | |
): | |
st.session_state["s3_face_file_path"] = None | |
st.session_state["s3_theme_file_path"] = None | |
if "view" not in st.session_state: | |
st.session_state["view"] = False | |
def callback(): | |
st.session_state["button_clicked"] = True | |
def zip_and_upload_images(identifier, uploaded_files, image_type): | |
if not os.path.exists(identifier): | |
os.makedirs(identifier) | |
for num, uploaded_file in enumerate(uploaded_files): | |
file_ = Image.open(uploaded_file).convert("RGB") | |
file_.save(f"{identifier}/{num}_test.png") | |
shutil.make_archive(f"{identifier}_{image_type}_images", "zip", identifier) | |
os.system( | |
f"aws s3 cp {identifier}_{image_type}_images.zip s3://gretel-image-synthetics/data/{identifier}/{image_type}_images.zip --no-sign-request" | |
) | |
return f"s3://gretel-image-synthetics/data/{identifier}/{image_type}_images.zip" | |
def train_model(model_inputs): | |
api_key = "03cdd72e-5c04-4207-bd6a-fd5712c1740e" | |
model_key = "fb9e7bcc-7291-4af6-b2fc-2e98a3b6e7e5" | |
st.markdown(str(model_inputs)) | |
# out = banana.run(api_key, model_key, model_inputs) | |
# if not os.path.exists("generated"): | |
# os.makedirs("generated") | |
# for num, img in enumerate(out["modelOutputs"][0]["image_base64"]): | |
# image_encoded = img.encode("utf-8") | |
# image_bytes = BytesIO(base64.b64decode(image_encoded)) | |
# image = Image.open(image_bytes) | |
# image.save(f"{num}_output.jpg") | |
identifier = st.session_state["key"] | |
face_images = st.empty() | |
with face_images.form("my_form"): | |
uploaded_files = st.file_uploader( | |
"Choose image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"] | |
) | |
submitted = st.form_submit_button("Submit") | |
if submitted: | |
st.session_state["s3_face_file_path"] = zip_and_upload_images( | |
identifier, uploaded_files, "face" | |
) | |
preset_theme_images = st.empty() | |
with preset_theme_images.form("choose-preset-theme"): | |
img = image_select( | |
"Choose a Theme!", | |
images=[ | |
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/got.png", | |
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/ironman.png", | |
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/thor.png", | |
], | |
captions=["Game of Thrones", "Iron Man", "Thor"], | |
return_value="index", | |
) | |
col1, col2 = st.columns([0.15, 1]) | |
with col1: | |
submitted_3 = st.form_submit_button("Submit!") | |
if submitted_3: | |
dictionary = { | |
0: [ | |
"s3://gretel-image-synthetics/data/game-of-thrones.zip", | |
"game-of-thrones", | |
], | |
1: ["s3://gretel-image-synthetics/data/iron-man.zip", "iron-man"], | |
2: ["s3://gretel-image-synthetics/data/thor.zip", "thor"], | |
} | |
st.session_state["model_inputs"] = { | |
"superhero_file_path": dictionary[img][0], | |
"person_file_path": st.session_state["s3_face_file_path"], | |
"superhero_prompt": dictionary[img][1], | |
"num_images": 50, | |
} | |
with col2: | |
submitted_4 = st.form_submit_button( | |
"If none of the themes interest you, click here!" | |
) | |
if submitted_4: | |
st.session_state["view"] = True | |
if st.session_state["view"]: | |
custom_theme_images = st.empty() | |
with custom_theme_images.form("input_custom_themes"): | |
st.markdown("If none of the themes interest you, please input your own!") | |
uploaded_files_2 = st.file_uploader( | |
"Choose image files", | |
accept_multiple_files=True, | |
type=["png", "jpg", "jpeg"], | |
) | |
title = st.text_input("Theme Name") | |
submitted_3 = st.form_submit_button("Submit!") | |
if submitted_3: | |
st.session_state["s3_theme_file_path"] = zip_and_upload_images( | |
identifier, uploaded_files_2, "theme" | |
) | |
st.session_state["model_inputs"] = { | |
"superhero_file_path": st.session_state["s3_theme_file_path"], | |
"person_file_path": st.session_state["s3_face_file_path"], | |
"superhero_prompt": title, | |
"num_images": 50, | |
} | |
train = st.empty() | |
with train.form("training"): | |
submitted = st.form_submit_button("Train Model!") | |
if submitted: | |
train_model(st.session_state["model_inputs"]) | |