Gradio configuration parameters (#1591)
Browse files* Gradio Configuration Settings
* Making various Gradio variables configurable instead of hardcoded
* Remove overwriting behavour of 'default tokens' that breaks tokenizer for llama3
* Fix type of gradio_temperature
* revert un-necessary change and lint
---------
Co-authored-by: Marijn Stollenga <stollenga@imfusion.de>
Co-authored-by: Marijn Stollenga <stollenga@imfusion.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
src/axolotl/cli/__init__.py
CHANGED
@@ -264,8 +264,8 @@ def do_inference_gradio(
|
|
264 |
with torch.no_grad():
|
265 |
generation_config = GenerationConfig(
|
266 |
repetition_penalty=1.1,
|
267 |
-
max_new_tokens=1024,
|
268 |
-
temperature=0.9,
|
269 |
top_p=0.95,
|
270 |
top_k=40,
|
271 |
bos_token_id=tokenizer.bos_token_id,
|
@@ -300,7 +300,13 @@ def do_inference_gradio(
|
|
300 |
outputs="text",
|
301 |
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
302 |
)
|
303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
|
305 |
|
306 |
def choose_config(path: Path):
|
|
|
264 |
with torch.no_grad():
|
265 |
generation_config = GenerationConfig(
|
266 |
repetition_penalty=1.1,
|
267 |
+
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
|
268 |
+
temperature=cfg.get("gradio_temperature", 0.9),
|
269 |
top_p=0.95,
|
270 |
top_k=40,
|
271 |
bos_token_id=tokenizer.bos_token_id,
|
|
|
300 |
outputs="text",
|
301 |
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
302 |
)
|
303 |
+
|
304 |
+
demo.queue().launch(
|
305 |
+
show_api=False,
|
306 |
+
share=cfg.get("gradio_share", True),
|
307 |
+
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
308 |
+
server_port=cfg.get("gradio_server_port", None),
|
309 |
+
)
|
310 |
|
311 |
|
312 |
def choose_config(path: Path):
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
@@ -409,6 +409,17 @@ class WandbConfig(BaseModel):
|
|
409 |
return data
|
410 |
|
411 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
# pylint: disable=too-many-public-methods,too-many-ancestors
|
413 |
class AxolotlInputConfig(
|
414 |
ModelInputConfig,
|
@@ -419,6 +430,7 @@ class AxolotlInputConfig(
|
|
419 |
WandbConfig,
|
420 |
MLFlowConfig,
|
421 |
LISAConfig,
|
|
|
422 |
RemappedParameters,
|
423 |
DeprecatedParameters,
|
424 |
BaseModel,
|
|
|
409 |
return data
|
410 |
|
411 |
|
412 |
+
class GradioConfig(BaseModel):
|
413 |
+
"""Gradio configuration subset"""
|
414 |
+
|
415 |
+
gradio_title: Optional[str] = None
|
416 |
+
gradio_share: Optional[bool] = None
|
417 |
+
gradio_server_name: Optional[str] = None
|
418 |
+
gradio_server_port: Optional[int] = None
|
419 |
+
gradio_max_new_tokens: Optional[int] = None
|
420 |
+
gradio_temperature: Optional[float] = None
|
421 |
+
|
422 |
+
|
423 |
# pylint: disable=too-many-public-methods,too-many-ancestors
|
424 |
class AxolotlInputConfig(
|
425 |
ModelInputConfig,
|
|
|
430 |
WandbConfig,
|
431 |
MLFlowConfig,
|
432 |
LISAConfig,
|
433 |
+
GradioConfig,
|
434 |
RemappedParameters,
|
435 |
DeprecatedParameters,
|
436 |
BaseModel,
|