minDALLE / app.py
valhalla's picture
Update app.py
4d72a29
raw
history blame
No virus
2.86 kB
import base64
import os
import time
from io import BytesIO
from multiprocessing import Process
import streamlit as st
from PIL import Image
import requests
def start_server():
os.system("uvicorn server:app --port 8080 --host 0.0.0.0 --workers 2")
def load_models():
if not is_port_in_use(8080):
with st.spinner(text="Loading models, please wait..."):
proc = Process(target=start_server, args=(), daemon=True)
proc.start()
while not is_port_in_use(8080):
time.sleep(1)
st.success("Model server started.")
else:
st.success("Model server already running...")
st.session_state["models_loaded"] = True
def is_port_in_use(port):
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("0.0.0.0", port)) == 0
def generate(prompt):
correct_request = f"http://0.0.0.0:8080/correct?prompt={prompt}"
response = requests.get(correct_request)
images = response.json()["images"]
images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
return images
if "models_loaded" not in st.session_state:
st.session_state["models_loaded"] = False
st.header("minDALL-E")
st.subheader("Generate images from text")
if not st.session_state["models_loaded"]:
load_models()
prompt = st.text_input("What do you want to see?")
DEBUG = False
if prompt != "":
container = st.empty()
container.markdown(
f"""
<style> p {{ margin:0 }} div {{ margin:0 }} </style>
<div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
<div class="stAlert">
<div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
<div class="st-b7">
<div class="css-whx05o e13vu3m50">
<div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
<img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/streamlit/img/loading.gif" width="30"/>
Generating predictions for: <b>{prompt}</b>
</div>
</div>
</div>
</div>
</div>
</div>
<small><i>Predictions may take up to 40s under high load. Please stand by.</i></small>
""",
unsafe_allow_html=True,
)
print(f"Getting selections: {prompt}")
selected = generate(prompt)
margin = 0.1 # for better position of zoom in arrow
n_columns = 3
cols = st.columns([1] + [margin, 1] * (n_columns - 1))
for i, img in enumerate(selected):
cols[(i % n_columns) * 2].image(img)
container.markdown(f"**{prompt}**")
st.button("Again!", key="again_button")