Gregor Betz commited on
Commit
312035b
1 Parent(s): 7cf1ffa
Files changed (4) hide show
  1. app.py +21 -74
  2. backend/config.py +78 -0
  3. config.yaml +10 -0
  4. requirements.txt +1 -0
app.py CHANGED
@@ -1,42 +1,30 @@
1
  from __future__ import annotations
2
 
3
  import asyncio
4
- import copy
5
  import logging
6
- import os
7
  import uuid
 
8
 
9
  import gradio as gr # type: ignore
10
 
11
  from logikon.backends.chat_models_with_grammar import create_logits_model, LogitsModel, LLMBackends
12
  from logikon.guides.proscons.recursive_balancing_guide import RecursiveBalancingGuide, RecursiveBalancingGuideConfig
13
 
 
14
  from backend.messages_processing import add_details, history_to_langchain_format
15
  from backend.svg_processing import postprocess_svg
16
 
17
  logging.basicConfig(level=logging.DEBUG)
18
 
 
 
 
 
 
 
 
 
19
 
20
- # Default client
21
- INFERENCE_SERVER_URL = "https://api-inference.huggingface.co/models/{model_id}"
22
- MODEL_ID = "HuggingFaceH4/zephyr-7b-beta"
23
- CLIENT_MODEL_KWARGS = {
24
- "max_tokens": 800,
25
- "temperature": 0.6,
26
- }
27
-
28
- GUIDE_KWARGS = {
29
- "expert_model": "HuggingFaceH4/zephyr-7b-beta",
30
- # "meta-llama/Meta-Llama-3.1-70B-Instruct",
31
- "inference_server_url": "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta",
32
- # "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-70B-Instruct",
33
- "llm_backend": "HFChat",
34
- "classifier_kwargs": {
35
- "model_id": "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
36
- "inference_server_url": "https://api-inference.huggingface.co/models/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
37
- "batch_size": 8,
38
- },
39
- }
40
 
