adamelliotfields
commited on
Commit
β’
3c7025e
1
Parent(s):
7c829e7
App
Browse files- 0_π _Home.py +49 -0
- app.py +10 -0
- lib/__init__.py +4 -0
- lib/config.py +22 -0
- lib/presets.py +33 -0
- logo.svg +5 -0
- pages/1_π¬_Text_Generation.py +174 -0
- pages/2_π¨_Text_to_Image.py +212 -0
0_π _Home.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
from lib import Config
|
4 |
+
|
5 |
+
st.set_page_config(
|
6 |
+
page_title=Config.TITLE,
|
7 |
+
page_icon=Config.ICON,
|
8 |
+
layout=Config.LAYOUT,
|
9 |
+
)
|
10 |
+
|
11 |
+
# sidebar
|
12 |
+
st.logo("logo.svg")
|
13 |
+
|
14 |
+
# title
|
15 |
+
st.html("""
|
16 |
+
<style>
|
17 |
+
.pro-badge {
|
18 |
+
display: inline-block;
|
19 |
+
transform: skew(-12deg);
|
20 |
+
font-size: 0.875rem;
|
21 |
+
line-height: 1.25rem;
|
22 |
+
font-weight: 700;
|
23 |
+
padding: 0.125rem 0.625rem;
|
24 |
+
border-radius: 0.5rem;
|
25 |
+
color: rgb(0 0 0 / 1);
|
26 |
+
box-shadow: 0 0 #0000, 0 0 #0000, 0 10px 15px -3px rgb(16 185 129 / .1), 0 4px 6px -4px rgb(16 185 129 / .1);
|
27 |
+
background-image: linear-gradient(to bottom right, #f9a8d4, #a7f3d0, #fde68a);
|
28 |
+
border: 1px solid rgb(229 231 235 / 1);
|
29 |
+
}
|
30 |
+
@media (prefers-color-scheme: dark) {
|
31 |
+
.pro-badge {
|
32 |
+
box-shadow: 0 0 #0000, 0 0 #0000, 0 10px 15px -3px rgb(16 185 129 / .2), 0 4px 6px -4px rgb(16 185 129 / .2);
|
33 |
+
background-image: linear-gradient(to bottom right, #ec4899, #10b981, #f59e0b);
|
34 |
+
border: 1px solid rgb(20 28 46 / 1);
|
35 |
+
}
|
36 |
+
}
|
37 |
+
</style>
|
38 |
+
<div style="display: flex; align-items: center; gap: 0.75rem">
|
39 |
+
<h1 style="padding: 0; margin-bottom: 0.5rem">API Inference</h1>
|
40 |
+
<span class="pro-badge">PRO</span>
|
41 |
+
</div>
|
42 |
+
<p>Run inference on API endpoints. Hugging Face for now; more coming soon!</p>
|
43 |
+
""")
|
44 |
+
|
45 |
+
# TODO: categorize tasks by service (SAI, FAL, etc)
|
46 |
+
# content
|
47 |
+
st.markdown("## Tasks")
|
48 |
+
st.page_link("pages/1_π¬_Text_Generation.py", label="Text Generation", icon="π¬")
|
49 |
+
st.page_link("pages/2_π¨_Text_to_Image.py", label="Text to Image", icon="π¨")
|
app.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import subprocess
|
3 |
+
import sys
|
4 |
+
|
5 |
+
try:
|
6 |
+
subprocess.run([sys.executable, "-m", "streamlit", "run", "0_π _Home.py"], check=True)
|
7 |
+
except KeyboardInterrupt:
|
8 |
+
sys.exit(0)
|
9 |
+
except subprocess.CalledProcessError as e:
|
10 |
+
sys.exit(e.returncode)
|
lib/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .config import Config
|
2 |
+
from .presets import Presets
|
3 |
+
|
4 |
+
__all__ = ["Config", "Presets"]
|
lib/config.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from types import SimpleNamespace
|
2 |
+
|
3 |
+
Config = SimpleNamespace(
|
4 |
+
TITLE="API Inference",
|
5 |
+
ICON="β‘",
|
6 |
+
LAYOUT="wide",
|
7 |
+
TXT2IMG_NEGATIVE_PROMPT="ugly, bad, asymmetrical, malformed, mutated, disgusting, blurry, grainy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, autograph, trademark",
|
8 |
+
TXT2IMG_DEFAULT_INDEX=2,
|
9 |
+
TXT2IMG_MODELS=[
|
10 |
+
"black-forest-labs/flux.1-dev",
|
11 |
+
"black-forest-labs/flux.1-schnell",
|
12 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
13 |
+
],
|
14 |
+
TXT2TXT_DEFAULT_INDEX=4,
|
15 |
+
TXT2TXT_MODELS=[
|
16 |
+
"codellama/codellama-34b-instruct-hf",
|
17 |
+
"meta-llama/llama-2-13b-chat-hf",
|
18 |
+
"meta-llama/meta-llama-3.1-405b-instruct-fp8",
|
19 |
+
"mistralai/mistral-7b-instruct-v0.2",
|
20 |
+
"nousresearch/nous-hermes-2-mixtral-8x7b-dpo",
|
21 |
+
],
|
22 |
+
)
|
lib/presets.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from types import SimpleNamespace
|
2 |
+
|
3 |
+
Presets = SimpleNamespace(
|
4 |
+
FLUX_1_DEV={
|
5 |
+
"name": "FLUX.1 Dev",
|
6 |
+
"num_inference_steps": 30,
|
7 |
+
"num_inference_steps_min": 10,
|
8 |
+
"num_inference_steps_max": 40,
|
9 |
+
"guidance_scale": 3.5,
|
10 |
+
"guidance_scale_min": 1.0,
|
11 |
+
"guidance_scale_max": 7.0,
|
12 |
+
"parameters": ["width", "height", "guidance_scale", "num_inference_steps"],
|
13 |
+
"kwargs": {"max_sequence_length": 512},
|
14 |
+
},
|
15 |
+
FLUX_1_SCHNELL={
|
16 |
+
"name": "FLUX.1 Schnell",
|
17 |
+
"num_inference_steps": 4,
|
18 |
+
"num_inference_steps_min": 1,
|
19 |
+
"num_inference_steps_max": 8,
|
20 |
+
"parameters": ["width", "height", "num_inference_steps"],
|
21 |
+
"kwargs": {"guidance_scale": 0.0, "max_sequence_length": 256},
|
22 |
+
},
|
23 |
+
STABLE_DIFFUSION_XL={
|
24 |
+
"name": "SDXL",
|
25 |
+
"guidance_scale": 7.0,
|
26 |
+
"guidance_scale_min": 1.0,
|
27 |
+
"guidance_scale_max": 15.0,
|
28 |
+
"num_inference_steps": 40,
|
29 |
+
"num_inference_steps_min": 10,
|
30 |
+
"num_inference_steps_max": 50,
|
31 |
+
"parameters": ["width", "height", "guidance_scale", "num_inference_steps", "seed", "negative_prompt"],
|
32 |
+
},
|
33 |
+
)
|
logo.svg
ADDED
pages/1_π¬_Text_Generation.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from datetime import datetime
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
from openai import APIError, OpenAI
|
6 |
+
|
7 |
+
from lib import Config
|
8 |
+
|
9 |
+
# TODO: key input and store in cache_data
|
10 |
+
# api key
|
11 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
12 |
+
|
13 |
+
|
14 |
+
# TODO: eventually support different APIs like OpenAI and Perplexity
|
15 |
+
@st.cache_resource
|
16 |
+
def get_chat_client(api_key, model):
|
17 |
+
client = OpenAI(
|
18 |
+
api_key=api_key,
|
19 |
+
base_url=f"https://api-inference.huggingface.co/models/{model}/v1",
|
20 |
+
)
|
21 |
+
return client
|
22 |
+
|
23 |
+
|
24 |
+
# config
|
25 |
+
st.set_page_config(
|
26 |
+
page_title=f"{Config.TITLE} | Text Generation",
|
27 |
+
page_icon=Config.ICON,
|
28 |
+
layout=Config.LAYOUT,
|
29 |
+
)
|
30 |
+
|
31 |
+
# initialize state
|
32 |
+
if "txt2txt_messages" not in st.session_state:
|
33 |
+
st.session_state.txt2txt_messages = []
|
34 |
+
|
35 |
+
if "txt2txt_prompt" not in st.session_state:
|
36 |
+
st.session_state.txt2txt_prompt = ""
|
37 |
+
|
38 |
+
# sidebar
|
39 |
+
st.logo("logo.svg")
|
40 |
+
st.sidebar.header("Settings")
|
41 |
+
model = st.sidebar.selectbox(
|
42 |
+
"Model",
|
43 |
+
placeholder="Select a model",
|
44 |
+
format_func=lambda x: x.split("/")[1],
|
45 |
+
index=Config.TXT2TXT_DEFAULT_INDEX,
|
46 |
+
options=Config.TXT2TXT_MODELS,
|
47 |
+
)
|
48 |
+
max_tokens = st.sidebar.slider(
|
49 |
+
"Max Tokens",
|
50 |
+
min_value=512,
|
51 |
+
max_value=4096,
|
52 |
+
value=512,
|
53 |
+
step=128,
|
54 |
+
help="Maximum number of tokens to generate (default: 512)",
|
55 |
+
)
|
56 |
+
temperature = st.sidebar.slider(
|
57 |
+
"Temperature",
|
58 |
+
min_value=0.0,
|
59 |
+
max_value=2.0,
|
60 |
+
value=1.0,
|
61 |
+
step=0.1,
|
62 |
+
help="Used to modulate the next token probabilities (default: 1.0)",
|
63 |
+
)
|
64 |
+
frequency_penalty = st.sidebar.slider(
|
65 |
+
"Frequency Penalty",
|
66 |
+
min_value=-2.0,
|
67 |
+
max_value=2.0,
|
68 |
+
value=0.0,
|
69 |
+
step=0.1,
|
70 |
+
help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
|
71 |
+
)
|
72 |
+
seed = st.sidebar.number_input(
|
73 |
+
"Seed",
|
74 |
+
min_value=-1,
|
75 |
+
max_value=(1 << 53) - 1,
|
76 |
+
value=-1,
|
77 |
+
help="Make a best effort to sample deterministically (default: -1)",
|
78 |
+
)
|
79 |
+
system = st.sidebar.text_area(
|
80 |
+
"System Message",
|
81 |
+
value="You are a helpful assistant. Be precise and concise.",
|
82 |
+
)
|
83 |
+
|
84 |
+
# random seed
|
85 |
+
if seed < 0:
|
86 |
+
seed = int(datetime.now().timestamp() * 1e6) % (1 << 53)
|
87 |
+
|
88 |
+
# heading
|
89 |
+
st.html("""
|
90 |
+
<h1 style="padding: 0; margin-bottom: 0.5rem">Text Generation</h1>
|
91 |
+
<p>Chat with large language models.</p>
|
92 |
+
""")
|
93 |
+
|
94 |
+
# chat messages
|
95 |
+
for message in st.session_state.txt2txt_messages:
|
96 |
+
with st.chat_message(message["role"]):
|
97 |
+
st.markdown(message["content"])
|
98 |
+
|
99 |
+
# button row
|
100 |
+
if st.session_state.txt2txt_messages:
|
101 |
+
button_container = st.empty()
|
102 |
+
with button_container.container():
|
103 |
+
# https://discuss.streamlit.io/t/st-button-in-one-line/25966/6
|
104 |
+
st.html("""
|
105 |
+
<style>
|
106 |
+
div[data-testid="column"] {
|
107 |
+
width: fit-content;
|
108 |
+
min-width: 0;
|
109 |
+
flex: none;
|
110 |
+
}
|
111 |
+
</style>
|
112 |
+
""")
|
113 |
+
|
114 |
+
# remove last assistant message and resend prompt
|
115 |
+
col1, col2, col3 = st.columns(3)
|
116 |
+
with col1:
|
117 |
+
if st.button("ποΈ", help="Retry last message") and len(st.session_state.txt2txt_messages) >= 2:
|
118 |
+
st.session_state.txt2txt_messages.pop()
|
119 |
+
st.session_state.txt2txt_prompt = st.session_state.txt2txt_messages.pop()["content"]
|
120 |
+
st.rerun()
|
121 |
+
|
122 |
+
# delete last message pair
|
123 |
+
with col2:
|
124 |
+
if st.button("β", help="Delete last message") and len(st.session_state.txt2txt_messages) >= 2:
|
125 |
+
st.session_state.txt2txt_messages.pop()
|
126 |
+
st.session_state.txt2txt_messages.pop()
|
127 |
+
st.rerun()
|
128 |
+
|
129 |
+
# reset app state
|
130 |
+
with col3:
|
131 |
+
if st.button("ποΈ", help="Clear all messages"):
|
132 |
+
st.session_state.txt2txt_messages = []
|
133 |
+
st.session_state.txt2txt_prompt = ""
|
134 |
+
st.rerun()
|
135 |
+
else:
|
136 |
+
button_container = None
|
137 |
+
|
138 |
+
# chat input
|
139 |
+
if prompt := st.chat_input("What would you like to know?"):
|
140 |
+
st.session_state.txt2txt_prompt = prompt
|
141 |
+
|
142 |
+
if st.session_state.txt2txt_prompt:
|
143 |
+
with st.chat_message("user"):
|
144 |
+
st.markdown(st.session_state.txt2txt_prompt)
|
145 |
+
|
146 |
+
if button_container:
|
147 |
+
button_container.empty()
|
148 |
+
|
149 |
+
messages = [{"role": "system", "content": system}]
|
150 |
+
messages.extend([{"role": m["role"], "content": m["content"]} for m in st.session_state.txt2txt_messages])
|
151 |
+
messages.append({"role": "user", "content": st.session_state.txt2txt_prompt})
|
152 |
+
|
153 |
+
with st.chat_message("assistant"):
|
154 |
+
try:
|
155 |
+
client = get_chat_client(HF_TOKEN, model)
|
156 |
+
stream = client.chat.completions.create(
|
157 |
+
frequency_penalty=frequency_penalty,
|
158 |
+
temperature=temperature,
|
159 |
+
max_tokens=max_tokens,
|
160 |
+
messages=messages,
|
161 |
+
model=model,
|
162 |
+
stream=True,
|
163 |
+
seed=seed,
|
164 |
+
)
|
165 |
+
response = st.write_stream(stream)
|
166 |
+
except APIError as e:
|
167 |
+
response = e.message
|
168 |
+
except Exception as e:
|
169 |
+
response = str(e)
|
170 |
+
|
171 |
+
st.session_state.txt2txt_messages.append({"role": "user", "content": st.session_state.txt2txt_prompt})
|
172 |
+
st.session_state.txt2txt_messages.append({"role": "assistant", "content": response})
|
173 |
+
st.session_state.txt2txt_prompt = ""
|
174 |
+
st.rerun()
|
pages/2_π¨_Text_to_Image.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
from datetime import datetime
|
4 |
+
|
5 |
+
import requests
|
6 |
+
import streamlit as st
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from lib import Config, Presets
|
10 |
+
|
11 |
+
# TODO: key input and store in cache_data
|
12 |
+
# TODO: API dropdown; changes available models
|
13 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
14 |
+
API_URL = "https://api-inference.huggingface.co/models"
|
15 |
+
HEADERS = {"Authorization": f"Bearer {HF_TOKEN}", "X-Wait-For-Model": "true", "X-Use-Cache": "false"}
|
16 |
+
SIZE_AR = {
|
17 |
+
"9:7": (1152, 896),
|
18 |
+
"7:4": (1344, 768),
|
19 |
+
"1:1": (1024, 1024),
|
20 |
+
"4:7": (768, 1344),
|
21 |
+
"7:9": (896, 1152),
|
22 |
+
}
|
23 |
+
PRESET_MODEL = {
|
24 |
+
"black-forest-labs/flux.1-dev": Presets.FLUX_1_DEV,
|
25 |
+
"black-forest-labs/flux.1-schnell": Presets.FLUX_1_SCHNELL,
|
26 |
+
"stabilityai/stable-diffusion-xl-base-1.0": Presets.STABLE_DIFFUSION_XL,
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
def generate_image(model, prompt, parameters, **kwargs):
|
31 |
+
response = requests.post(
|
32 |
+
f"{API_URL}/{model}",
|
33 |
+
headers=HEADERS,
|
34 |
+
json={
|
35 |
+
"inputs": prompt,
|
36 |
+
"parameters": {**parameters, **kwargs},
|
37 |
+
},
|
38 |
+
)
|
39 |
+
|
40 |
+
if response.status_code == 200:
|
41 |
+
image = Image.open(io.BytesIO(response.content))
|
42 |
+
return image
|
43 |
+
else:
|
44 |
+
st.error(f"Error: {response.status_code} - {response.text}")
|
45 |
+
return None
|
46 |
+
|
47 |
+
|
48 |
+
# config
|
49 |
+
st.set_page_config(
|
50 |
+
page_title=f"{Config.TITLE} | Text to Image",
|
51 |
+
page_icon=Config.ICON,
|
52 |
+
layout=Config.LAYOUT,
|
53 |
+
)
|
54 |
+
|
55 |
+
# initialize state
|
56 |
+
if "txt2img_messages" not in st.session_state:
|
57 |
+
st.session_state.txt2img_messages = []
|
58 |
+
|
59 |
+
if "txt2img_seed" not in st.session_state:
|
60 |
+
st.session_state.txt2img_seed = 0
|
61 |
+
|
62 |
+
# sidebar
|
63 |
+
st.logo("logo.svg")
|
64 |
+
st.sidebar.header("Settings")
|
65 |
+
model = st.sidebar.selectbox(
|
66 |
+
"Model",
|
67 |
+
format_func=lambda x: x.split("/")[1],
|
68 |
+
options=Config.TXT2IMG_MODELS,
|
69 |
+
index=Config.TXT2IMG_DEFAULT_INDEX,
|
70 |
+
)
|
71 |
+
aspect_ratio = st.sidebar.select_slider(
|
72 |
+
"Aspect Ratio",
|
73 |
+
options=list(SIZE_AR.keys()),
|
74 |
+
value=list(SIZE_AR.keys())[1],
|
75 |
+
)
|
76 |
+
|
77 |
+
# heading
|
78 |
+
st.html("""
|
79 |
+
<h1 style="padding: 0; margin-bottom: 0.5rem">Text to Image</h1>
|
80 |
+
<p>Generate an image from a text prompt.</p>
|
81 |
+
""")
|
82 |
+
|
83 |
+
# build parameters from preset
|
84 |
+
parameters = {}
|
85 |
+
preset = PRESET_MODEL[model]
|
86 |
+
for param in preset["parameters"]:
|
87 |
+
if param == "width":
|
88 |
+
parameters[param] = SIZE_AR[aspect_ratio][0]
|
89 |
+
if param == "height":
|
90 |
+
parameters[param] = SIZE_AR[aspect_ratio][1]
|
91 |
+
if param == "guidance_scale":
|
92 |
+
parameters[param] = st.sidebar.slider(
|
93 |
+
"Guidance Scale",
|
94 |
+
preset["guidance_scale_min"],
|
95 |
+
preset["guidance_scale_max"],
|
96 |
+
preset["guidance_scale"],
|
97 |
+
0.1,
|
98 |
+
)
|
99 |
+
if param == "num_inference_steps":
|
100 |
+
parameters[param] = st.sidebar.slider(
|
101 |
+
"Inference Steps",
|
102 |
+
preset["num_inference_steps_min"],
|
103 |
+
preset["num_inference_steps_max"],
|
104 |
+
preset["num_inference_steps"],
|
105 |
+
1,
|
106 |
+
)
|
107 |
+
if param == "seed":
|
108 |
+
parameters[param] = st.sidebar.number_input(
|
109 |
+
"Seed",
|
110 |
+
min_value=-1,
|
111 |
+
max_value=(1 << 53) - 1,
|
112 |
+
value=-1,
|
113 |
+
)
|
114 |
+
if param == "negative_prompt":
|
115 |
+
parameters[param] = st.sidebar.text_area(
|
116 |
+
label="Negative Prompt",
|
117 |
+
value=Config.TXT2IMG_NEGATIVE_PROMPT,
|
118 |
+
)
|
119 |
+
|
120 |
+
# wrap the prompt in an expander to display additional parameters
|
121 |
+
for message in st.session_state.txt2img_messages:
|
122 |
+
role = message["role"]
|
123 |
+
with st.chat_message(role):
|
124 |
+
image_container = st.empty()
|
125 |
+
|
126 |
+
with image_container.container():
|
127 |
+
if role == "user":
|
128 |
+
with st.expander(message["content"]):
|
129 |
+
# build a markdown string for additional parameters
|
130 |
+
st.html("""
|
131 |
+
<style>
|
132 |
+
div[data-testid="stMarkdownContainer"] p:not(:last-of-type) { margin-bottom: 0 }
|
133 |
+
</style>
|
134 |
+
""")
|
135 |
+
md = f"`model`: {message['model']}\n\n"
|
136 |
+
md += "\n\n".join([f"`{k}`: {v}" for k, v in message["parameters"].items()])
|
137 |
+
st.markdown(md)
|
138 |
+
|
139 |
+
if role == "assistant":
|
140 |
+
# image is full width when _not_ in full-screen mode
|
141 |
+
st.html("""
|
142 |
+
<style>
|
143 |
+
div[data-testid="stImage"]:has(img[style*="max-width: 100%"]) {
|
144 |
+
height: auto;
|
145 |
+
max-width: 512px;
|
146 |
+
}
|
147 |
+
div[data-testid="stImage"] img[style*="max-width: 100%"] {
|
148 |
+
border-radius: 8px;
|
149 |
+
}
|
150 |
+
</style>
|
151 |
+
""")
|
152 |
+
st.image(message["content"])
|
153 |
+
|
154 |
+
# button row
|
155 |
+
if st.session_state.txt2img_messages:
|
156 |
+
button_container = st.empty()
|
157 |
+
with button_container.container():
|
158 |
+
# https://discuss.streamlit.io/t/st-button-in-one-line/25966/6
|
159 |
+
st.html("""
|
160 |
+
<style>
|
161 |
+
div[data-testid="column"] {
|
162 |
+
width: fit-content;
|
163 |
+
min-width: 0;
|
164 |
+
flex: none;
|
165 |
+
}
|
166 |
+
</style>
|
167 |
+
""")
|
168 |
+
|
169 |
+
# retry
|
170 |
+
col1, col2 = st.columns(2)
|
171 |
+
with col1:
|
172 |
+
if st.button("β", help="Delete last generation") and len(st.session_state.txt2img_messages) >= 2:
|
173 |
+
st.session_state.txt2img_messages.pop()
|
174 |
+
st.session_state.txt2img_messages.pop()
|
175 |
+
st.rerun()
|
176 |
+
|
177 |
+
with col2:
|
178 |
+
if st.button("ποΈ", help="Clear all generations"):
|
179 |
+
st.session_state.txt2img_messages = []
|
180 |
+
st.session_state.txt2img_seed = 0
|
181 |
+
st.rerun()
|
182 |
+
else:
|
183 |
+
button_container = None
|
184 |
+
|
185 |
+
# show the prompt and spinner while loading then update state and re-render
|
186 |
+
if prompt := st.chat_input("What do you want to see?"):
|
187 |
+
if "seed" in parameters and parameters["seed"] >= 0:
|
188 |
+
st.session_state.txt2img_seed = parameters["seed"]
|
189 |
+
else:
|
190 |
+
st.session_state.txt2img_seed = int(datetime.now().timestamp() * 1e6) % (1 << 53)
|
191 |
+
if "seed" in parameters:
|
192 |
+
parameters["seed"] = st.session_state.txt2img_seed
|
193 |
+
|
194 |
+
if button_container:
|
195 |
+
button_container.empty()
|
196 |
+
|
197 |
+
with st.chat_message("user"):
|
198 |
+
st.markdown(prompt)
|
199 |
+
|
200 |
+
with st.chat_message("assistant"):
|
201 |
+
with st.spinner("Running..."):
|
202 |
+
generate_kwargs = {"model": model, "prompt": prompt, "parameters": parameters}
|
203 |
+
if preset.get("kwargs") is not None:
|
204 |
+
generate_kwargs.update(preset["kwargs"])
|
205 |
+
image = generate_image(**generate_kwargs)
|
206 |
+
|
207 |
+
model_name = PRESET_MODEL[model]["name"]
|
208 |
+
st.session_state.txt2img_messages.append(
|
209 |
+
{"role": "user", "content": prompt, "parameters": parameters, "model": model_name}
|
210 |
+
)
|
211 |
+
st.session_state.txt2img_messages.append({"role": "assistant", "content": image})
|
212 |
+
st.rerun()
|