adamelliotfields commited on
Commit
04549f6
Β·
verified Β·
1 Parent(s): 39d4ddb

Add txt2txt clients

Browse files
0_🏠_Home.py CHANGED
@@ -39,10 +39,9 @@ st.html("""
39
  <h1 style="padding: 0; margin-bottom: 0.5rem">API Inference</h1>
40
  <span class="pro-badge">PRO</span>
41
  </div>
42
- <p>Run inference on API endpoints. Hugging Face for now; more coming soon!</p>
43
  """)
44
 
45
- # TODO: categorize tasks by service (SAI, FAL, etc)
46
  # content
47
  st.markdown("## Tasks")
48
  st.page_link("pages/1_πŸ’¬_Text_Generation.py", label="Text Generation", icon="πŸ’¬")
 
39
  <h1 style="padding: 0; margin-bottom: 0.5rem">API Inference</h1>
40
  <span class="pro-badge">PRO</span>
41
  </div>
42
+ <p>Inference on Huggingface, Perplexity, and Fal ⚑</p>
43
  """)
44
 
 
45
  # content
46
  st.markdown("## Tasks")
47
  st.page_link("pages/1_πŸ’¬_Text_Generation.py", label="Text Generation", icon="πŸ’¬")
README.md CHANGED
@@ -46,6 +46,14 @@ STREAMLIT_SERVER_RUN_ON_SAVE=false
46
  STREAMLIT_BROWSER_SERVER_ADDRESS=adamelliotfields-api-inference.hf.space
47
  ```
48
 
 
 
 
 
 
 
 
 
49
  ## Installation
50
 
51
  ```sh
 
46
  STREAMLIT_BROWSER_SERVER_ADDRESS=adamelliotfields-api-inference.hf.space
47
  ```
48
 
49
+ ## Secrets
50
+
51
+ ```bash
52
+ FAL_KEY=...
53
+ HF_TOKEN=...
54
+ PPLX_API_KEY=...
55
+ ```
56
+
57
  ## Installation
58
 
59
  ```sh
lib/__init__.py CHANGED
@@ -1,5 +1,12 @@
1
- from .api import HuggingFaceTxt2ImgAPI
2
  from .config import Config
3
- from .presets import Presets
4
 
5
- __all__ = ["Config", "HuggingFaceTxt2ImgAPI", "Presets"]
 
 
 
 
 
 
 
 
1
+ from .api import HuggingFaceTxt2ImgAPI, HuggingFaceTxt2TxtAPI, PerplexityTxt2TxtAPI
2
  from .config import Config
3
+ from .presets import ModelPresets, ServicePresets
4
 
5
+ __all__ = [
6
+ "Config",
7
+ "HuggingFaceTxt2ImgAPI",
8
+ "HuggingFaceTxt2TxtAPI",
9
+ "ModelPresets",
10
+ "PerplexityTxt2TxtAPI",
11
+ "ServicePresets",
12
+ ]
lib/api.py CHANGED
@@ -2,15 +2,63 @@ import io
2
  from abc import ABC, abstractmethod
3
 
4
  import requests
 
 
5
  from PIL import Image
6
 
7
 
 
 
 
 
 
 
8
  class Txt2ImgAPI(ABC):
9
  @abstractmethod
10
  def generate_image(self, model, prompt, parameters, **kwargs):
11
  pass
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # essentially the same as huggingface_hub's inference client
15
  class HuggingFaceTxt2ImgAPI(Txt2ImgAPI):
16
  def __init__(self, token):
 
2
  from abc import ABC, abstractmethod
3
 
4
  import requests
5
+ import streamlit as st
6
+ from openai import APIError, OpenAI
7
  from PIL import Image
8
 
9
 
10
+ class Txt2TxtAPI(ABC):
11
+ @abstractmethod
12
+ def generate_text(self, model, parameters, **kwargs):
13
+ pass
14
+
15
+
16
  class Txt2ImgAPI(ABC):
17
  @abstractmethod
18
  def generate_image(self, model, prompt, parameters, **kwargs):
19
  pass
20
 
21
 
22
+ class HuggingFaceTxt2TxtAPI(Txt2TxtAPI):
23
+ def __init__(self, api_key):
24
+ self.api_key = api_key
25
+
26
+ def generate_text(self, model, parameters, **kwargs):
27
+ if not self.api_key:
28
+ return "API Key is required."
29
+ client = OpenAI(
30
+ api_key=self.api_key,
31
+ base_url=f"https://api-inference.huggingface.co/models/{model}/v1",
32
+ )
33
+ try:
34
+ stream = client.chat.completions.create(stream=True, model=model, **parameters, **kwargs)
35
+ return st.write_stream(stream)
36
+ except APIError as e:
37
+ return e.message
38
+ except Exception as e:
39
+ return str(e)
40
+
41
+
42
+ class PerplexityTxt2TxtAPI(Txt2TxtAPI):
43
+ def __init__(self, api_key):
44
+ self.api_key = api_key
45
+
46
+ def generate_text(self, model, parameters, **kwargs):
47
+ if not self.api_key:
48
+ return "API Key is required."
49
+ client = OpenAI(
50
+ api_key=self.api_key,
51
+ base_url="https://api.perplexity.ai",
52
+ )
53
+ try:
54
+ stream = client.chat.completions.create(stream=True, model=model, **parameters, **kwargs)
55
+ return st.write_stream(stream)
56
+ except APIError as e:
57
+ return e.message
58
+ except Exception as e:
59
+ return str(e)
60
+
61
+
62
  # essentially the same as huggingface_hub's inference client
63
  class HuggingFaceTxt2ImgAPI(Txt2ImgAPI):
64
  def __init__(self, token):
lib/config.py CHANGED
@@ -19,12 +19,24 @@ Config = SimpleNamespace(
19
  "7:9": (896, 1152),
20
  "4:7": (768, 1344),
21
  },
22
- TXT2TXT_DEFAULT_MODEL=4,
23
- TXT2TXT_MODELS=[
24
- "codellama/codellama-34b-instruct-hf",
25
- "meta-llama/llama-2-13b-chat-hf",
26
- "meta-llama/meta-llama-3.1-405b-instruct-fp8",
27
- "mistralai/mistral-7b-instruct-v0.2",
28
- "nousresearch/nous-hermes-2-mixtral-8x7b-dpo",
29
- ],
 
 
 
 
 
 
 
 
 
 
 
 
30
  )
 
19
  "7:9": (896, 1152),
20
  "4:7": (768, 1344),
21
  },
22
+ TXT2TXT_DEFAULT_SYSTEM="You are a helpful assistant. Be precise and concise.",
23
+ TXT2TXT_DEFAULT_MODEL={
24
+ "Huggingface": 4,
25
+ "Perplexity": 3,
26
+ },
27
+ TXT2TXT_MODELS={
28
+ "Huggingface": [
29
+ "codellama/codellama-34b-instruct-hf",
30
+ "meta-llama/llama-2-13b-chat-hf",
31
+ "meta-llama/meta-llama-3.1-405b-instruct-fp8",
32
+ "mistralai/mistral-7b-instruct-v0.2",
33
+ "nousresearch/nous-hermes-2-mixtral-8x7b-dpo",
34
+ ],
35
+ "Perplexity": [
36
+ "llama-3.1-sonar-small-128k-chat",
37
+ "llama-3.1-sonar-large-128k-chat",
38
+ "llama-3.1-sonar-small-128k-online",
39
+ "llama-3.1-sonar-large-128k-online",
40
+ ],
41
+ },
42
  )
lib/presets.py CHANGED
@@ -1,6 +1,22 @@
1
  from types import SimpleNamespace
2
 
3
- Presets = SimpleNamespace(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  FLUX_1_DEV={
5
  "name": "FLUX.1 Dev",
6
  "num_inference_steps": 30,
 
1
  from types import SimpleNamespace
2
 
3
+ ServicePresets = SimpleNamespace(
4
+ Huggingface={
5
+ # every service has model and system messages
6
+ "frequency_penalty": 0.0,
7
+ "frequency_penalty_min": -2.0,
8
+ "frequency_penalty_max": 2.0,
9
+ "parameters": ["max_tokens", "temperature", "frequency_penalty", "seed"],
10
+ },
11
+ Perplexity={
12
+ "frequency_penalty": 1.0,
13
+ "frequency_penalty_min": 1.0,
14
+ "frequency_penalty_max": 2.0,
15
+ "parameters": ["max_tokens", "temperature", "frequency_penalty"],
16
+ },
17
+ )
18
+
19
+ ModelPresets = SimpleNamespace(
20
  FLUX_1_DEV={
21
  "name": "FLUX.1 Dev",
22
  "num_inference_steps": 30,
pages/1_πŸ’¬_Text_Generation.py CHANGED
@@ -2,23 +2,20 @@ import os
2
  from datetime import datetime
3
 
4
  import streamlit as st
5
- from openai import APIError, OpenAI
6
 
7
- from lib import Config
8
 
9
- # TODO: key input and store in cache_data
10
- # api key
11
- HF_TOKEN = os.environ.get("HF_TOKEN")
12
 
13
 
14
- # TODO: eventually support different APIs like OpenAI and Perplexity
15
  @st.cache_resource
16
- def get_chat_client(api_key, model):
17
- client = OpenAI(
18
- api_key=api_key,
19
- base_url=f"https://api-inference.huggingface.co/models/{model}/v1",
20
- )
21
- return client
22
 
23
 
24
  # config
@@ -29,6 +26,9 @@ st.set_page_config(
29
  )
30
 
31
  # initialize state
 
 
 
32
  if "txt2txt_messages" not in st.session_state:
33
  st.session_state.txt2txt_messages = []
34
 
@@ -38,52 +38,83 @@ if "txt2txt_prompt" not in st.session_state:
38
  # sidebar
39
  st.logo("logo.svg")
40
  st.sidebar.header("Settings")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  model = st.sidebar.selectbox(
42
  "Model",
43
- placeholder="Select a model",
44
- format_func=lambda x: x.split("/")[1],
45
- index=Config.TXT2TXT_DEFAULT_MODEL,
46
- options=Config.TXT2TXT_MODELS,
47
- )
48
- max_tokens = st.sidebar.slider(
49
- "Max Tokens",
50
- min_value=512,
51
- max_value=4096,
52
- value=512,
53
- step=128,
54
- help="Maximum number of tokens to generate (default: 512)",
55
- )
56
- temperature = st.sidebar.slider(
57
- "Temperature",
58
- min_value=0.0,
59
- max_value=2.0,
60
- value=1.0,
61
- step=0.1,
62
- help="Used to modulate the next token probabilities (default: 1.0)",
63
- )
64
- frequency_penalty = st.sidebar.slider(
65
- "Frequency Penalty",
66
- min_value=-2.0,
67
- max_value=2.0,
68
- value=0.0,
69
- step=0.1,
70
- help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
71
- )
72
- seed = st.sidebar.number_input(
73
- "Seed",
74
- min_value=-1,
75
- max_value=(1 << 53) - 1,
76
- value=-1,
77
- help="Make a best effort to sample deterministically (default: -1)",
78
  )
79
  system = st.sidebar.text_area(
80
  "System Message",
81
- value="You are a helpful assistant. Be precise and concise.",
 
82
  )
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  # random seed
85
- if seed < 0:
86
- seed = int(datetime.now().timestamp() * 1e6) % (1 << 53)
87
 
88
  # heading
89
  st.html("""
