File size: 3,150 Bytes
4d72a29
0ea4415
4d72a29
 
 
b442155
 
0ea4415
 
4d72a29
091b9da
 
4d72a29
 
091b9da
 
4d72a29
 
 
 
 
 
 
 
 
 
 
0ea4415
 
4d72a29
 
0ea4415
4d72a29
 
0ea4415
 
4d72a29
 
 
 
 
 
0ea4415
b442155
4d72a29
 
0ea4415
 
 
a200d93
405665c
b442155
4d72a29
 
 
b442155
0ea4415
 
749bafa
0ea4415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d72a29
0ea4415
4d72a29
0ea4415
 
 
 
 
 
 
749bafa
c55fdff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import base64
import os
import time
from io import BytesIO
from multiprocessing import Process

import streamlit as st
from PIL import Image

import requests


def start_server():
    os.system("uvicorn server:app --port 8080 --host 0.0.0.0 --workers 2")


def load_models():
    if not is_port_in_use(8080):
        with st.spinner(text="Loading models, please wait..."):
            proc = Process(target=start_server, args=(), daemon=True)
            proc.start()
            while not is_port_in_use(8080):
                time.sleep(1)
            st.success("Model server started.")
    else:
        st.success("Model server already running...")
    st.session_state["models_loaded"] = True


def is_port_in_use(port):
    import socket

    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        return s.connect_ex(("0.0.0.0", port)) == 0


def generate(prompt):
    correct_request = f"http://0.0.0.0:8080/correct?prompt={prompt}"
    response = requests.get(correct_request)
    images = response.json()["images"]
    images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
    return images


if "models_loaded" not in st.session_state:
    st.session_state["models_loaded"] = False


st.header("minDALL-E")
#st.subheader("Generate images from text")
st.write("Generate images from text: Interactive demo for [minDALL-E](https://github.com/kakaobrain/minDALL-E)")

if not st.session_state["models_loaded"]:
    load_models()

prompt = st.text_input("What do you want to see?")

DEBUG = False
# UI code taken from https://huggingface.co/spaces/flax-community/dalle-mini/blob/main/app/streamlit/app.py
if prompt != "":
    container = st.empty()
    container.markdown(
        f"""
        <style> p {{ margin:0 }} div {{ margin:0 }} </style>
        <div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
        <div class="stAlert">
        <div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
        <div class="st-b7">
        <div class="css-whx05o e13vu3m50">
        <div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
                <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/streamlit/img/loading.gif" width="30"/>
                Generating predictions for: <b>{prompt}</b>
        </div>
        </div>
        </div>
        </div>
        </div>
        </div>
    """,
        unsafe_allow_html=True,
    )

    print(f"Getting selections: {prompt}")
    selected = generate(prompt)

    margin = 0.1  # for better position of zoom in arrow
    n_columns = 3
    cols = st.columns([1] + [margin, 1] * (n_columns - 1))
    for i, img in enumerate(selected):
        cols[(i % n_columns) * 2].image(img)
    container.markdown(f"**{prompt}**")

    st.button("Again!", key="again_button")
    
    st.write(f"<b><i>UI credits: <a href='https://huggingface.co/spaces/flax-community/dalle-mini'>DALL-E mini Space</a></i></b>", unsafe_allow_html=True)