adamelliotfields commited on
Commit
c659a88
·
verified ·
1 Parent(s): 7f5047d

Refactor preset

Browse files
0_🏠_Home.py CHANGED
@@ -4,12 +4,12 @@ from lib import config
4
 
5
  st.set_page_config(
6
  page_title=config.title,
7
- page_icon=config.icon,
8
  layout=config.layout,
9
  )
10
 
11
  # sidebar
12
- st.logo("logo.png")
13
 
14
  # title
15
  st.html("""
@@ -39,6 +39,7 @@ st.html("""
39
  <h1>API Inference</h1>
40
  <span class="pro-badge">PRO</span>
41
  </div>
 
42
  """)
43
 
44
  st.markdown("## Tasks")
@@ -58,7 +59,7 @@ st.markdown("""
58
  st.markdown("""
59
  ## Usage
60
 
61
- Choose a task from the sidebar. Enter your API key for the service you want to use. Refresh your browser to remove it.
62
 
63
  I recommend [duplicating this space](https://huggingface.co/spaces/adamelliotfields/api-inference?duplicate=true) **privately** and persisting your keys as secrets. See [`README.md`](https://huggingface.co/spaces/adamelliotfields/api-inference/blob/main/README.md).
64
  """)
 
4
 
5
  st.set_page_config(
6
  page_title=config.title,
7
+ page_icon=config.logo,
8
  layout=config.layout,
9
  )
10
 
11
  # sidebar
12
+ st.logo(config.logo)
13
 
14
  # title
15
  st.html("""
 
39
  <h1>API Inference</h1>
40
  <span class="pro-badge">PRO</span>
41
  </div>
42
+ <p>Explore popular AI endpoints in one place.</p>
43
  """)
44
 
45
  st.markdown("## Tasks")
 
59
  st.markdown("""
60
  ## Usage
61
 
62
+ Choose a task. Select a service. Enter your API key (refresh browser to clear).
63
 
64
  I recommend [duplicating this space](https://huggingface.co/spaces/adamelliotfields/api-inference?duplicate=true) **privately** and persisting your keys as secrets. See [`README.md`](https://huggingface.co/spaces/adamelliotfields/api-inference/blob/main/README.md).
65
  """)
lib/__init__.py CHANGED
@@ -1,8 +1,10 @@
1
  from .api import txt2img_generate, txt2txt_generate
2
  from .config import config
3
- from .preset import preset
4
 
5
  __all__ = [
 
 
6
  "config",
7
  "preset",
8
  "txt2img_generate",
 
1
  from .api import txt2img_generate, txt2txt_generate
2
  from .config import config
3
+ from .preset import Txt2ImgPreset, Txt2TxtPreset, preset
4
 
5
  __all__ = [
6
+ "Txt2ImgPreset",
7
+ "Txt2TxtPreset",
8
  "config",
9
  "preset",
10
  "txt2img_generate",
lib/api.py CHANGED
@@ -11,8 +11,8 @@ from .config import config
11
 
12
 
13
  def txt2txt_generate(api_key, service, model, parameters, **kwargs):
14
- base_url = config.services[service]
15
- if service == "Hugging Face":
16
  base_url = f"{base_url}/{model}/v1"
17
  client = OpenAI(api_key=api_key, base_url=base_url)
18
 
@@ -29,42 +29,32 @@ def txt2txt_generate(api_key, service, model, parameters, **kwargs):
29
 
30
  def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
31
  headers = {}
32
- if service == "Black Forest Labs":
 
 
33
  headers["x-key"] = api_key
 
34
 
35
- if service == "Fal":
36
  headers["Authorization"] = f"Key {api_key}"
 
37
 
38
- if service == "Hugging Face":
39
  headers["Authorization"] = f"Bearer {api_key}"
40
  headers["X-Wait-For-Model"] = "true"
41
  headers["X-Use-Cache"] = "false"
42
-
43
- if service == "Together":
44
- headers["Authorization"] = f"Bearer {api_key}"
45
-
46
- json = {}
47
- if service == "Black Forest Labs":
48
- json = {**parameters, **kwargs}
49
- json["prompt"] = inputs
50
-
51
- if service == "Fal":
52
- json = {**parameters, **kwargs}
53
- json["prompt"] = inputs
54
-
55
- if service == "Hugging Face":
56
  json = {
57
  "inputs": inputs,
58
  "parameters": {**parameters, **kwargs},
59
  }
60
 
61
- if service == "Together":
62
- json = {**parameters, **kwargs}
63
  json["prompt"] = inputs
64
 
65
- base_url = config.services[service]
66
 
67
- if service not in ["Together"]:
68
  base_url = f"{base_url}/{model}"
69
 
70
  try:
@@ -72,9 +62,9 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
72
  if response.status_code // 100 == 2: # 2xx
73
  # BFL is async so we need to poll for result
74
  # https://api.bfl.ml/docs
75
- if service == "Black Forest Labs":
76
  id = response.json()["id"]
77
- url = f"{config.services[service]}/get_result?id={id}"
78
 
79
  retries = 0
80
  while retries < config.txt2img.timeout:
@@ -95,7 +85,7 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
95
 
96
  return "Error: API timeout"
97
 
98
- if service == "Fal":
99
  # Sync mode means wait for image base64 string instead of CDN link
100
  if parameters.get("sync_mode", True):
101
  bytes = base64.b64decode(response.json()["images"][0]["url"].split(",")[-1])
@@ -105,10 +95,10 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
105
  image = httpx.get(url, headers=headers, timeout=config.txt2img.timeout)
106
  return Image.open(io.BytesIO(image.content))
107
 
108
- if service == "Hugging Face":
109
  return Image.open(io.BytesIO(response.content))
110
 
111
- if service == "Together":
112
  url = response.json()["data"][0]["url"]
113
  image = httpx.get(url, headers=headers, timeout=config.txt2img.timeout)
114
  return Image.open(io.BytesIO(image.content))
 
11
 
12
 
13
  def txt2txt_generate(api_key, service, model, parameters, **kwargs):
14
+ base_url = config.service[service].url
15
+ if service == "hf":
16
  base_url = f"{base_url}/{model}/v1"
17
  client = OpenAI(api_key=api_key, base_url=base_url)
18
 
 
29
 
30
  def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
31
  headers = {}
32
+ json = {**parameters, **kwargs}
33
+
34
+ if service == "bfl":
35
  headers["x-key"] = api_key
36
+ json["prompt"] = inputs
37
 
38
+ if service == "fal":
39
  headers["Authorization"] = f"Key {api_key}"
40
+ json["prompt"] = inputs
41
 
42
+ if service == "hf":
43
  headers["Authorization"] = f"Bearer {api_key}"
44
  headers["X-Wait-For-Model"] = "true"
45
  headers["X-Use-Cache"] = "false"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  json = {
47
  "inputs": inputs,
48
  "parameters": {**parameters, **kwargs},
49
  }
50
 
51
+ if service == "together":
52
+ headers["Authorization"] = f"Bearer {api_key}"
53
  json["prompt"] = inputs
54
 
55
+ base_url = config.service[service].url
56
 
57
+ if service not in ["together"]:
58
  base_url = f"{base_url}/{model}"
59
 
60
  try:
 
62
  if response.status_code // 100 == 2: # 2xx
63
  # BFL is async so we need to poll for result
64
  # https://api.bfl.ml/docs
65
+ if service == "bfl":
66
  id = response.json()["id"]
67
+ url = f"{config.service[service].url}/get_result?id={id}"
68
 
69
  retries = 0
70
  while retries < config.txt2img.timeout:
 
85
 
86
  return "Error: API timeout"
87
 
88
+ if service == "fal":
89
  # Sync mode means wait for image base64 string instead of CDN link
90
  if parameters.get("sync_mode", True):
91
  bytes = base64.b64decode(response.json()["images"][0]["url"].split(",")[-1])
 
95
  image = httpx.get(url, headers=headers, timeout=config.txt2img.timeout)
96
  return Image.open(io.BytesIO(image.content))
97
 
98
+ if service == "hf":
99
  return Image.open(io.BytesIO(response.content))
100
 
101
+ if service == "together":
102
  url = response.json()["data"][0]["url"]
103
  image = httpx.get(url, headers=headers, timeout=config.txt2img.timeout)
104
  return Image.open(io.BytesIO(image.content))
lib/config.py CHANGED
@@ -1,31 +1,22 @@
 
1
  from dataclasses import dataclass
2
- from typing import Dict, List
3
 
4
- from .preset import preset
5
 
6
-
7
- def txt2img_models_from_presets(presets):
8
- models = {}
9
- for p in presets:
10
- service = p.service
11
- model_id = p.model_id
12
- if service not in models:
13
- models[service] = []
14
- models[service].append(model_id)
15
- return models
16
 
17
 
18
  @dataclass
19
  class Txt2TxtConfig:
20
  default_system: str
21
- default_model: Dict[str, int]
22
- models: Dict[str, List[str]]
23
 
24
 
25
  @dataclass
26
  class Txt2ImgConfig:
27
- default_model: Dict[str, int]
28
- models: Dict[str, List[str]]
29
  hidden_parameters: List[str]
30
  negative_prompt: str
31
  default_image_size: str
@@ -38,37 +29,50 @@ class Txt2ImgConfig:
38
  @dataclass
39
  class Config:
40
  title: str
41
- icon: str
42
  layout: str
43
- services: Dict[str, str]
 
44
  txt2img: Txt2ImgConfig
45
  txt2txt: Txt2TxtConfig
46
 
47
 
48
- # TODO: API keys should be with services (make a dataclass)
49
  config = Config(
50
  title="API Inference",
51
- icon="⚡",
52
  layout="wide",
53
- services={
54
- "Black Forest Labs": "https://api.bfl.ml/v1",
55
- "Fal": "https://fal.run",
56
- "Hugging Face": "https://api-inference.huggingface.co/models",
57
- "Perplexity": "https://api.perplexity.ai",
58
- "Together": "https://api.together.xyz/v1/images/generations",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  },
60
  txt2img=Txt2ImgConfig(
61
- default_model={
62
- "Black Forest Labs": 2,
63
- "Fal": 0,
64
- "Hugging Face": 2,
65
- "Together": 0,
66
- },
67
- models=txt2img_models_from_presets(preset.txt2img.presets),
68
  hidden_parameters=[
69
  # Sent to API but not shown in generation parameters accordion
70
  "enable_safety_checker",
71
  "max_sequence_length",
 
72
  "num_images",
73
  "output_format",
74
  "performance",
@@ -115,25 +119,5 @@ config = Config(
115
  ),
116
  txt2txt=Txt2TxtConfig(
117
  default_system="You are a helpful assistant. Be precise and concise.",
118
- default_model={
119
- "Hugging Face": 4,
120
- "Perplexity": 3,
121
- },
122
- models={
123
- "Hugging Face": [
124
- "codellama/codellama-34b-instruct-hf",
125
- "meta-llama/llama-2-13b-chat-hf",
126
- "meta-llama/meta-llama-3.1-405b-instruct-fp8",
127
- "mistralai/mistral-7b-instruct-v0.2",
128
- "nousresearch/nous-hermes-2-mixtral-8x7b-dpo",
129
- ],
130
- "Perplexity": [
131
- "llama-3.1-sonar-small-128k-chat",
132
- "llama-3.1-sonar-large-128k-chat",
133
- "llama-3.1-sonar-small-128k-online",
134
- "llama-3.1-sonar-large-128k-online",
135
- "llama-3.1-sonar-huge-128k-online",
136
- ],
137
- },
138
  ),
139
  )
 
1
+ import os
2
  from dataclasses import dataclass
3
+ from typing import Dict, List, Optional
4
 
 
5
 
6
+ @dataclass
7
+ class ServiceConfig:
8
+ name: str
9
+ url: str
10
+ api_key: Optional[str] = None
 
 
 
 
 
11
 
12
 
13
  @dataclass
14
  class Txt2TxtConfig:
15
  default_system: str
 
 
16
 
17
 
18
  @dataclass
19
  class Txt2ImgConfig:
 
 
20
  hidden_parameters: List[str]
21
  negative_prompt: str
22
  default_image_size: str
 
29
  @dataclass
30
  class Config:
31
  title: str
 
32
  layout: str
33
+ logo: str
34
+ service: Dict[str, ServiceConfig]
35
  txt2img: Txt2ImgConfig
36
  txt2txt: Txt2TxtConfig
37
 
38
 
 
39
  config = Config(
40
  title="API Inference",
 
41
  layout="wide",
42
+ logo="logo.png",
43
+ service={
44
+ "bfl": ServiceConfig(
45
+ "Black Forest Labs",
46
+ "https://api.bfl.ml/v1",
47
+ os.environ.get("BFL_API_KEY"),
48
+ ),
49
+ "fal": ServiceConfig(
50
+ "Fal",
51
+ "https://fal.run",
52
+ os.environ.get("FAL_KEY"),
53
+ ),
54
+ "hf": ServiceConfig(
55
+ "Hugging Face",
56
+ "https://api-inference.huggingface.co/models",
57
+ os.environ.get("HF_TOKEN"),
58
+ ),
59
+ "pplx": ServiceConfig(
60
+ "Perplexity",
61
+ "https://api.perplexity.ai",
62
+ os.environ.get("PPLX_API_KEY"),
63
+ ),
64
+ "together": ServiceConfig(
65
+ "Together",
66
+ "https://api.together.xyz/v1/images/generations",
67
+ os.environ.get("TOGETHER_API_KEY"),
68
+ ),
69
  },
70
  txt2img=Txt2ImgConfig(
 
 
 
 
 
 
 
71
  hidden_parameters=[
72
  # Sent to API but not shown in generation parameters accordion
73
  "enable_safety_checker",
74
  "max_sequence_length",
75
+ "n",
76
  "num_images",
77
  "output_format",
78
  "performance",
 
119
  ),
120
  txt2txt=Txt2TxtConfig(
121
  default_system="You are a helpful assistant. Be precise and concise.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  ),
123
  )
lib/preset.py CHANGED
@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Union
4
 
5
  @dataclass
6
  class Txt2TxtPreset:
 
7
  frequency_penalty: float
8
  frequency_penalty_min: float
9
  frequency_penalty_max: float
@@ -12,10 +13,7 @@ class Txt2TxtPreset:
12
 
13
  @dataclass
14
  class Txt2ImgPreset:
15
- # FLUX1.1 has no scale or steps
16
  name: str
17
- service: str
18
- model_id: str
19
  guidance_scale: Optional[float] = None
20
  guidance_scale_min: Optional[float] = None
21
  guidance_scale_max: Optional[float] = None
@@ -27,66 +25,55 @@ class Txt2ImgPreset:
27
 
28
 
29
  @dataclass
30
- class Txt2TxtPresets:
31
- hugging_face: Txt2TxtPreset
32
- perplexity: Txt2TxtPreset
33
-
34
-
35
- @dataclass
36
- class Txt2ImgPresets:
37
- presets: List[Txt2ImgPreset] = field(default_factory=list)
38
 
39
- def __iter__(self):
40
- return iter(self.presets)
41
 
 
 
 
 
 
 
42
 
43
- @dataclass
44
- class Preset:
45
- txt2txt: Txt2TxtPresets
46
- txt2img: Txt2ImgPresets
 
 
47
 
48
 
49
  preset = Preset(
50
- txt2txt=Txt2TxtPresets(
51
- hugging_face=Txt2TxtPreset(
52
- frequency_penalty=0.0,
53
- frequency_penalty_min=-2.0,
54
- frequency_penalty_max=2.0,
55
- parameters=["max_tokens", "temperature", "frequency_penalty", "seed"],
56
- ),
57
- perplexity=Txt2TxtPreset(
58
- frequency_penalty=1.0,
59
- frequency_penalty_min=1.0,
60
- frequency_penalty_max=2.0,
61
- parameters=["max_tokens", "temperature", "frequency_penalty"],
62
- ),
63
- ),
64
- txt2img=Txt2ImgPresets(
65
- presets=[
66
- Txt2ImgPreset(
67
- "AuraFlow",
68
- "Fal",
69
- "fal-ai/aura-flow",
70
- guidance_scale=3.5,
71
- guidance_scale_min=1.0,
72
- guidance_scale_max=10.0,
73
- num_inference_steps=28,
74
- num_inference_steps_min=10,
75
- num_inference_steps_max=50,
76
- parameters=["seed", "num_inference_steps", "guidance_scale", "expand_prompt"],
77
- kwargs={"num_images": 1, "sync_mode": False},
78
  ),
79
- Txt2ImgPreset(
 
 
 
 
 
 
 
 
 
 
 
80
  "FLUX1.1 Pro",
81
- "Black Forest Labs",
82
- "flux-pro-1.1",
83
  parameters=["seed", "width", "height", "prompt_upsampling"],
84
  kwargs={"safety_tolerance": 6},
85
  ),
86
- Txt2ImgPreset(
87
  "FLUX.1 Pro",
88
- "Black Forest Labs",
89
- "flux-pro",
90
  guidance_scale=2.5,
91
  guidance_scale_min=1.5,
92
  guidance_scale_max=5.0,
@@ -96,10 +83,8 @@ preset = Preset(
96
  parameters=["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
97
  kwargs={"safety_tolerance": 6, "interval": 1},
98
  ),
99
- Txt2ImgPreset(
100
  "FLUX.1 Dev",
101
- "Black Forest Labs",
102
- "flux-dev",
103
  num_inference_steps=28,
104
  num_inference_steps_min=10,
105
  num_inference_steps_max=50,
@@ -109,10 +94,21 @@ preset = Preset(
109
  parameters=["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
110
  kwargs={"safety_tolerance": 6},
111
  ),
112
- Txt2ImgPreset(
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  "FLUX1.1 Pro",
114
- "Fal",
115
- "fal-ai/flux-pro/v1.1",
116
  parameters=["seed", "image_size"],
117
  kwargs={
118
  "num_images": 1,
@@ -121,10 +117,8 @@ preset = Preset(
121
  "enable_safety_checker": False,
122
  },
123
  ),
124
- Txt2ImgPreset(
125
  "FLUX.1 Pro",
126
- "Fal",
127
- "fal-ai/flux-pro",
128
  guidance_scale=2.5,
129
  guidance_scale_min=1.5,
130
  guidance_scale_max=5.0,
@@ -134,10 +128,8 @@ preset = Preset(
134
  parameters=["seed", "image_size", "num_inference_steps", "guidance_scale"],
135
  kwargs={"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
136
  ),
137
- Txt2ImgPreset(
138
  "FLUX.1 Dev",
139
- "Fal",
140
- "fal-ai/flux/dev",
141
  num_inference_steps=28,
142
  num_inference_steps_min=10,
143
  num_inference_steps_max=50,
@@ -147,53 +139,16 @@ preset = Preset(
147
  parameters=["seed", "image_size", "num_inference_steps", "guidance_scale"],
148
  kwargs={"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
149
  ),
150
- Txt2ImgPreset(
151
  "FLUX.1 Schnell",
152
- "Fal",
153
- "fal-ai/flux/schnell",
154
  num_inference_steps=4,
155
  num_inference_steps_min=1,
156
  num_inference_steps_max=12,
157
  parameters=["seed", "image_size", "num_inference_steps"],
158
  kwargs={"num_images": 1, "sync_mode": False, "enable_safety_checker": False},
159
  ),
160
- Txt2ImgPreset(
161
- "FLUX.1 Dev",
162
- "Hugging Face",
163
- "black-forest-labs/flux.1-dev",
164
- num_inference_steps=28,
165
- num_inference_steps_min=10,
166
- num_inference_steps_max=50,
167
- guidance_scale=3.0,
168
- guidance_scale_min=1.5,
169
- guidance_scale_max=5.0,
170
- parameters=["width", "height", "guidance_scale", "num_inference_steps"],
171
- kwargs={"max_sequence_length": 512},
172
- ),
173
- Txt2ImgPreset(
174
- "FLUX.1 Schnell",
175
- "Hugging Face",
176
- "black-forest-labs/flux.1-schnell",
177
- num_inference_steps=4,
178
- num_inference_steps_min=1,
179
- num_inference_steps_max=12,
180
- parameters=["width", "height", "num_inference_steps"],
181
- kwargs={"guidance_scale": 0.0, "max_sequence_length": 256},
182
- ),
183
- Txt2ImgPreset(
184
- "FLUX.1 Schnell Free",
185
- "Together",
186
- "black-forest-labs/FLUX.1-schnell-Free",
187
- num_inference_steps=4,
188
- num_inference_steps_min=1,
189
- num_inference_steps_max=12,
190
- parameters=["model", "seed", "width", "height", "steps"],
191
- kwargs={"n": 1},
192
- ),
193
- Txt2ImgPreset(
194
  "Fooocus",
195
- "Fal",
196
- "fal-ai/fooocus",
197
  guidance_scale=4.0,
198
  guidance_scale_min=1.0,
199
  guidance_scale_max=10.0,
@@ -208,10 +163,8 @@ preset = Preset(
208
  "performance": "Quality",
209
  },
210
  ),
211
- Txt2ImgPreset(
212
  "Kolors",
213
- "Fal",
214
- "fal-ai/kolors",
215
  guidance_scale=5.0,
216
  guidance_scale_min=1.0,
217
  guidance_scale_max=10.0,
@@ -226,10 +179,8 @@ preset = Preset(
226
  "scheduler": "EulerDiscreteScheduler",
227
  },
228
  ),
229
- Txt2ImgPreset(
230
  "SD3",
231
- "Fal",
232
- "fal-ai/stable-diffusion-v3-medium",
233
  guidance_scale=5.0,
234
  guidance_scale_min=1.0,
235
  guidance_scale_max=10.0,
@@ -246,10 +197,29 @@ preset = Preset(
246
  ],
247
  kwargs={"num_images": 1, "sync_mode": True, "enable_safety_checker": False},
248
  ),
249
- Txt2ImgPreset(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  "SDXL",
251
- "Hugging Face",
252
- "stabilityai/stable-diffusion-xl-base-1.0",
253
  guidance_scale=7.0,
254
  guidance_scale_min=1.0,
255
  guidance_scale_max=10.0,
@@ -265,6 +235,16 @@ preset = Preset(
265
  "num_inference_steps",
266
  ],
267
  ),
268
- ]
269
- ),
 
 
 
 
 
 
 
 
 
 
270
  )
 
4
 
5
  @dataclass
6
  class Txt2TxtPreset:
7
+ name: str
8
  frequency_penalty: float
9
  frequency_penalty_min: float
10
  frequency_penalty_max: float
 
13
 
14
  @dataclass
15
  class Txt2ImgPreset:
 
16
  name: str
 
 
17
  guidance_scale: Optional[float] = None
18
  guidance_scale_min: Optional[float] = None
19
  guidance_scale_max: Optional[float] = None
 
25
 
26
 
27
  @dataclass
28
+ class Preset:
29
+ txt2txt: Dict[str, Txt2TxtPreset]
30
+ txt2img: Dict[str, Txt2ImgPreset]
 
 
 
 
 
31
 
 
 
32
 
33
+ hf_txt2txt_kwargs = {
34
+ "frequency_penalty": 0.0,
35
+ "frequency_penalty_min": -2.0,
36
+ "frequency_penalty_max": 2.0,
37
+ "parameters": ["max_tokens", "temperature", "frequency_penalty", "seed"],
38
+ }
39
 
40
+ pplx_txt2txt_kwargs = {
41
+ "frequency_penalty": 1.0,
42
+ "frequency_penalty_min": 1.0,
43
+ "frequency_penalty_max": 2.0,
44
+ "parameters": ["max_tokens", "temperature", "frequency_penalty"],
45
+ }
46
 
47
 
48
  preset = Preset(
49
+ txt2txt={
50
+ "hf": {
51
+ # TODO: update models
52
+ "codellama/codellama-34b-instruct-hf": Txt2TxtPreset("Code Llama 34B", **hf_txt2txt_kwargs),
53
+ "meta-llama/llama-2-13b-chat-hf": Txt2TxtPreset("Llama 2 13B", **hf_txt2txt_kwargs),
54
+ "mistralai/mistral-7b-instruct-v0.2": Txt2TxtPreset("Mistral v0.2 7B", **hf_txt2txt_kwargs),
55
+ "nousresearch/nous-hermes-2-mixtral-8x7b-dpo": Txt2TxtPreset(
56
+ "Nous Hermes 2 Mixtral 8x7B",
57
+ **hf_txt2txt_kwargs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  ),
59
+ },
60
+ "pplx": {
61
+ "llama-3.1-sonar-small-128k-chat": Txt2TxtPreset("Sonar Small (Offline)", **pplx_txt2txt_kwargs),
62
+ "llama-3.1-sonar-large-128k-chat": Txt2TxtPreset("Sonar Large (Offline)", **pplx_txt2txt_kwargs),
63
+ "llama-3.1-sonar-small-128k-online": Txt2TxtPreset("Sonar Small (Online)", **pplx_txt2txt_kwargs),
64
+ "llama-3.1-sonar-large-128k-online": Txt2TxtPreset("Sonar Large (Online)", **pplx_txt2txt_kwargs),
65
+ "llama-3.1-sonar-huge-128k-online": Txt2TxtPreset("Sonar Huge (Online)", **pplx_txt2txt_kwargs),
66
+ },
67
+ },
68
+ txt2img={
69
+ "bfl": {
70
+ "flux-pro-1.1": Txt2ImgPreset(
71
  "FLUX1.1 Pro",
 
 
72
  parameters=["seed", "width", "height", "prompt_upsampling"],
73
  kwargs={"safety_tolerance": 6},
74
  ),
75
+ "flux-pro": Txt2ImgPreset(
76
  "FLUX.1 Pro",
 
 
77
  guidance_scale=2.5,
78
  guidance_scale_min=1.5,
79
  guidance_scale_max=5.0,
 
83
  parameters=["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
84
  kwargs={"safety_tolerance": 6, "interval": 1},
85
  ),
86
+ "flux-dev": Txt2ImgPreset(
87
  "FLUX.1 Dev",
 
 
88
  num_inference_steps=28,
89
  num_inference_steps_min=10,
90
  num_inference_steps_max=50,
 
94
  parameters=["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
95
  kwargs={"safety_tolerance": 6},
96
  ),
97
+ },
98
+ "fal": {
99
+ "fal-ai/aura-flow": Txt2ImgPreset(
100
+ "AuraFlow",
101
+ guidance_scale=3.5,
102
+ guidance_scale_min=1.0,
103
+ guidance_scale_max=10.0,
104
+ num_inference_steps=28,
105
+ num_inference_steps_min=10,
106
+ num_inference_steps_max=50,
107
+ parameters=["seed", "num_inference_steps", "guidance_scale", "expand_prompt"],
108
+ kwargs={"num_images": 1, "sync_mode": False},
109
+ ),
110
+ "fal-ai/flux-pro/v1.1": Txt2ImgPreset(
111
  "FLUX1.1 Pro",
 
 
112
  parameters=["seed", "image_size"],
113
  kwargs={
114
  "num_images": 1,
 
117
  "enable_safety_checker": False,
118
  },
119
  ),
120
+ "fal-ai/flux-pro": Txt2ImgPreset(
121
  "FLUX.1 Pro",
 
 
122
  guidance_scale=2.5,
123
  guidance_scale_min=1.5,
124
  guidance_scale_max=5.0,
 
128
  parameters=["seed", "image_size", "num_inference_steps", "guidance_scale"],
129
  kwargs={"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
130
  ),
131
+ "fal-ai/flux/dev": Txt2ImgPreset(
132
  "FLUX.1 Dev",
 
 
133
  num_inference_steps=28,
134
  num_inference_steps_min=10,
135
  num_inference_steps_max=50,
 
139
  parameters=["seed", "image_size", "num_inference_steps", "guidance_scale"],
140
  kwargs={"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
141
  ),
142
+ "fal-ai/flux/schnell": Txt2ImgPreset(
143
  "FLUX.1 Schnell",
 
 
144
  num_inference_steps=4,
145
  num_inference_steps_min=1,
146
  num_inference_steps_max=12,
147
  parameters=["seed", "image_size", "num_inference_steps"],
148
  kwargs={"num_images": 1, "sync_mode": False, "enable_safety_checker": False},
149
  ),
150
+ "fal-ai/fooocus": Txt2ImgPreset(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  "Fooocus",
 
 
152
  guidance_scale=4.0,
153
  guidance_scale_min=1.0,
154
  guidance_scale_max=10.0,
 
163
  "performance": "Quality",
164
  },
165
  ),
166
+ "fal-ai/kolors": Txt2ImgPreset(
167
  "Kolors",
 
 
168
  guidance_scale=5.0,
169
  guidance_scale_min=1.0,
170
  guidance_scale_max=10.0,
 
179
  "scheduler": "EulerDiscreteScheduler",
180
  },
181
  ),
182
+ "fal-ai/stable-diffusion-v3-medium": Txt2ImgPreset(
183
  "SD3",
 
 
184
  guidance_scale=5.0,
185
  guidance_scale_min=1.0,
186
  guidance_scale_max=10.0,
 
197
  ],
198
  kwargs={"num_images": 1, "sync_mode": True, "enable_safety_checker": False},
199
  ),
200
+ },
201
+ "hf": {
202
+ "black-forest-labs/flux.1-dev": Txt2ImgPreset(
203
+ "FLUX.1 Dev",
204
+ num_inference_steps=28,
205
+ num_inference_steps_min=10,
206
+ num_inference_steps_max=50,
207
+ guidance_scale=3.0,
208
+ guidance_scale_min=1.5,
209
+ guidance_scale_max=5.0,
210
+ parameters=["width", "height", "guidance_scale", "num_inference_steps"],
211
+ kwargs={"max_sequence_length": 512},
212
+ ),
213
+ "black-forest-labs/flux.1-schnell": Txt2ImgPreset(
214
+ "FLUX.1 Schnell",
215
+ num_inference_steps=4,
216
+ num_inference_steps_min=1,
217
+ num_inference_steps_max=12,
218
+ parameters=["width", "height", "num_inference_steps"],
219
+ kwargs={"guidance_scale": 0.0, "max_sequence_length": 256},
220
+ ),
221
+ "stabilityai/stable-diffusion-xl-base-1.0": Txt2ImgPreset(
222
  "SDXL",
 
 
223
  guidance_scale=7.0,
224
  guidance_scale_min=1.0,
225
  guidance_scale_max=10.0,
 
235
  "num_inference_steps",
236
  ],
237
  ),
238
+ },
239
+ "together": {
240
+ "black-forest-labs/FLUX.1-schnell-Free": Txt2ImgPreset(
241
+ "FLUX.1 Schnell Free",
242
+ num_inference_steps=4,
243
+ num_inference_steps_min=1,
244
+ num_inference_steps_max=12,
245
+ parameters=["model", "seed", "width", "height", "steps"],
246
+ kwargs={"n": 1},
247
+ ),
248
+ },
249
+ },
250
  )
pages/1_💬_Text_Generation.py CHANGED
@@ -1,33 +1,23 @@
1
- import os
2
  from datetime import datetime
 
3
 
4
  import streamlit as st
5
 
6
- from lib import config, preset, txt2txt_generate
7
-
8
- SERVICE_SESSION = {
9
- "Hugging Face": "api_key_hugging_face",
10
- "Perplexity": "api_key_perplexity",
11
- }
12
-
13
- SESSION_TOKEN = {
14
- "api_key_hugging_face": os.environ.get("HF_TOKEN") or None,
15
- "api_key_perplexity": os.environ.get("PPLX_API_KEY") or None,
16
- }
17
 
18
  # config
19
  st.set_page_config(
20
  page_title=f"{config.title} | Text Generation",
21
- page_icon=config.icon,
22
  layout=config.layout,
23
  )
24
 
25
  # initialize state
26
- if "api_key_hugging_face" not in st.session_state:
27
- st.session_state.api_key_hugging_face = ""
28
 
29
- if "api_key_perplexity" not in st.session_state:
30
- st.session_state.api_key_perplexity = ""
31
 
32
  if "running" not in st.session_state:
33
  st.session_state.running = False
@@ -39,33 +29,41 @@ if "txt2txt_seed" not in st.session_state:
39
  st.session_state.txt2txt_seed = 0
40
 
41
  # sidebar
42
- st.logo("logo.png")
43
  st.sidebar.header("Settings")
 
44
  service = st.sidebar.selectbox(
45
  "Service",
46
- options=SERVICE_SESSION.keys(),
47
- index=0,
48
  disabled=st.session_state.running,
49
  )
50
 
51
  # disable API key input and hide value if set by environment variable (handle empty string value later)
52
- for display_name, session_key in SERVICE_SESSION.items():
53
- if service == display_name:
 
 
 
54
  st.session_state[session_key] = st.sidebar.text_input(
55
  "API Key",
56
  type="password",
57
- value="" if SESSION_TOKEN[session_key] else st.session_state[session_key],
58
- disabled=bool(SESSION_TOKEN[session_key]) or st.session_state.running,
59
- help="Set by environment variable" if SESSION_TOKEN[session_key] else "Cleared on page refresh",
60
  )
61
 
 
 
62
  model = st.sidebar.selectbox(
63
  "Model",
64
- options=config.txt2txt.models[service],
65
- index=config.txt2txt.default_model[service],
66
  disabled=st.session_state.running,
67
- format_func=lambda x: x.split("/")[1] if service == "Hugging Face" else x,
68
  )
 
 
 
69
  system = st.sidebar.text_area(
70
  "System Message",
71
  value=config.txt2txt.default_system,
@@ -74,9 +72,7 @@ system = st.sidebar.text_area(
74
 
75
  # build parameters from preset
76
  parameters = {}
77
- service_key = service.lower().replace(" ", "_")
78
- service_preset = getattr(preset.txt2txt, service_key)
79
- for param in service_preset.parameters:
80
  if param == "max_tokens":
81
  parameters[param] = st.sidebar.slider(
82
  "Max Tokens",
@@ -101,9 +97,9 @@ for param in service_preset.parameters:
101
  parameters[param] = st.sidebar.slider(
102
  "Frequency Penalty",
103
  step=0.1,
104
- value=service_preset.frequency_penalty,
105
- min_value=service_preset.frequency_penalty_min,
106
- max_value=service_preset.frequency_penalty_max,
107
  disabled=st.session_state.running,
108
  help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
109
  )
@@ -180,8 +176,8 @@ if prompt := st.chat_input(
180
  st.markdown(prompt)
181
 
182
  with st.chat_message("assistant"):
183
- session_key = f"api_key_{service.lower().replace(' ', '_')}"
184
- api_key = st.session_state[session_key] or SESSION_TOKEN[session_key]
185
  response = txt2txt_generate(api_key, service, model, parameters)
186
  st.session_state.running = False
187
 
 
 
1
  from datetime import datetime
2
+ from typing import Dict
3
 
4
  import streamlit as st
5
 
6
+ from lib import Txt2TxtPreset, config, preset, txt2txt_generate
 
 
 
 
 
 
 
 
 
 
7
 
8
  # config
9
  st.set_page_config(
10
  page_title=f"{config.title} | Text Generation",
11
+ page_icon=config.logo,
12
  layout=config.layout,
13
  )
14
 
15
  # initialize state
16
+ if "api_key_hf" not in st.session_state:
17
+ st.session_state.api_key_hf = ""
18
 
19
+ if "api_key_pplx" not in st.session_state:
20
+ st.session_state.api_key_pplx = ""
21
 
22
  if "running" not in st.session_state:
23
  st.session_state.running = False
 
29
  st.session_state.txt2txt_seed = 0
30
 
31
  # sidebar
32
+ st.logo(config.logo)
33
  st.sidebar.header("Settings")
34
+
35
  service = st.sidebar.selectbox(
36
  "Service",
37
+ options=preset.txt2txt.keys(),
38
+ format_func=lambda x: config.service[x].name,
39
  disabled=st.session_state.running,
40
  )
41
 
42
  # disable API key input and hide value if set by environment variable (handle empty string value later)
43
+ # for display_name, session_key in SERVICE_SESSION.items():
44
+ for service_id, service_config in config.service.items():
45
+ if service == service_id:
46
+ session_key = f"api_key_{service}"
47
+ api_key = config.service[service].api_key
48
  st.session_state[session_key] = st.sidebar.text_input(
49
  "API Key",
50
  type="password",
51
+ value="" if api_key else st.session_state[session_key],
52
+ disabled=bool(api_key) or st.session_state.running,
53
+ help="Set by environment variable" if api_key else "Cleared on page refresh",
54
  )
55
 
56
+ service_preset: Dict[str, Txt2TxtPreset] = preset.txt2txt[service]
57
+
58
  model = st.sidebar.selectbox(
59
  "Model",
60
+ options=service_preset.keys(),
61
+ format_func=lambda x: service_preset[x].name,
62
  disabled=st.session_state.running,
 
63
  )
64
+
65
+ model_preset = service_preset[model]
66
+
67
  system = st.sidebar.text_area(
68
  "System Message",
69
  value=config.txt2txt.default_system,
 
72
 
73
  # build parameters from preset
74
  parameters = {}
75
+ for param in model_preset.parameters:
 
 
76
  if param == "max_tokens":
77
  parameters[param] = st.sidebar.slider(
78
  "Max Tokens",
 
97
  parameters[param] = st.sidebar.slider(
98
  "Frequency Penalty",
99
  step=0.1,
100
+ value=model_preset.frequency_penalty,
101
+ min_value=model_preset.frequency_penalty_min,
102
+ max_value=model_preset.frequency_penalty_max,
103
  disabled=st.session_state.running,
104
  help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
105
  )
 
176
  st.markdown(prompt)
177
 
178
  with st.chat_message("assistant"):
179
+ session_key = f"api_key_{service}"
180
+ api_key = st.session_state[session_key] or config.service[service].api_key
181
  response = txt2txt_generate(api_key, service, model, parameters)
182
  st.session_state.running = False
183
 
pages/2_🎨_Text_to_Image.py CHANGED
@@ -1,44 +1,25 @@
1
- import os
2
  from datetime import datetime
 
3
 
4
  import streamlit as st
5
 
6
- from lib import config, preset, txt2img_generate
7
-
8
- # The token name is the service in lower_snake_case
9
- SERVICE_SESSION = {
10
- "Black Forest Labs": "api_key_black_forest_labs",
11
- "Fal": "api_key_fal",
12
- "Hugging Face": "api_key_hugging_face",
13
- "Together": "api_key_together",
14
- }
15
-
16
- SESSION_TOKEN = {
17
- "api_key_black_forest_labs": os.environ.get("BFL_API_KEY") or None,
18
- "api_key_fal": os.environ.get("FAL_KEY") or None,
19
- "api_key_hugging_face": os.environ.get("HF_TOKEN") or None,
20
- "api_key_together": os.environ.get("TOGETHER_API_KEY") or None,
21
- }
22
-
23
- PRESET_MODEL = {}
24
- for p in preset.txt2img.presets:
25
- PRESET_MODEL[p.model_id] = p
26
 
27
  st.set_page_config(
28
  page_title=f"{config.title} | Text to Image",
29
- page_icon=config.icon,
30
  layout=config.layout,
31
  )
32
 
33
  # Initialize Streamlit session state
34
- if "api_key_black_forest_labs" not in st.session_state:
35
- st.session_state.api_key_black_forest_labs = ""
36
 
37
  if "api_key_fal" not in st.session_state:
38
  st.session_state.api_key_fal = ""
39
 
40
- if "api_key_hugging_face" not in st.session_state:
41
- st.session_state.api_key_hugging_face = ""
42
 
43
  if "api_key_together" not in st.session_state:
44
  st.session_state.api_key_together = ""
@@ -52,34 +33,42 @@ if "txt2img_messages" not in st.session_state:
52
  if "txt2img_seed" not in st.session_state:
53
  st.session_state.txt2img_seed = 0
54
 
55
- st.logo("logo.png")
56
  st.sidebar.header("Settings")
 
57
  service = st.sidebar.selectbox(
58
  "Service",
59
- options=list(SERVICE_SESSION.keys()),
 
60
  disabled=st.session_state.running,
61
- index=2, # Hugging Face
62
  )
63
 
64
  # Show the API key input for the selected service.
65
  # Disable and hide value if set by environment variable; handle empty string value later.
66
- for display_name, session_key in SERVICE_SESSION.items():
67
- if service == display_name:
 
 
 
68
  st.session_state[session_key] = st.sidebar.text_input(
69
  "API Key",
70
  type="password",
71
- value="" if SESSION_TOKEN[session_key] else st.session_state[session_key],
72
- disabled=bool(SESSION_TOKEN[session_key]) or st.session_state.running,
73
- help="Set by environment variable" if SESSION_TOKEN[session_key] else "Cleared on page refresh",
74
  )
75
 
 
 
76
  model = st.sidebar.selectbox(
77
  "Model",
78
- options=config.txt2img.models[service],
79
- index=config.txt2img.default_model[service],
80
  disabled=st.session_state.running,
81
  )
82
 
 
 
83
  # heading
84
  st.html("""
85
  <h1>Text to Image</h1>
@@ -88,7 +77,6 @@ st.html("""
88
 
89
  # Build parameters from preset by rendering the appropriate input widgets
90
  parameters = {}
91
- model_preset = PRESET_MODEL[model]
92
  for param in model_preset.parameters:
93
  if param == "model":
94
  parameters[param] = model
@@ -262,13 +250,13 @@ if prompt := st.chat_input(
262
  with st.spinner("Running..."):
263
  if model_preset.kwargs:
264
  parameters.update(model_preset.kwargs)
265
- session_key = f"api_key_{service.lower().replace(' ', '_')}"
266
- api_key = st.session_state[session_key] or SESSION_TOKEN[session_key]
267
  image = txt2img_generate(api_key, service, model, prompt, parameters)
268
  st.session_state.running = False
269
 
270
  st.session_state.txt2img_messages.append(
271
- {"role": "user", "content": prompt, "parameters": parameters, "model": PRESET_MODEL[model].name}
272
  )
273
  st.session_state.txt2img_messages.append({"role": "assistant", "content": image})
274
  st.rerun()
 
 
1
  from datetime import datetime
2
+ from typing import Dict
3
 
4
  import streamlit as st
5
 
6
+ from lib import Txt2ImgPreset, config, preset, txt2img_generate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  st.set_page_config(
9
  page_title=f"{config.title} | Text to Image",
10
+ page_icon=config.logo,
11
  layout=config.layout,
12
  )
13
 
14
  # Initialize Streamlit session state
15
+ if "api_key_bfl" not in st.session_state:
16
+ st.session_state.api_key_bfl = ""
17
 
18
  if "api_key_fal" not in st.session_state:
19
  st.session_state.api_key_fal = ""
20
 
21
+ if "api_key_hf" not in st.session_state:
22
+ st.session_state.api_key_hf = ""
23
 
24
  if "api_key_together" not in st.session_state:
25
  st.session_state.api_key_together = ""
 
33
  if "txt2img_seed" not in st.session_state:
34
  st.session_state.txt2img_seed = 0
35
 
36
+ st.logo(config.logo)
37
  st.sidebar.header("Settings")
38
+
39
  service = st.sidebar.selectbox(
40
  "Service",
41
+ options=preset.txt2img.keys(),
42
+ format_func=lambda x: config.service[x].name,
43
  disabled=st.session_state.running,
 
44
  )
45
 
46
  # Show the API key input for the selected service.
47
  # Disable and hide value if set by environment variable; handle empty string value later.
48
+ # for display_name, session_key in SERVICE_SESSION.items():
49
+ for service_id in config.service.keys():
50
+ if service == service_id:
51
+ session_key = f"api_key_{service}"
52
+ api_key = config.service[service].api_key
53
  st.session_state[session_key] = st.sidebar.text_input(
54
  "API Key",
55
  type="password",
56
+ value="" if api_key else st.session_state[session_key],
57
+ disabled=bool(api_key) or st.session_state.running,
58
+ help="Set by environment variable" if api_key else "Cleared on page refresh",
59
  )
60
 
61
+ service_preset: Dict[str, Txt2ImgPreset] = preset.txt2img[service]
62
+
63
  model = st.sidebar.selectbox(
64
  "Model",
65
+ options=service_preset.keys(),
66
+ format_func=lambda x: service_preset[x].name,
67
  disabled=st.session_state.running,
68
  )
69
 
70
+ model_preset = service_preset[model]
71
+
72
  # heading
73
  st.html("""
74
  <h1>Text to Image</h1>
 
77
 
78
  # Build parameters from preset by rendering the appropriate input widgets
79
  parameters = {}
 
80
  for param in model_preset.parameters:
81
  if param == "model":
82
  parameters[param] = model
 
250
  with st.spinner("Running..."):
251
  if model_preset.kwargs:
252
  parameters.update(model_preset.kwargs)
253
+ session_key = f"api_key_{service}"
254
+ api_key = st.session_state[session_key] or config.service[service].api_key
255
  image = txt2img_generate(api_key, service, model, prompt, parameters)
256
  st.session_state.running = False
257
 
258
  st.session_state.txt2img_messages.append(
259
+ {"role": "user", "content": prompt, "parameters": parameters, "model": model_preset.name}
260
  )
261
  st.session_state.txt2img_messages.append({"role": "assistant", "content": image})
262
  st.rerun()