@@ -136,7 +167,10 @@ else:
136
  button_container = None
137
 
138
  # chat input
139
- if prompt := st.chat_input("What would you like to know?"):
 
 
 
140
  st.session_state.txt2txt_prompt = prompt
141
 
142
  if st.session_state.txt2txt_prompt:
@@ -149,24 +183,19 @@ if st.session_state.txt2txt_prompt:
149
  messages = [{"role": "system", "content": system}]
150
  messages.extend([{"role": m["role"], "content": m["content"]} for m in st.session_state.txt2txt_messages])
151
  messages.append({"role": "user", "content": st.session_state.txt2txt_prompt})
 
152
 
153
  with st.chat_message("assistant"):
154
- try:
155
- client = get_chat_client(HF_TOKEN, model)
156
- stream = client.chat.completions.create(
157
- frequency_penalty=frequency_penalty,
158
- temperature=temperature,
159
- max_tokens=max_tokens,
160
- messages=messages,
161
- model=model,
162
- stream=True,
163
- seed=seed,
164
- )
165
- response = st.write_stream(stream)
166
- except APIError as e:
167
- response = e.message
168
- except Exception as e:
169
- response = str(e)
170
 
171
  st.session_state.txt2txt_messages.append({"role": "user", "content": st.session_state.txt2txt_prompt})
