adamelliotfields commited on
Commit
2401ce9
·
verified ·
1 Parent(s): 1367e6b

Convert presets to dataclasses

Browse files
lib/__init__.py CHANGED
@@ -1,11 +1,10 @@
1
  from .api import txt2img_generate, txt2txt_generate
2
  from .config import config
3
- from .presets import ModelPresets, ServicePresets
4
 
5
  __all__ = [
6
  "config",
7
- "ModelPresets",
8
- "ServicePresets",
9
  "txt2img_generate",
10
  "txt2txt_generate",
11
  ]
 
1
  from .api import txt2img_generate, txt2txt_generate
2
  from .config import config
3
+ from .preset import preset
4
 
5
  __all__ = [
6
  "config",
7
+ "preset",
 
8
  "txt2img_generate",
9
  "txt2txt_generate",
10
  ]
lib/preset.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, List, Optional, Union
3
+
4
+
5
+ @dataclass
6
+ class Txt2TxtPreset:
7
+ frequency_penalty: float
8
+ frequency_penalty_min: float
9
+ frequency_penalty_max: float
10
+ parameters: Optional[List[str]] = field(default_factory=list)
11
+
12
+
13
+ @dataclass
14
+ class Txt2ImgPreset:
15
+ # FLUX1.1 has no scale or steps
16
+ name: str
17
+ guidance_scale: Optional[float] = None
18
+ guidance_scale_min: Optional[float] = None
19
+ guidance_scale_max: Optional[float] = None
20
+ num_inference_steps: Optional[int] = None
21
+ num_inference_steps_min: Optional[int] = None
22
+ num_inference_steps_max: Optional[int] = None
23
+ parameters: Optional[List[str]] = field(default_factory=list)
24
+ kwargs: Optional[Dict[str, Union[str, int, float, bool]]] = field(default_factory=dict)
25
+
26
+
27
+ @dataclass
28
+ class Txt2TxtPresets:
29
+ hugging_face: Txt2TxtPreset
30
+ perplexity: Txt2TxtPreset
31
+
32
+
33
+ @dataclass
34
+ class Txt2ImgPresets:
35
+ # bfl
36
+ flux_1_1_pro_bfl: Txt2ImgPreset
37
+ flux_dev_bfl: Txt2ImgPreset
38
+ flux_pro_bfl: Txt2ImgPreset
39
+ # fal
40
+ aura_flow: Txt2ImgPreset
41
+ flux_1_1_pro_fal: Txt2ImgPreset
42
+ flux_dev_fal: Txt2ImgPreset
43
+ flux_pro_fal: Txt2ImgPreset
44
+ flux_schnell_fal: Txt2ImgPreset
45
+ fooocus: Txt2ImgPreset
46
+ kolors: Txt2ImgPreset
47
+ stable_diffusion_3: Txt2ImgPreset
48
+ # hf
49
+ flux_dev_hf: Txt2ImgPreset
50
+ flux_schnell_hf: Txt2ImgPreset
51
+ stable_diffusion_xl: Txt2ImgPreset
52
+ # together
53
+ flux_schnell_free_together: Txt2ImgPreset
54
+
55
+
56
+ @dataclass
57
+ class Preset:
58
+ txt2txt: Txt2TxtPresets
59
+ txt2img: Txt2ImgPresets
60
+
61
+
62
+ preset = Preset(
63
+ txt2txt=Txt2TxtPresets(
64
+ # Every service has model and system messages
65
+ hugging_face=Txt2TxtPreset(
66
+ frequency_penalty=0.0,
67
+ frequency_penalty_min=-2.0,
68
+ frequency_penalty_max=2.0,
69
+ parameters=["max_tokens", "temperature", "frequency_penalty", "seed"],
70
+ ),
71
+ perplexity=Txt2TxtPreset(
72
+ frequency_penalty=1.0,
73
+ frequency_penalty_min=1.0,
74
+ frequency_penalty_max=2.0,
75
+ parameters=["max_tokens", "temperature", "frequency_penalty"],
76
+ ),
77
+ ),
78
+ txt2img=Txt2ImgPresets(
79
+ aura_flow=Txt2ImgPreset(
80
+ "AuraFlow",
81
+ guidance_scale=3.5,
82
+ guidance_scale_min=1.0,
83
+ guidance_scale_max=10.0,
84
+ num_inference_steps=28,
85
+ num_inference_steps_min=10,
86
+ num_inference_steps_max=50,
87
+ parameters=["seed", "num_inference_steps", "guidance_scale", "expand_prompt"],
88
+ kwargs={"num_images": 1, "sync_mode": False},
89
+ ),
90
+ flux_1_1_pro_bfl=Txt2ImgPreset(
91
+ "FLUX1.1 Pro",
92
+ parameters=["seed", "width", "height", "prompt_upsampling"],
93
+ kwargs={"safety_tolerance": 6},
94
+ ),
95
+ flux_pro_bfl=Txt2ImgPreset(
96
+ "FLUX.1 Pro",
97
+ guidance_scale=2.5,
98
+ guidance_scale_min=1.5,
99
+ guidance_scale_max=5.0,
100
+ num_inference_steps=40,
101
+ num_inference_steps_min=10,
102
+ num_inference_steps_max=50,
103
+ parameters=["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
104
+ kwargs={"safety_tolerance": 6, "interval": 1},
105
+ ),
106
+ flux_dev_bfl=Txt2ImgPreset(
107
+ "FLUX.1 Dev",
108
+ num_inference_steps=28,
109
+ num_inference_steps_min=10,
110
+ num_inference_steps_max=50,
111
+ guidance_scale=3.0,
112
+ guidance_scale_min=1.5,
113
+ guidance_scale_max=5.0,
114
+ parameters=["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
115
+ kwargs={"safety_tolerance": 6},
116
+ ),
117
+ flux_1_1_pro_fal=Txt2ImgPreset(
118
+ "FLUX1.1 Pro",
119
+ parameters=["seed", "image_size"],
120
+ kwargs={
121
+ "num_images": 1,
122
+ "sync_mode": False,
123
+ "safety_tolerance": 6,
124
+ "enable_safety_checker": False,
125
+ },
126
+ ),
127
+ flux_pro_fal=Txt2ImgPreset(
128
+ "FLUX.1 Pro",
129
+ guidance_scale=2.5,
130
+ guidance_scale_min=1.5,
131
+ guidance_scale_max=5.0,
132
+ num_inference_steps=40,
133
+ num_inference_steps_min=10,
134
+ num_inference_steps_max=50,
135
+ parameters=["seed", "image_size", "num_inference_steps", "guidance_scale"],
136
+ kwargs={"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
137
+ ),
138
+ flux_dev_fal=Txt2ImgPreset(
139
+ "FLUX.1 Dev",
140
+ num_inference_steps=28,
141
+ num_inference_steps_min=10,
142
+ num_inference_steps_max=50,
143
+ guidance_scale=3.0,
144
+ guidance_scale_min=1.5,
145
+ guidance_scale_max=5.0,
146
+ parameters=["seed", "image_size", "num_inference_steps", "guidance_scale"],
147
+ kwargs={"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
148
+ ),
149
+ flux_schnell_fal=Txt2ImgPreset(
150
+ "FLUX.1 Schnell",
151
+ num_inference_steps=4,
152
+ num_inference_steps_min=1,
153
+ num_inference_steps_max=12,
154
+ parameters=["seed", "image_size", "num_inference_steps"],
155
+ kwargs={"num_images": 1, "sync_mode": False, "enable_safety_checker": False},
156
+ ),
157
+ flux_dev_hf=Txt2ImgPreset(
158
+ "FLUX.1 Dev",
159
+ num_inference_steps=28,
160
+ num_inference_steps_min=10,
161
+ num_inference_steps_max=50,
162
+ guidance_scale=3.0,
163
+ guidance_scale_min=1.5,
164
+ guidance_scale_max=5.0,
165
+ parameters=["width", "height", "guidance_scale", "num_inference_steps"],
166
+ kwargs={"max_sequence_length": 512},
167
+ ),
168
+ flux_schnell_hf=Txt2ImgPreset(
169
+ "FLUX.1 Schnell",
170
+ num_inference_steps=4,
171
+ num_inference_steps_min=1,
172
+ num_inference_steps_max=12,
173
+ parameters=["width", "height", "num_inference_steps"],
174
+ kwargs={"guidance_scale": 0.0, "max_sequence_length": 256},
175
+ ),
176
+ flux_schnell_free_together=Txt2ImgPreset(
177
+ "FLUX.1 Schnell Free",
178
+ num_inference_steps=4,
179
+ num_inference_steps_min=1,
180
+ num_inference_steps_max=12,
181
+ parameters=["model", "seed", "width", "height", "steps"],
182
+ kwargs={"n": 1},
183
+ ),
184
+ fooocus=Txt2ImgPreset(
185
+ "Fooocus",
186
+ guidance_scale=4.0,
187
+ guidance_scale_min=1.0,
188
+ guidance_scale_max=10.0,
189
+ parameters=["seed", "negative_prompt", "aspect_ratio", "guidance_scale"],
190
+ kwargs={
191
+ "num_images": 1,
192
+ "sync_mode": True,
193
+ "enable_safety_checker": False,
194
+ "output_format": "png",
195
+ "sharpness": 2,
196
+ "styles": ["Fooocus Enhance", "Fooocus V2", "Fooocus Sharp"],
197
+ "performance": "Quality",
198
+ },
199
+ ),
200
+ kolors=Txt2ImgPreset(
201
+ "Kolors",
202
+ guidance_scale=5.0,
203
+ guidance_scale_min=1.0,
204
+ guidance_scale_max=10.0,
205
+ num_inference_steps=50,
206
+ num_inference_steps_min=10,
207
+ num_inference_steps_max=50,
208
+ parameters=["seed", "negative_prompt", "image_size", "guidance_scale", "num_inference_steps"],
209
+ kwargs={
210
+ "num_images": 1,
211
+ "sync_mode": True,
212
+ "enable_safety_checker": False,
213
+ "scheduler": "EulerDiscreteScheduler",
214
+ },
215
+ ),
216
+ stable_diffusion_3=Txt2ImgPreset(
217
+ "SD3",
218
+ guidance_scale=5.0,
219
+ guidance_scale_min=1.0,
220
+ guidance_scale_max=10.0,
221
+ num_inference_steps=28,
222
+ num_inference_steps_min=10,
223
+ num_inference_steps_max=50,
224
+ parameters=[
225
+ "seed",
226
+ "negative_prompt",
227
+ "image_size",
228
+ "guidance_scale",
229
+ "num_inference_steps",
230
+ "prompt_expansion",
231
+ ],
232
+ kwargs={"num_images": 1, "sync_mode": True, "enable_safety_checker": False},
233
+ ),
234
+ stable_diffusion_xl=Txt2ImgPreset(
235
+ "SDXL",
236
+ guidance_scale=7.0,
237
+ guidance_scale_min=1.0,
238
+ guidance_scale_max=10.0,
239
+ num_inference_steps=40,
240
+ num_inference_steps_min=10,
241
+ num_inference_steps_max=50,
242
+ parameters=[
243
+ "seed",
244
+ "negative_prompt",
245
+ "width",
246
+ "height",
247
+ "guidance_scale",
248
+ "num_inference_steps",
249
+ ],
250
+ ),
251
+ ),
252
+ )
lib/presets.py DELETED
@@ -1,193 +0,0 @@
1
- from types import SimpleNamespace
2
-
3
- # txt2txt
4
- ServicePresets = SimpleNamespace(
5
- # Every service has model and system messages
6
- HUGGING_FACE={
7
- "frequency_penalty": 0.0,
8
- "frequency_penalty_min": -2.0,
9
- "frequency_penalty_max": 2.0,
10
- "parameters": ["max_tokens", "temperature", "frequency_penalty", "seed"],
11
- },
12
- PERPLEXITY={
13
- "frequency_penalty": 1.0,
14
- "frequency_penalty_min": 1.0,
15
- "frequency_penalty_max": 2.0,
16
- "parameters": ["max_tokens", "temperature", "frequency_penalty"],
17
- },
18
- )
19
-
20
- # txt2img
21
- ModelPresets = SimpleNamespace(
22
- AURA_FLOW={
23
- "name": "AuraFlow",
24
- "guidance_scale": 3.5,
25
- "guidance_scale_min": 1.0,
26
- "guidance_scale_max": 10.0,
27
- "num_inference_steps": 28,
28
- "num_inference_steps_min": 10,
29
- "num_inference_steps_max": 50,
30
- "parameters": ["seed", "num_inference_steps", "guidance_scale", "expand_prompt"],
31
- "kwargs": {"num_images": 1, "sync_mode": False},
32
- },
33
- FLUX_1_1_PRO_BFL={
34
- "name": "FLUX1.1 Pro",
35
- "parameters": ["seed", "width", "height", "prompt_upsampling"],
36
- "kwargs": {"safety_tolerance": 6},
37
- },
38
- FLUX_PRO_BFL={
39
- "name": "FLUX.1 Pro",
40
- "guidance_scale": 2.5,
41
- "guidance_scale_min": 1.5,
42
- "guidance_scale_max": 5.0,
43
- "num_inference_steps": 40,
44
- "num_inference_steps_min": 10,
45
- "num_inference_steps_max": 50,
46
- "parameters": ["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
47
- "kwargs": {"safety_tolerance": 6, "interval": 1},
48
- },
49
- FLUX_DEV_BFL={
50
- "name": "FLUX.1 Dev",
51
- "num_inference_steps": 28,
52
- "num_inference_steps_min": 10,
53
- "num_inference_steps_max": 50,
54
- "guidance_scale": 3.0,
55
- "guidance_scale_min": 1.5,
56
- "guidance_scale_max": 5.0,
57
- "parameters": ["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
58
- "kwargs": {"safety_tolerance": 6},
59
- },
60
- FLUX_1_1_PRO_FAL={
61
- "name": "FLUX1.1 Pro",
62
- "parameters": ["seed", "image_size"],
63
- "kwargs": {
64
- "num_images": 1,
65
- "sync_mode": False,
66
- "safety_tolerance": 6,
67
- "enable_safety_checker": False,
68
- },
69
- },
70
- FLUX_PRO_FAL={
71
- "name": "FLUX.1 Pro",
72
- "guidance_scale": 2.5,
73
- "guidance_scale_min": 1.5,
74
- "guidance_scale_max": 5.0,
75
- "num_inference_steps": 40,
76
- "num_inference_steps_min": 10,
77
- "num_inference_steps_max": 50,
78
- "parameters": ["seed", "image_size", "num_inference_steps", "guidance_scale"],
79
- "kwargs": {"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
80
- },
81
- FLUX_DEV_FAL={
82
- "name": "FLUX.1 Dev",
83
- "num_inference_steps": 28,
84
- "num_inference_steps_min": 10,
85
- "num_inference_steps_max": 50,
86
- "guidance_scale": 3.0,
87
- "guidance_scale_min": 1.5,
88
- "guidance_scale_max": 5.0,
89
- "parameters": ["seed", "image_size", "num_inference_steps", "guidance_scale"],
90
- "kwargs": {"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
91
- },
92
- FLUX_SCHNELL_FAL={
93
- "name": "FLUX.1 Schnell",
94
- "num_inference_steps": 4,
95
- "num_inference_steps_min": 1,
96
- "num_inference_steps_max": 12,
97
- "parameters": ["seed", "image_size", "num_inference_steps"],
98
- "kwargs": {"num_images": 1, "sync_mode": False, "enable_safety_checker": False},
99
- },
100
- FLUX_DEV_HF={
101
- "name": "FLUX.1 Dev",
102
- "num_inference_steps": 28,
103
- "num_inference_steps_min": 10,
104
- "num_inference_steps_max": 50,
105
- "guidance_scale": 3.0,
106
- "guidance_scale_min": 1.5,
107
- "guidance_scale_max": 5.0,
108
- "parameters": ["width", "height", "guidance_scale", "num_inference_steps"],
109
- "kwargs": {"max_sequence_length": 512},
110
- },
111
- FLUX_SCHNELL_HF={
112
- "name": "FLUX.1 Schnell",
113
- "num_inference_steps": 4,
114
- "num_inference_steps_min": 1,
115
- "num_inference_steps_max": 12,
116
- "parameters": ["width", "height", "num_inference_steps"],
117
- "kwargs": {"guidance_scale": 0.0, "max_sequence_length": 256},
118
- },
119
- FLUX_SCHNELL_FREE_TOGETHER={
120
- "name": "FLUX.1 Schnell Free",
121
- "num_inference_steps": 4,
122
- "num_inference_steps_min": 1,
123
- "num_inference_steps_max": 12,
124
- "parameters": ["model", "seed", "width", "height", "steps"],
125
- "kwargs": {"n": 1},
126
- },
127
- FOOOCUS={
128
- "name": "Fooocus",
129
- "guidance_scale": 4.0,
130
- "guidance_scale_min": 1.0,
131
- "guidance_scale_max": 10.0,
132
- "parameters": ["seed", "negative_prompt", "aspect_ratio", "guidance_scale"],
133
- "kwargs": {
134
- "num_images": 1,
135
- "sync_mode": True,
136
- "enable_safety_checker": False,
137
- "output_format": "png",
138
- "sharpness": 2,
139
- "styles": ["Fooocus Enhance", "Fooocus V2", "Fooocus Sharp"],
140
- "performance": "Quality",
141
- },
142
- },
143
- KOLORS={
144
- "name": "Kolors",
145
- "guidance_scale": 5.0,
146
- "guidance_scale_min": 1.0,
147
- "guidance_scale_max": 10.0,
148
- "num_inference_steps": 50,
149
- "num_inference_steps_min": 10,
150
- "num_inference_steps_max": 50,
151
- "parameters": [
152
- "seed",
153
- "negative_prompt",
154
- "image_size",
155
- "guidance_scale",
156
- "num_inference_steps",
157
- ],
158
- "kwargs": {
159
- "num_images": 1,
160
- "sync_mode": True,
161
- "enable_safety_checker": False,
162
- "scheduler": "EulerDiscreteScheduler",
163
- },
164
- },
165
- STABLE_DIFFUSION_3={
166
- "name": "SD3",
167
- "guidance_scale": 5.0,
168
- "guidance_scale_min": 1.0,
169
- "guidance_scale_max": 10.0,
170
- "num_inference_steps": 28,
171
- "num_inference_steps_min": 10,
172
- "num_inference_steps_max": 50,
173
- "parameters": [
174
- "seed",
175
- "negative_prompt",
176
- "image_size",
177
- "guidance_scale",
178
- "num_inference_steps",
179
- "prompt_expansion",
180
- ],
181
- "kwargs": {"num_images": 1, "sync_mode": True, "enable_safety_checker": False},
182
- },
183
- STABLE_DIFFUSION_XL={
184
- "name": "SDXL",
185
- "guidance_scale": 7.0,
186
- "guidance_scale_min": 1.0,
187
- "guidance_scale_max": 10.0,
188
- "num_inference_steps": 40,
189
- "num_inference_steps_min": 10,
190
- "num_inference_steps_max": 50,
191
- "parameters": ["seed", "negative_prompt", "width", "height", "guidance_scale", "num_inference_steps"],
192
- },
193
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/1_💬_Text_Generation.py CHANGED
@@ -3,7 +3,7 @@ from datetime import datetime
3
 
4
  import streamlit as st
5
 
6
- from lib import ServicePresets, config, txt2txt_generate
7
 
8
  SERVICE_SESSION = {
9
  "Hugging Face": "api_key_hugging_face",
@@ -74,9 +74,9 @@ system = st.sidebar.text_area(
74
 
75
  # build parameters from preset
76
  parameters = {}
77
- service_key = service.upper().replace(" ", "_")
78
- preset = getattr(ServicePresets, service_key, {})
79
- for param in preset["parameters"]:
80
  if param == "max_tokens":
81
  parameters[param] = st.sidebar.slider(
82
  "Max Tokens",
@@ -101,9 +101,9 @@ for param in preset["parameters"]:
101
  parameters[param] = st.sidebar.slider(
102
  "Frequency Penalty",
103
  step=0.1,
104
- value=preset["frequency_penalty"],
105
- min_value=preset["frequency_penalty_min"],
106
- max_value=preset["frequency_penalty_max"],
107
  disabled=st.session_state.running,
108
  help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
109
  )
 
3
 
4
  import streamlit as st
5
 
6
+ from lib import config, preset, txt2txt_generate
7
 
8
  SERVICE_SESSION = {
9
  "Hugging Face": "api_key_hugging_face",
 
74
 
75
  # build parameters from preset
76
  parameters = {}
77
+ service_key = service.lower().replace(" ", "_")
78
+ service_preset = getattr(preset.txt2txt, service_key)
79
+ for param in service_preset.parameters:
80
  if param == "max_tokens":
81
  parameters[param] = st.sidebar.slider(
82
  "Max Tokens",
 
101
  parameters[param] = st.sidebar.slider(
102
  "Frequency Penalty",
103
  step=0.1,
104
+ value=service_preset.frequency_penalty,
105
+ min_value=service_preset.frequency_penalty_min,
106
+ max_value=service_preset.frequency_penalty_max,
107
  disabled=st.session_state.running,
108
  help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
109
  )
pages/2_🎨_Text_to_Image.py CHANGED
@@ -3,7 +3,7 @@ from datetime import datetime
3
 
4
  import streamlit as st
5
 
6
- from lib import ModelPresets, config, txt2img_generate
7
 
8
  # The token name is the service in lower_snake_case
9
  SERVICE_SESSION = {
@@ -24,24 +24,24 @@ SESSION_TOKEN = {
24
  # Model IDs in lib/config.py
25
  PRESET_MODEL = {
26
  # bfl
27
- "flux-pro-1.1": ModelPresets.FLUX_1_1_PRO_BFL,
28
- "flux-pro": ModelPresets.FLUX_PRO_BFL,
29
- "flux-dev": ModelPresets.FLUX_DEV_BFL,
30
  # fal
31
- "fal-ai/aura-flow": ModelPresets.AURA_FLOW,
32
- "fal-ai/flux/dev": ModelPresets.FLUX_DEV_FAL,
33
- "fal-ai/flux/schnell": ModelPresets.FLUX_SCHNELL_FAL,
34
- "fal-ai/flux-pro": ModelPresets.FLUX_PRO_FAL,
35
- "fal-ai/flux-pro/v1.1": ModelPresets.FLUX_1_1_PRO_FAL,
36
- "fal-ai/fooocus": ModelPresets.FOOOCUS,
37
- "fal-ai/kolors": ModelPresets.KOLORS,
38
- "fal-ai/stable-diffusion-v3-medium": ModelPresets.STABLE_DIFFUSION_3,
39
  # hf
40
- "black-forest-labs/flux.1-dev": ModelPresets.FLUX_DEV_HF,
41
- "black-forest-labs/flux.1-schnell": ModelPresets.FLUX_SCHNELL_HF,
42
- "stabilityai/stable-diffusion-xl-base-1.0": ModelPresets.STABLE_DIFFUSION_XL,
43
  # together
44
- "black-forest-labs/FLUX.1-schnell-Free": ModelPresets.FLUX_SCHNELL_FREE_TOGETHER,
45
  }
46
 
47
  st.set_page_config(
@@ -81,7 +81,8 @@ service = st.sidebar.selectbox(
81
  index=2, # Hugging Face
82
  )
83
 
84
- # Disable API key input and hide value if set by environment variable; handle empty string value later.
 
85
  for display_name, session_key in SERVICE_SESSION.items():
86
  if service == display_name:
87
  st.session_state[session_key] = st.sidebar.text_input(
@@ -108,7 +109,7 @@ st.html("""
108
  # Build parameters from preset by rendering the appropriate input widgets
109
  parameters = {}
110
  preset = PRESET_MODEL[model]
111
- for param in preset["parameters"]:
112
  if param == "model":
113
  parameters[param] = model
114
  if param == "seed":
@@ -160,18 +161,18 @@ for param in preset["parameters"]:
160
  if param in ["guidance_scale", "guidance"]:
161
  parameters[param] = st.sidebar.slider(
162
  "Guidance Scale",
163
- preset["guidance_scale_min"],
164
- preset["guidance_scale_max"],
165
- preset["guidance_scale"],
166
  0.1,
167
  disabled=st.session_state.running,
168
  )
169
  if param in ["num_inference_steps", "steps"]:
170
  parameters[param] = st.sidebar.slider(
171
  "Inference Steps",
172
- preset["num_inference_steps_min"],
173
- preset["num_inference_steps_max"],
174
- preset["num_inference_steps"],
175
  1,
176
  disabled=st.session_state.running,
177
  )
@@ -279,16 +280,15 @@ if prompt := st.chat_input(
279
 
280
  with st.chat_message("assistant"):
281
  with st.spinner("Running..."):
282
- if preset.get("kwargs") is not None:
283
- parameters.update(preset["kwargs"])
284
  session_key = f"api_key_{service.lower().replace(' ', '_')}"
285
  api_key = st.session_state[session_key] or SESSION_TOKEN[session_key]
286
  image = txt2img_generate(api_key, service, model, prompt, parameters)
287
  st.session_state.running = False
288
 
289
- model_name = PRESET_MODEL[model]["name"]
290
  st.session_state.txt2img_messages.append(
291
- {"role": "user", "content": prompt, "parameters": parameters, "model": model_name}
292
  )
293
  st.session_state.txt2img_messages.append({"role": "assistant", "content": image})
294
  st.rerun()
 
3
 
4
  import streamlit as st
5
 
6
+ from lib import config, preset, txt2img_generate
7
 
8
  # The token name is the service in lower_snake_case
9
  SERVICE_SESSION = {
 
24
  # Model IDs in lib/config.py
25
  PRESET_MODEL = {
26
  # bfl
27
+ "flux-pro-1.1": preset.txt2img.flux_1_1_pro_bfl,
28
+ "flux-pro": preset.txt2img.flux_pro_bfl,
29
+ "flux-dev": preset.txt2img.flux_dev_bfl,
30
  # fal
31
+ "fal-ai/aura-flow": preset.txt2img.aura_flow,
32
+ "fal-ai/flux/dev": preset.txt2img.flux_dev_fal,
33
+ "fal-ai/flux/schnell": preset.txt2img.flux_schnell_fal,
34
+ "fal-ai/flux-pro": preset.txt2img.flux_pro_fal,
35
+ "fal-ai/flux-pro/v1.1": preset.txt2img.flux_1_1_pro_fal,
36
+ "fal-ai/fooocus": preset.txt2img.fooocus,
37
+ "fal-ai/kolors": preset.txt2img.kolors,
38
+ "fal-ai/stable-diffusion-v3-medium": preset.txt2img.stable_diffusion_3,
39
  # hf
40
+ "black-forest-labs/flux.1-dev": preset.txt2img.flux_dev_hf,
41
+ "black-forest-labs/flux.1-schnell": preset.txt2img.flux_schnell_hf,
42
+ "stabilityai/stable-diffusion-xl-base-1.0": preset.txt2img.stable_diffusion_xl,
43
  # together
44
+ "black-forest-labs/FLUX.1-schnell-Free": preset.txt2img.flux_schnell_free_together,
45
  }
46
 
47
  st.set_page_config(
 
81
  index=2, # Hugging Face
82
  )
83
 
84
+ # Show the API key input for the selected service.
85
+ # Disable and hide value if set by environment variable; handle empty string value later.
86
  for display_name, session_key in SERVICE_SESSION.items():
87
  if service == display_name:
88
  st.session_state[session_key] = st.sidebar.text_input(
 
109
  # Build parameters from preset by rendering the appropriate input widgets
110
  parameters = {}
111
  preset = PRESET_MODEL[model]
112
+ for param in preset.parameters:
113
  if param == "model":
114
  parameters[param] = model
115
  if param == "seed":
 
161
  if param in ["guidance_scale", "guidance"]:
162
  parameters[param] = st.sidebar.slider(
163
  "Guidance Scale",
164
+ preset.guidance_scale_min,
165
+ preset.guidance_scale_max,
166
+ preset.guidance_scale,
167
  0.1,
168
  disabled=st.session_state.running,
169
  )
170
  if param in ["num_inference_steps", "steps"]:
171
  parameters[param] = st.sidebar.slider(
172
  "Inference Steps",
173
+ preset.num_inference_steps_min,
174
+ preset.num_inference_steps_max,
175
+ preset.num_inference_steps,
176
  1,
177
  disabled=st.session_state.running,
178
  )
 
280
 
281
  with st.chat_message("assistant"):
282
  with st.spinner("Running..."):
283
+ if preset.kwargs:
284
+ parameters.update(preset.kwargs)
285
  session_key = f"api_key_{service.lower().replace(' ', '_')}"
286
  api_key = st.session_state[session_key] or SESSION_TOKEN[session_key]
287
  image = txt2img_generate(api_key, service, model, prompt, parameters)
288
  st.session_state.running = False
289
 
 
290
  st.session_state.txt2img_messages.append(
291
+ {"role": "user", "content": prompt, "parameters": parameters, "model": PRESET_MODEL[model].name}
292
  )
293
  st.session_state.txt2img_messages.append({"role": "assistant", "content": image})
294
  st.rerun()