adamelliotfields commited on
Commit
3c7025e
β€’
1 Parent(s): 7c829e7
0_🏠_Home.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from lib import Config
4
+
5
+ st.set_page_config(
6
+ page_title=Config.TITLE,
7
+ page_icon=Config.ICON,
8
+ layout=Config.LAYOUT,
9
+ )
10
+
11
+ # sidebar
12
+ st.logo("logo.svg")
13
+
14
+ # title
15
+ st.html("""
16
+ <style>
17
+ .pro-badge {
18
+ display: inline-block;
19
+ transform: skew(-12deg);
20
+ font-size: 0.875rem;
21
+ line-height: 1.25rem;
22
+ font-weight: 700;
23
+ padding: 0.125rem 0.625rem;
24
+ border-radius: 0.5rem;
25
+ color: rgb(0 0 0 / 1);
26
+ box-shadow: 0 0 #0000, 0 0 #0000, 0 10px 15px -3px rgb(16 185 129 / .1), 0 4px 6px -4px rgb(16 185 129 / .1);
27
+ background-image: linear-gradient(to bottom right, #f9a8d4, #a7f3d0, #fde68a);
28
+ border: 1px solid rgb(229 231 235 / 1);
29
+ }
30
+ @media (prefers-color-scheme: dark) {
31
+ .pro-badge {
32
+ box-shadow: 0 0 #0000, 0 0 #0000, 0 10px 15px -3px rgb(16 185 129 / .2), 0 4px 6px -4px rgb(16 185 129 / .2);
33
+ background-image: linear-gradient(to bottom right, #ec4899, #10b981, #f59e0b);
34
+ border: 1px solid rgb(20 28 46 / 1);
35
+ }
36
+ }
37
+ </style>
38
+ <div style="display: flex; align-items: center; gap: 0.75rem">
39
+ <h1 style="padding: 0; margin-bottom: 0.5rem">API Inference</h1>
40
+ <span class="pro-badge">PRO</span>
41
+ </div>
42
+ <p>Run inference on API endpoints. Hugging Face for now; more coming soon!</p>
43
+ """)
44
+
45
+ # TODO: categorize tasks by service (SAI, FAL, etc)
46
+ # content
47
+ st.markdown("## Tasks")
48
+ st.page_link("pages/1_πŸ’¬_Text_Generation.py", label="Text Generation", icon="πŸ’¬")
49
+ st.page_link("pages/2_🎨_Text_to_Image.py", label="Text to Image", icon="🎨")
app.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import subprocess
3
+ import sys
4
+
5
+ try:
6
+ subprocess.run([sys.executable, "-m", "streamlit", "run", "0_🏠_Home.py"], check=True)
7
+ except KeyboardInterrupt:
8
+ sys.exit(0)
9
+ except subprocess.CalledProcessError as e:
10
+ sys.exit(e.returncode)
lib/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .config import Config
2
+ from .presets import Presets
3
+
4
+ __all__ = ["Config", "Presets"]
lib/config.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import SimpleNamespace
2
+
3
+ Config = SimpleNamespace(
4
+ TITLE="API Inference",
5
+ ICON="⚑",
6
+ LAYOUT="wide",
7
+ TXT2IMG_NEGATIVE_PROMPT="ugly, bad, asymmetrical, malformed, mutated, disgusting, blurry, grainy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, autograph, trademark",
8
+ TXT2IMG_DEFAULT_INDEX=2,
9
+ TXT2IMG_MODELS=[
10
+ "black-forest-labs/flux.1-dev",
11
+ "black-forest-labs/flux.1-schnell",
12
+ "stabilityai/stable-diffusion-xl-base-1.0",
13
+ ],
14
+ TXT2TXT_DEFAULT_INDEX=4,
15
+ TXT2TXT_MODELS=[
16
+ "codellama/codellama-34b-instruct-hf",
17
+ "meta-llama/llama-2-13b-chat-hf",
18
+ "meta-llama/meta-llama-3.1-405b-instruct-fp8",
19
+ "mistralai/mistral-7b-instruct-v0.2",
20
+ "nousresearch/nous-hermes-2-mixtral-8x7b-dpo",
21
+ ],
22
+ )
lib/presets.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import SimpleNamespace
2
+
3
+ Presets = SimpleNamespace(
4
+ FLUX_1_DEV={
5
+ "name": "FLUX.1 Dev",
6
+ "num_inference_steps": 30,
7
+ "num_inference_steps_min": 10,
8
+ "num_inference_steps_max": 40,
9
+ "guidance_scale": 3.5,
10
+ "guidance_scale_min": 1.0,
11
+ "guidance_scale_max": 7.0,
12
+ "parameters": ["width", "height", "guidance_scale", "num_inference_steps"],
13
+ "kwargs": {"max_sequence_length": 512},
14
+ },
15
+ FLUX_1_SCHNELL={
16
+ "name": "FLUX.1 Schnell",
17
+ "num_inference_steps": 4,
18
+ "num_inference_steps_min": 1,
19
+ "num_inference_steps_max": 8,
20
+ "parameters": ["width", "height", "num_inference_steps"],
21
+ "kwargs": {"guidance_scale": 0.0, "max_sequence_length": 256},
22
+ },
23
+ STABLE_DIFFUSION_XL={
24
+ "name": "SDXL",
25
+ "guidance_scale": 7.0,
26
+ "guidance_scale_min": 1.0,
27
+ "guidance_scale_max": 15.0,
28
+ "num_inference_steps": 40,
29
+ "num_inference_steps_min": 10,
30
+ "num_inference_steps_max": 50,
31
+ "parameters": ["width", "height", "guidance_scale", "num_inference_steps", "seed", "negative_prompt"],
32
+ },
33
+ )
logo.svg ADDED
pages/1_πŸ’¬_Text_Generation.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime
3
+
4
+ import streamlit as st
5
+ from openai import APIError, OpenAI
6
+
7
+ from lib import Config
8
+
9
+ # TODO: key input and store in cache_data
10
+ # api key
11
+ HF_TOKEN = os.environ.get("HF_TOKEN")
12
+
13
+
14
+ # TODO: eventually support different APIs like OpenAI and Perplexity
15
+ @st.cache_resource
16
+ def get_chat_client(api_key, model):
17
+ client = OpenAI(
18
+ api_key=api_key,
19
+ base_url=f"https://api-inference.huggingface.co/models/{model}/v1",
20
+ )
21
+ return client
22
+
23
+
24
+ # config
25
+ st.set_page_config(
26
+ page_title=f"{Config.TITLE} | Text Generation",
27
+ page_icon=Config.ICON,
28
+ layout=Config.LAYOUT,
29
+ )
30
+
31
+ # initialize state
32
+ if "txt2txt_messages" not in st.session_state:
33
+ st.session_state.txt2txt_messages = []
34
+
35
+ if "txt2txt_prompt" not in st.session_state:
36
+ st.session_state.txt2txt_prompt = ""
37
+
38
+ # sidebar
39
+ st.logo("logo.svg")
40
+ st.sidebar.header("Settings")
41
+ model = st.sidebar.selectbox(
42
+ "Model",
43
+ placeholder="Select a model",
44
+ format_func=lambda x: x.split("/")[1],
45
+ index=Config.TXT2TXT_DEFAULT_INDEX,
46
+ options=Config.TXT2TXT_MODELS,
47
+ )
48
+ max_tokens = st.sidebar.slider(
49
+ "Max Tokens",
50
+ min_value=512,
51
+ max_value=4096,
52
+ value=512,
53
+ step=128,
54
+ help="Maximum number of tokens to generate (default: 512)",
55
+ )
56
+ temperature = st.sidebar.slider(
57
+ "Temperature",
58
+ min_value=0.0,
59
+ max_value=2.0,
60
+ value=1.0,
61
+ step=0.1,
62
+ help="Used to modulate the next token probabilities (default: 1.0)",
63
+ )
64
+ frequency_penalty = st.sidebar.slider(
65
+ "Frequency Penalty",
66
+ min_value=-2.0,
67
+ max_value=2.0,
68
+ value=0.0,
69
+ step=0.1,
70
+ help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
71
+ )
72
+ seed = st.sidebar.number_input(
73
+ "Seed",
74
+ min_value=-1,
75
+ max_value=(1 << 53) - 1,
76
+ value=-1,
77
+ help="Make a best effort to sample deterministically (default: -1)",
78
+ )
79
+ system = st.sidebar.text_area(
80
+ "System Message",
81
+ value="You are a helpful assistant. Be precise and concise.",
82
+ )
83
+
84
+ # random seed
85
+ if seed < 0:
86
+ seed = int(datetime.now().timestamp() * 1e6) % (1 << 53)
87
+
88
+ # heading
89
+ st.html("""
90
+ <h1 style="padding: 0; margin-bottom: 0.5rem">Text Generation</h1>
91
+ <p>Chat with large language models.</p>
92
+ """)
93
+
94
+ # chat messages
95
+ for message in st.session_state.txt2txt_messages:
96
+ with st.chat_message(message["role"]):
97
+ st.markdown(message["content"])
98
+
99
+ # button row
100
+ if st.session_state.txt2txt_messages:
101
+ button_container = st.empty()
102
+ with button_container.container():
103
+ # https://discuss.streamlit.io/t/st-button-in-one-line/25966/6
104
+ st.html("""
105
+ <style>
106
+ div[data-testid="column"] {
107
+ width: fit-content;
108
+ min-width: 0;
109
+ flex: none;
110
+ }
111
+ </style>
112
+ """)
113
+
114
+ # remove last assistant message and resend prompt
115
+ col1, col2, col3 = st.columns(3)
116
+ with col1:
117
+ if st.button("πŸ”„οΈ", help="Retry last message") and len(st.session_state.txt2txt_messages) >= 2:
118
+ st.session_state.txt2txt_messages.pop()
119
+ st.session_state.txt2txt_prompt = st.session_state.txt2txt_messages.pop()["content"]
120
+ st.rerun()
121
+
122
+ # delete last message pair
123
+ with col2:
124
+ if st.button("❌", help="Delete last message") and len(st.session_state.txt2txt_messages) >= 2:
125
+ st.session_state.txt2txt_messages.pop()
126
+ st.session_state.txt2txt_messages.pop()
127
+ st.rerun()
128
+
129
+ # reset app state
130
+ with col3:
131
+ if st.button("πŸ—‘οΈ", help="Clear all messages"):
132
+ st.session_state.txt2txt_messages = []
133
+ st.session_state.txt2txt_prompt = ""
134
+ st.rerun()
135
+ else:
136
+ button_container = None
137
+
138
+ # chat input
139
+ if prompt := st.chat_input("What would you like to know?"):
140
+ st.session_state.txt2txt_prompt = prompt
141
+
142
+ if st.session_state.txt2txt_prompt:
143
+ with st.chat_message("user"):
144
+ st.markdown(st.session_state.txt2txt_prompt)
145
+
146
+ if button_container:
147
+ button_container.empty()
148
+
149
+ messages = [{"role": "system", "content": system}]
150
+ messages.extend([{"role": m["role"], "content": m["content"]} for m in st.session_state.txt2txt_messages])
151
+ messages.append({"role": "user", "content": st.session_state.txt2txt_prompt})
152
+
153
+ with st.chat_message("assistant"):
154
+ try:
155
+ client = get_chat_client(HF_TOKEN, model)
156
+ stream = client.chat.completions.create(
157
+ frequency_penalty=frequency_penalty,
158
+ temperature=temperature,
159
+ max_tokens=max_tokens,
160
+ messages=messages,
161
+ model=model,
162
+ stream=True,
163
+ seed=seed,
164
+ )
165
+ response = st.write_stream(stream)
166
+ except APIError as e:
167
+ response = e.message
168
+ except Exception as e:
169
+ response = str(e)
170
+
171
+ st.session_state.txt2txt_messages.append({"role": "user", "content": st.session_state.txt2txt_prompt})
172
+ st.session_state.txt2txt_messages.append({"role": "assistant", "content": response})
173
+ st.session_state.txt2txt_prompt = ""
174
+ st.rerun()
pages/2_🎨_Text_to_Image.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ from datetime import datetime
4
+
5
+ import requests
6
+ import streamlit as st
7
+ from PIL import Image
8
+
9
+ from lib import Config, Presets
10
+
11
+ # TODO: key input and store in cache_data
12
+ # TODO: API dropdown; changes available models
13
+ HF_TOKEN = os.environ.get("HF_TOKEN")
14
+ API_URL = "https://api-inference.huggingface.co/models"
15
+ HEADERS = {"Authorization": f"Bearer {HF_TOKEN}", "X-Wait-For-Model": "true", "X-Use-Cache": "false"}
16
+ SIZE_AR = {
17
+ "9:7": (1152, 896),
18
+ "7:4": (1344, 768),
19
+ "1:1": (1024, 1024),
20
+ "4:7": (768, 1344),
21
+ "7:9": (896, 1152),
22
+ }
23
+ PRESET_MODEL = {
24
+ "black-forest-labs/flux.1-dev": Presets.FLUX_1_DEV,
25
+ "black-forest-labs/flux.1-schnell": Presets.FLUX_1_SCHNELL,
26
+ "stabilityai/stable-diffusion-xl-base-1.0": Presets.STABLE_DIFFUSION_XL,
27
+ }
28
+
29
+
30
+ def generate_image(model, prompt, parameters, **kwargs):
31
+ response = requests.post(
32
+ f"{API_URL}/{model}",
33
+ headers=HEADERS,
34
+ json={
35
+ "inputs": prompt,
36
+ "parameters": {**parameters, **kwargs},
37
+ },
38
+ )
39
+
40
+ if response.status_code == 200:
41
+ image = Image.open(io.BytesIO(response.content))
42
+ return image
43
+ else:
44
+ st.error(f"Error: {response.status_code} - {response.text}")
45
+ return None
46
+
47
+
48
+ # config
49
+ st.set_page_config(
50
+ page_title=f"{Config.TITLE} | Text to Image",
51
+ page_icon=Config.ICON,
52
+ layout=Config.LAYOUT,
53
+ )
54
+
55
+ # initialize state
56
+ if "txt2img_messages" not in st.session_state:
57
+ st.session_state.txt2img_messages = []
58
+
59
+ if "txt2img_seed" not in st.session_state:
60
+ st.session_state.txt2img_seed = 0
61
+
62
+ # sidebar
63
+ st.logo("logo.svg")
64
+ st.sidebar.header("Settings")
65
+ model = st.sidebar.selectbox(
66
+ "Model",
67
+ format_func=lambda x: x.split("/")[1],
68
+ options=Config.TXT2IMG_MODELS,
69
+ index=Config.TXT2IMG_DEFAULT_INDEX,
70
+ )
71
+ aspect_ratio = st.sidebar.select_slider(
72
+ "Aspect Ratio",
73
+ options=list(SIZE_AR.keys()),
74
+ value=list(SIZE_AR.keys())[1],
75
+ )
76
+
77
+ # heading
78
+ st.html("""
79
+ <h1 style="padding: 0; margin-bottom: 0.5rem">Text to Image</h1>
80
+ <p>Generate an image from a text prompt.</p>
81
+ """)
82
+
83
+ # build parameters from preset
84
+ parameters = {}
85
+ preset = PRESET_MODEL[model]
86
+ for param in preset["parameters"]:
87
+ if param == "width":
88
+ parameters[param] = SIZE_AR[aspect_ratio][0]
89
+ if param == "height":
90
+ parameters[param] = SIZE_AR[aspect_ratio][1]
91
+ if param == "guidance_scale":
92
+ parameters[param] = st.sidebar.slider(
93
+ "Guidance Scale",
94
+ preset["guidance_scale_min"],
95
+ preset["guidance_scale_max"],
96
+ preset["guidance_scale"],
97
+ 0.1,
98
+ )
99
+ if param == "num_inference_steps":
100
+ parameters[param] = st.sidebar.slider(
101
+ "Inference Steps",
102
+ preset["num_inference_steps_min"],
103
+ preset["num_inference_steps_max"],
104
+ preset["num_inference_steps"],
105
+ 1,
106
+ )
107
+ if param == "seed":
108
+ parameters[param] = st.sidebar.number_input(
109
+ "Seed",
110
+ min_value=-1,
111
+ max_value=(1 << 53) - 1,
112
+ value=-1,
113
+ )
114
+ if param == "negative_prompt":
115
+ parameters[param] = st.sidebar.text_area(
116
+ label="Negative Prompt",
117
+ value=Config.TXT2IMG_NEGATIVE_PROMPT,
118
+ )
119
+
120
+ # wrap the prompt in an expander to display additional parameters
121
+ for message in st.session_state.txt2img_messages:
122
+ role = message["role"]
123
+ with st.chat_message(role):
124
+ image_container = st.empty()
125
+
126
+ with image_container.container():
127
+ if role == "user":
128
+ with st.expander(message["content"]):
129
+ # build a markdown string for additional parameters
130
+ st.html("""
131
+ <style>
132
+ div[data-testid="stMarkdownContainer"] p:not(:last-of-type) { margin-bottom: 0 }
133
+ </style>
134
+ """)
135
+ md = f"`model`: {message['model']}\n\n"
136
+ md += "\n\n".join([f"`{k}`: {v}" for k, v in message["parameters"].items()])
137
+ st.markdown(md)
138
+
139
+ if role == "assistant":
140
+ # image is full width when _not_ in full-screen mode
141
+ st.html("""
142
+ <style>
143
+ div[data-testid="stImage"]:has(img[style*="max-width: 100%"]) {
144
+ height: auto;
145
+ max-width: 512px;
146
+ }
147
+ div[data-testid="stImage"] img[style*="max-width: 100%"] {
148
+ border-radius: 8px;
149
+ }
150
+ </style>
151
+ """)
152
+ st.image(message["content"])
153
+
154
+ # button row
155
+ if st.session_state.txt2img_messages:
156
+ button_container = st.empty()
157
+ with button_container.container():
158
+ # https://discuss.streamlit.io/t/st-button-in-one-line/25966/6
159
+ st.html("""
160
+ <style>
161
+ div[data-testid="column"] {
162
+ width: fit-content;
163
+ min-width: 0;
164
+ flex: none;
165
+ }
166
+ </style>
167
+ """)
168
+
169
+ # retry
170
+ col1, col2 = st.columns(2)
171
+ with col1:
172
+ if st.button("❌", help="Delete last generation") and len(st.session_state.txt2img_messages) >= 2:
173
+ st.session_state.txt2img_messages.pop()
174
+ st.session_state.txt2img_messages.pop()
175
+ st.rerun()
176
+
177
+ with col2:
178
+ if st.button("πŸ—‘οΈ", help="Clear all generations"):
179
+ st.session_state.txt2img_messages = []
180
+ st.session_state.txt2img_seed = 0
181
+ st.rerun()
182
+ else:
183
+ button_container = None
184
+
185
+ # show the prompt and spinner while loading then update state and re-render
186
+ if prompt := st.chat_input("What do you want to see?"):
187
+ if "seed" in parameters and parameters["seed"] >= 0:
188
+ st.session_state.txt2img_seed = parameters["seed"]
189
+ else:
190
+ st.session_state.txt2img_seed = int(datetime.now().timestamp() * 1e6) % (1 << 53)
191
+ if "seed" in parameters:
192
+ parameters["seed"] = st.session_state.txt2img_seed
193
+
194
+ if button_container:
195
+ button_container.empty()
196
+
197
+ with st.chat_message("user"):
198
+ st.markdown(prompt)
199
+
200
+ with st.chat_message("assistant"):
201
+ with st.spinner("Running..."):
202
+ generate_kwargs = {"model": model, "prompt": prompt, "parameters": parameters}
203
+ if preset.get("kwargs") is not None:
204
+ generate_kwargs.update(preset["kwargs"])
205
+ image = generate_image(**generate_kwargs)
206
+
207
+ model_name = PRESET_MODEL[model]["name"]
208
+ st.session_state.txt2img_messages.append(
209
+ {"role": "user", "content": prompt, "parameters": parameters, "model": model_name}
210
+ )
211
+ st.session_state.txt2img_messages.append({"role": "assistant", "content": image})
212
+ st.rerun()