marijnfs Marijn Stollenga Marijn Stollenga winglian commited on
Commit
3367fca
·
unverified ·
1 Parent(s): 1ac8998

Gradio configuration parameters (#1591)

Browse files

* Gradio Configuration Settings

* Making various Gradio variables configurable instead of hardcoded

* Remove overwriting behavour of 'default tokens' that breaks tokenizer for llama3

* Fix type of gradio_temperature

* revert un-necessary change and lint

---------

Co-authored-by: Marijn Stollenga <stollenga@imfusion.de>
Co-authored-by: Marijn Stollenga <stollenga@imfusion.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>

src/axolotl/cli/__init__.py CHANGED
@@ -264,8 +264,8 @@ def do_inference_gradio(
264
  with torch.no_grad():
265
  generation_config = GenerationConfig(
266
  repetition_penalty=1.1,
267
- max_new_tokens=1024,
268
- temperature=0.9,
269
  top_p=0.95,
270
  top_k=40,
271
  bos_token_id=tokenizer.bos_token_id,
@@ -300,7 +300,13 @@ def do_inference_gradio(
300
  outputs="text",
301
  title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
302
  )
303
- demo.queue().launch(show_api=False, share=True)
 
 
 
 
 
 
304
 
305
 
306
  def choose_config(path: Path):
 
264
  with torch.no_grad():
265
  generation_config = GenerationConfig(
266
  repetition_penalty=1.1,
267
+ max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
268
+ temperature=cfg.get("gradio_temperature", 0.9),
269
  top_p=0.95,
270
  top_k=40,
271
  bos_token_id=tokenizer.bos_token_id,
 
300
  outputs="text",
301
  title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
302
  )
303
+
304
+ demo.queue().launch(
305
+ show_api=False,
306
+ share=cfg.get("gradio_share", True),
307
+ server_name=cfg.get("gradio_server_name", "127.0.0.1"),
308
+ server_port=cfg.get("gradio_server_port", None),
309
+ )
310
 
311
 
312
  def choose_config(path: Path):
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -409,6 +409,17 @@ class WandbConfig(BaseModel):
409
  return data
410
 
411
 
 
 
 
 
 
 
 
 
 
 
 
412
  # pylint: disable=too-many-public-methods,too-many-ancestors
413
  class AxolotlInputConfig(
414
  ModelInputConfig,
@@ -419,6 +430,7 @@ class AxolotlInputConfig(
419
  WandbConfig,
420
  MLFlowConfig,
421
  LISAConfig,
 
422
  RemappedParameters,
423
  DeprecatedParameters,
424
  BaseModel,
 
409
  return data
410
 
411
 
412
+ class GradioConfig(BaseModel):
413
+ """Gradio configuration subset"""
414
+
415
+ gradio_title: Optional[str] = None
416
+ gradio_share: Optional[bool] = None
417
+ gradio_server_name: Optional[str] = None
418
+ gradio_server_port: Optional[int] = None
419
+ gradio_max_new_tokens: Optional[int] = None
420
+ gradio_temperature: Optional[float] = None
421
+
422
+
423
  # pylint: disable=too-many-public-methods,too-many-ancestors
424
  class AxolotlInputConfig(
425
  ModelInputConfig,
 
430
  WandbConfig,
431
  MLFlowConfig,
432
  LISAConfig,
433
+ GradioConfig,
434
  RemappedParameters,
435
  DeprecatedParameters,
436
  BaseModel,