172
  st.session_state.txt2txt_messages.append({"role": "assistant", "content": response})
 
2
  from datetime import datetime
3
 
4
  import streamlit as st
 
5
 
6
+ from lib import Config, HuggingFaceTxt2TxtAPI, PerplexityTxt2TxtAPI, ServicePresets
7
 
8
+ HF_TOKEN = os.environ.get("HF_TOKEN") or None
9
+ PPLX_API_KEY = os.environ.get("PPLX_API_KEY") or None
 
10
 
11
 
 
12
  @st.cache_resource
13
+ def get_txt2txt_api(service="Huggingface", api_key=None):
14
+ if service == "Huggingface":
15
+ return HuggingFaceTxt2TxtAPI(api_key)
16
+ if service == "Perplexity":
17
+ return PerplexityTxt2TxtAPI(api_key)
18
+ return None
19
 
20
 
21
  # config
 
26
  )
27
 
28
  # initialize state
29
+ if "txt2txt_running" not in st.session_state:
30
+ st.session_state.txt2txt_running = False
31
+
32
  if "txt2txt_messages" not in st.session_state:
33
  st.session_state.txt2txt_messages = []
34
 
 
38
  # sidebar
39
  st.logo("logo.svg")
40
  st.sidebar.header("Settings")
41
+ service = st.sidebar.selectbox(
42
+ "Service",
43
+ options=["Huggingface", "Perplexity"],
44
+ index=0,
45
+ disabled=st.session_state.txt2txt_running,
46
+ )
47
+
48
+ # hide key input if environment variables are set
49
+ if (service == "Huggingface" and HF_TOKEN is None) or (service == "Perplexity" and PPLX_API_KEY is None):
50
+ api_key = st.sidebar.text_input(
51
+ "API Key",
52
+ value="",
53
+ type="password",
54
+ help="Cleared on page refresh",
55
+ disabled=st.session_state.txt2txt_running,
56
+ )
57
+
58
  model = st.sidebar.selectbox(
59
  "Model",
60
+ format_func=lambda x: x.split("/")[1] if service == "Huggingface" else x,
61
+ index=Config.TXT2TXT_DEFAULT_MODEL[service],
62
+ options=Config.TXT2TXT_MODELS[service],
63
+ disabled=st.session_state.txt2txt_running,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  )
65
  system = st.sidebar.text_area(
66
  "System Message",
67
+ value=Config.TXT2TXT_DEFAULT_SYSTEM,
68
+ disabled=st.session_state.txt2txt_running,
69
  )