41
  EXAMPLES = [
42
  ("We're a nature-loving family with three kids, have some money left, and no plans "
@@ -106,31 +94,15 @@ CHATBOT_INSTRUCTIONS = (
106
  )
107
 
108
 
109
- logging.info(f"Reasoning guide expert model is {GUIDE_KWARGS['expert_model']}.")
110
-
111
-
112
  def new_conversation_id():
113
  conversation_id = str(uuid.uuid4())
114
  print(f"New conversation with conversation ID: {conversation_id}")
115
  return conversation_id
116
 
117
 
118
- def setup_client_llm(
119
- client_model_id,
120
- client_inference_url,
121
- client_inference_token,
122
- client_backend,
123
- client_temperature,
124
- ) -> LogitsModel | None:
125
  try:
126
- llm = create_logits_model(
127
- model_id=client_model_id,
128
- inference_server_url=client_inference_url,
129
- api_key=client_inference_token if client_inference_token else os.getenv("HF_TOKEN"),
130
- llm_backend=client_backend,
131
- max_tokens=CLIENT_MODEL_KWARGS["max_tokens"],
132
- temperature=client_temperature,
133
- )
134
  except Exception as e:
135
  logging.error(f"When setting up client llm: Error: {e}")
136
  return False
@@ -155,26 +127,17 @@ def add_message(history, message, conversation_id):
155
 
156
  async def bot(
157
  history,
158
- client_model_id,
159
- client_inference_url,
160
- client_inference_token,
161
- client_backend,
162
- client_temperature,
163
  conversation_id,
164
  progress=gr.Progress(),
165
  ):
166
 
167
- client_llm = setup_client_llm(
168
- client_model_id,
169
- client_inference_url,
170
- client_inference_token,
171
- client_backend,
172
- client_temperature,
173
- )
174
 
175
  if not client_llm:
176
  raise gr.Error(
177
- "Failed to set up tourist LLM.",
178
  duration=0
179
  )
180
 
@@ -184,10 +147,6 @@ async def bot(
184
  # use guide always and exclusively at first turn
185
  if len(history_langchain_format) <= 1:
186
 
187
- guide_kwargs = copy.deepcopy(GUIDE_KWARGS)
188
- guide_kwargs["api_key"] = os.getenv("HF_TOKEN") # expert model api key
189
- guide_kwargs["classifier_kwargs"]["api_key"] = os.getenv("HF_TOKEN") # classifier api key
190
-
191
  guide_config = RecursiveBalancingGuideConfig(**guide_kwargs)
192
  guide = RecursiveBalancingGuide(tourist_llm=client_llm, config=guide_config)
193
 
@@ -244,6 +203,9 @@ with gr.Blocks() as demo:
244
  conversation_id = gr.State(str(uuid.uuid4()))
245
  tos_approved = gr.State(False)
246
 
 
 
 
247
 
248
  with gr.Tab(label="Chatbot", visible=False) as chatbot_tab:
249
 
@@ -258,29 +220,14 @@ with gr.Blocks() as demo:
258
  clear = gr.ClearButton([chat_input, chatbot])
259
  gr.Examples([{"text": e, "files":[]} for e in EXAMPLES], chat_input)
260
 
261
- # configs
262
- with gr.Accordion("Client LLM Configuration", open=False):
263
- gr.Markdown("Configure your client LLM that underpins this chatbot and is guided through the reasoning process.")
264
- with gr.Row():
265
- with gr.Column(2):
266
- client_backend = gr.Dropdown(choices=[b.value for b in LLMBackends], value=LLMBackends.HFChat.value, label="LLM Inference Backend")
267
- client_model_id = gr.Textbox(MODEL_ID, label="Model ID", max_lines=1)
268
- client_inference_url = gr.Textbox(INFERENCE_SERVER_URL.format(model_id=MODEL_ID), label="Inference Server URL", max_lines=1)
269
- client_inference_token = gr.Textbox("", label="Inference Token", max_lines=1, placeholder="Not required with HF Inference Api (default)", type="password")
270
- with gr.Column(1):
271
- client_temperature = gr.Slider(0, 1.0, value = CLIENT_MODEL_KWARGS["temperature"], label="Temperature")
272
-
273
  # logic
274
  chat_msg = chat_input.submit(add_message, [chatbot, chat_input, conversation_id], [chatbot, chat_input, conversation_id])
275
  bot_msg = chat_msg.then(
276
  bot,
277
  [
278
  chatbot,
279
- client_model_id,
280
- client_inference_url,
281
- client_inference_token,
282
- client_backend,
283
- client_temperature,
284
  conversation_id
285
  ],
286
  chatbot,
 
1
  from __future__ import annotations
2
 
3
  import asyncio
 
4
  import logging
 
5
  import uuid
6
+ import yaml
7
 
8
  import gradio as gr # type: ignore
9
 
10
  from logikon.backends.chat_models_with_grammar import create_logits_model, LogitsModel, LLMBackends
11
  from logikon.guides.proscons.recursive_balancing_guide import RecursiveBalancingGuide, RecursiveBalancingGuideConfig
12
 
13
+ from backend.config import process_config
14
  from backend.messages_processing import add_details, history_to_langchain_format
15
  from backend.svg_processing import postprocess_svg
16
 
17
  logging.basicConfig(level=logging.DEBUG)
18
 
19
+ with open("config.yaml") as stream:
20
+ try:
21
+ DEMO_CONFIG = yaml.safe_load(stream)
22
+ logging.debug(f"Config: {DEMO_CONFIG}")
23
+ except yaml.YAMLError as exc:
24
+ logging.error(f"Error loading config: {exc}")
25
+ raise exc
26
+
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  EXAMPLES = [
30
  ("We're a nature-loving family with three kids, have some money left, and no plans "
 
94
  )
95
 
96
 
 
 
 
97
  def new_conversation_id():
98
  conversation_id = str(uuid.uuid4())
99
  print(f"New conversation with conversation ID: {conversation_id}")
100
  return conversation_id
101
 
102
 
103
+ def setup_client_llm(**client_kwargs) -> LogitsModel | None:
 
 
 
 
 
 
104
  try:
105
+ llm = create_logits_model(**client_kwargs)
 
 
 
 
 
 
 
106
  except Exception as e:
107
  logging.error(f"When setting up client llm: Error: {e}")
108
  return False
 
127
 
128
  async def bot(
129
  history,
130
+ client_kwargs,
131
+ guide_kwargs,
 
 
 
132
  conversation_id,
133
  progress=gr.Progress(),
134
  ):
135
 
136
+ client_llm = setup_client_llm(**client_kwargs)
 
 
 
 
 
 
137
 
138
  if not client_llm:
139
  raise gr.Error(
140
+ "Failed to set up client LLM.",
141
  duration=0
142
  )
143
 
 
147
  # use guide always and exclusively at first turn
148
  if len(history_langchain_format) <= 1:
149
 
 
 
 
 
150
  guide_config = RecursiveBalancingGuideConfig(**guide_kwargs)
151
  guide = RecursiveBalancingGuide(tourist_llm=client_llm, config=guide_config)
152
 
 
203
  conversation_id = gr.State(str(uuid.uuid4()))
204
  tos_approved = gr.State(False)
205
 
206
+ client_kwargs, guide_kwargs = process_config(DEMO_CONFIG)
207
+ logging.info(f"Reasoning guide expert model is {guide_kwargs['expert_model']}.")
208
+
209
 
210
  with gr.Tab(label="Chatbot", visible=False) as chatbot_tab:
211
 
 
220
  clear = gr.ClearButton([chat_input, chatbot])
221
  gr.Examples([{"text": e, "files":[]} for e in EXAMPLES], chat_input)
222
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  # logic
224
  chat_msg = chat_input.submit(add_message, [chatbot, chat_input, conversation_id], [chatbot, chat_input, conversation_id])
225
  bot_msg = chat_msg.then(
226
  bot,
227
  [
228
  chatbot,
229
+ client_kwargs,
230
+ guide_kwargs,
 
 
 
231
  conversation_id
232
  ],
233
  chatbot,
backend/config.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Default client
4
+ INFERENCE_SERVER_URL = "https://api-inference.huggingface.co/models/{model_id}"
5
+ MODEL_ID = "HuggingFaceH4/zephyr-7b-beta"
6
+ CLIENT_MODEL_KWARGS = {
7
+ "max_tokens": 800,
8
+ "temperature": 0.6,
9
+ }
10
+
11
+ GUIDE_KWARGS = {
12
+ "expert_model": "HuggingFaceH4/zephyr-7b-beta",
13
+ # "meta-llama/Meta-Llama-3.1-70B-Instruct",
14
+ "inference_server_url": "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta",
15
+ # "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-70B-Instruct",
16
+ "llm_backend": "HFChat",
17
+ "classifier_kwargs": {
18
+ "model_id": "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
19
+ "inference_server_url": "https://api-inference.huggingface.co/models/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
20
+ "batch_size": 8,
21
+ },
22
+ }
23
+
24
+
25
+ def process_config(config):
26
+ if "HF_TOKEN" not in os.environ:
27
+ raise ValueError("Please set the HF_TOKEN environment variable.")
28
+ client_kwargs = {}
29
+ if "client_llm" in config:
30
+ if "model_id" in config["client_llm"]:
31
+ client_kwargs["model_id"] = config["client_llm"]["model_id"]
32
+ else:
33
+ raise ValueError("config.yaml is missing client model_id.")
34
+ if "url" in config["client_llm"]:
35
+ client_kwargs["inference_server_url"] = config["client_llm"]["url"]
36
+ else:
37
+ raise ValueError("config.yaml is missing client url.")
38
+ client_kwargs["api_key"] = os.getenv("HF_TOKEN")
39
+ client_kwargs["llm_backend"] = "HFChat"
40
+ client_kwargs["temperature"] = CLIENT_MODEL_KWARGS["temperature"]
41
+ client_kwargs["max_tokens"] = CLIENT_MODEL_KWARGS["max_tokens"]
42
+ else:
43
+ raise ValueError("config.yaml is missing client_llm settings.")
44
+
45
+ guide_kwargs = {}
46
+ if "expert_llm" in config:
47
+ if "model_id" in config["expert_llm"]:
48
+ guide_kwargs["expert_model"] = config["expert_llm"]["model_id"]
49
+ else:
50
+ raise ValueError("config.yaml is missing expert model_id.")
51
+ if "url" in config["expert_llm"]:
52
+ guide_kwargs["inference_server_url"] = config["expert_llm"]["url"]
53
+ else:
54
+ raise ValueError("config.yaml is missing expert url.")
55
+ guide_kwargs["api_key"] = os.getenv("HF_TOKEN")
56
+ guide_kwargs["llm_backend"] = "HFChat"
57
+ else:
58
+ raise ValueError("config.yaml is missing expert_llm settings.")
59
+
60
+ if "classifier_llm" in config:
61
+ if "model_id" in config["classifier_llm"]:
62
+ guide_kwargs["classifier_kwargs"]["model_id"] = config["classifier_llm"]["model_id"]
63
+ else:
64
+ raise ValueError("config.yaml is missing classifier model_id.")
65
+ if "url" in config["classifier_llm"]:
66
+ guide_kwargs["classifier_kwargs"]["inference_server_url"] = config["classifier_llm"]["url"]
67
+ else:
68
+ raise ValueError("config.yaml is missing classifier url.")
69
+ if "batch_size" in config["classifier_llm"]:
70
+ guide_kwargs["classifier_kwargs"]["batch_size"] = config["classifier_llm"]["batch_size"]
71
+ else:
72
+ raise ValueError("config.yaml is missing classifier batch_size.")
73
+ guide_kwargs["classifier_kwargs"]["api_key"] = os.getenv("HF_TOKEN") # classifier api key
74
+ else:
75
+ raise ValueError("config.yaml is missing classifier_llm settings.")
76
+
77
+ return client_kwargs, guide_kwargs
78
+
config.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ client_llm:
2
+ url: "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
3
+ model_id: "HuggingFaceH4/zephyr-7b-beta"
4
+ expert_llm:
5
+ url: "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
6
+ model_id: "HuggingFaceH4/zephyr-7b-beta"
7
+ classifier_llm:
8
+ model_id: "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
9
+ inference_server_url: "https://api-inference.huggingface.co/models/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
10
+ batch_size: 8,
requirements.txt CHANGED
@@ -1 +1,2 @@
 
1
  git+https://github.com/logikon-ai/logikon@v0.2.0
 
1
+ pyyaml
2
  git+https://github.com/logikon-ai/logikon@v0.2.0