70
 
71
+ # build parameters from preset
72
+ parameters = {}
73
+ preset = getattr(ServicePresets, service)
74
+ for param in preset["parameters"]:
75
+ if param == "max_tokens":
76
+ parameters[param] = st.sidebar.slider(
77
+ "Max Tokens",
78
+ min_value=512,
79
+ max_value=4096,
80
+ value=512,
81
+ step=128,
82
+ help="Maximum number of tokens to generate (default: 512)",
83
+ disabled=st.session_state.txt2txt_running,
84
+ )
85
+ if param == "temperature":
86
+ parameters[param] = st.sidebar.slider(
87
+ "Temperature",
88
+ min_value=0.0,
89
+ max_value=2.0,
90
+ value=1.0,
91
+ step=0.1,
92
+ help="Used to modulate the next token probabilities (default: 1.0)",
93
+ disabled=st.session_state.txt2txt_running,
94
+ )
95
+ if param == "frequency_penalty":
96
+ parameters[param] = st.sidebar.slider(
97
+ "Frequency Penalty",
98
+ min_value=preset["frequency_penalty_min"],
99
+ max_value=preset["frequency_penalty_max"],
100
+ value=preset["frequency_penalty"],
101
+ step=0.1,
102
+ help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
103
+ disabled=st.session_state.txt2txt_running,
104
+ )
105
+ if param == "seed":
106
+ parameters[param] = st.sidebar.number_input(
107
+ "Seed",
108
+ min_value=-1,
109
+ max_value=(1 << 53) - 1,
110
+ value=-1,
111
+ help="Make a best effort to sample deterministically (default: -1)",
112
+ disabled=st.session_state.txt2txt_running,
113
+ )
114
+
115
  # random seed
116
+ if parameters.get("seed", 0) < 0:
117
+ parameters["seed"] = int(datetime.now().timestamp() * 1e6) % (1 << 53)
118
 
119
  # heading
120
  st.html("""
 
167
  button_container = None
168
 
169
  # chat input
170
+ if prompt := st.chat_input(
171
+ "What would you like to know?",
172
+ on_submit=lambda: setattr(st.session_state, "txt2txt_running", True),
173
+ ):
174
  st.session_state.txt2txt_prompt = prompt
175
 
176
  if st.session_state.txt2txt_prompt:
 
183
  messages = [{"role": "system", "content": system}]
184
  messages.extend([{"role": m["role"], "content": m["content"]} for m in st.session_state.txt2txt_messages])
185
  messages.append({"role": "user", "content": st.session_state.txt2txt_prompt})
186
+ parameters["messages"] = messages
187
 
188
  with st.chat_message("assistant"):
189
+ # allow environment variables in development for convenience
190
+ if service == "Huggingface" and HF_TOKEN is not None:
191
+ key = HF_TOKEN
192
+ elif service == "Perplexity" and PPLX_API_KEY is not None:
193
+ key = PPLX_API_KEY
194
+ else:
195
+ key = api_key
196
+ api = get_txt2txt_api(service, key)
197
+ response = api.generate_text(model, parameters)
198
+ st.session_state.txt2txt_running = False
 
 
 
 
 
 
199
 
200
  st.session_state.txt2txt_messages.append({"role": "user", "content": st.session_state.txt2txt_prompt})
201
  st.session_state.txt2txt_messages.append({"role": "assistant", "content": response})
pages/2_🎨_Text_to_Image.py CHANGED
@@ -3,7 +3,7 @@ from datetime import datetime
3
 
4
  import streamlit as st
5
 
6
- from lib import Config, HuggingFaceTxt2ImgAPI, Presets
7
 
8
  # TODO: key input and store in cache_data
9
  # TODO: API dropdown; changes available models
@@ -11,9 +11,9 @@ HF_TOKEN = os.environ.get("HF_TOKEN")
11
  API_URL = "https://api-inference.huggingface.co/models"
12
  HEADERS = {"Authorization": f"Bearer {HF_TOKEN}", "X-Wait-For-Model": "true", "X-Use-Cache": "false"}
13
  PRESET_MODEL = {
14
- "black-forest-labs/flux.1-dev": Presets.FLUX_1_DEV,
15
- "black-forest-labs/flux.1-schnell": Presets.FLUX_1_SCHNELL,
16
- "stabilityai/stable-diffusion-xl-base-1.0": Presets.STABLE_DIFFUSION_XL,
17
  }
18
 
19
 
 
3
 
4
  import streamlit as st
5
 
6
+ from lib import Config, HuggingFaceTxt2ImgAPI, ModelPresets
7
 
8
  # TODO: key input and store in cache_data
9
  # TODO: API dropdown; changes available models
 
11
  API_URL = "https://api-inference.huggingface.co/models"
12
  HEADERS = {"Authorization": f"Bearer {HF_TOKEN}", "X-Wait-For-Model": "true", "X-Use-Cache": "false"}
13
  PRESET_MODEL = {
14
+ "black-forest-labs/flux.1-dev": ModelPresets.FLUX_1_DEV,
15
+ "black-forest-labs/flux.1-schnell": ModelPresets.FLUX_1_SCHNELL,
16
+ "stabilityai/stable-diffusion-xl-base-1.0": ModelPresets.STABLE_DIFFUSION_XL,
17
  }
18
 
19