zetavg commited on
Commit
d8a2e9e
1 Parent(s): 0e73bed

support setting auth and load config from a custom file path

Browse files
Files changed (2) hide show
  1. app.py +22 -7
  2. llama_lora/config.py +3 -0
app.py CHANGED
@@ -23,12 +23,14 @@ def main(
23
  server_name: str = "127.0.0.1",
24
  share: bool = False,
25
  skip_loading_base_model: bool = False,
 
26
  load_8bit: Union[bool, None] = None,
27
  ui_show_sys_info: Union[bool, None] = None,
28
  ui_dev_mode: Union[bool, None] = None,
29
  wandb_api_key: Union[str, None] = None,
30
  wandb_project: Union[str, None] = None,
31
  timezone: Union[str, None] = None,
 
32
  ):
33
  '''
34
  Start the LLaMA-LoRA Tuner UI.
@@ -45,15 +47,17 @@ def main(
45
  :param wandb_project: The default project name for Weights & Biases. Setting either this or `wandb_api_key` will enable Weights & Biases.
46
  '''
47
 
48
- config_from_file = read_yaml_config()
49
  if config_from_file:
50
  for key, value in config_from_file.items():
51
  if key == "server_name":
52
  server_name = value
53
  continue
54
  if not hasattr(Config, key):
55
- available_keys = [k for k in vars(Config) if not k.startswith('__')]
56
- raise ValueError(f"Invalid config key '{key}' in config.yaml. Available keys: {', '.join(available_keys)}")
 
 
57
  setattr(Config, key, value)
58
 
59
  if base_model is not None:
@@ -71,6 +75,12 @@ def main(
71
  if load_8bit is not None:
72
  Config.load_8bit = load_8bit
73
 
 
 
 
 
 
 
74
  if wandb_api_key is not None:
75
  Config.wandb_api_key = wandb_api_key
76
 
@@ -106,12 +116,17 @@ def main(
106
  main_page()
107
 
108
  demo.queue(concurrency_count=1).launch(
109
- server_name=server_name, share=share)
 
 
 
 
110
 
111
 
112
- def read_yaml_config():
113
- app_dir = os.path.dirname(os.path.abspath(__file__))
114
- config_path = os.path.join(app_dir, 'config.yaml')
 
115
 
116
  if not os.path.exists(config_path):
117
  return None
 
23
  server_name: str = "127.0.0.1",
24
  share: bool = False,
25
  skip_loading_base_model: bool = False,
26
+ auth: Union[str, None] = None,
27
  load_8bit: Union[bool, None] = None,
28
  ui_show_sys_info: Union[bool, None] = None,
29
  ui_dev_mode: Union[bool, None] = None,
30
  wandb_api_key: Union[str, None] = None,
31
  wandb_project: Union[str, None] = None,
32
  timezone: Union[str, None] = None,
33
+ config: Union[str, None] = None,
34
  ):
35
  '''
36
  Start the LLaMA-LoRA Tuner UI.
 
47
  :param wandb_project: The default project name for Weights & Biases. Setting either this or `wandb_api_key` will enable Weights & Biases.
48
  '''
49
 
50
+ config_from_file = read_yaml_config(config_path=config)
51
  if config_from_file:
52
  for key, value in config_from_file.items():
53
  if key == "server_name":
54
  server_name = value
55
  continue
56
  if not hasattr(Config, key):
57
+ available_keys = [k for k in vars(
58
+ Config) if not k.startswith('__')]
59
+ raise ValueError(
60
+ f"Invalid config key '{key}' in config.yaml. Available keys: {', '.join(available_keys)}")
61
  setattr(Config, key, value)
62
 
63
  if base_model is not None:
 
75
  if load_8bit is not None:
76
  Config.load_8bit = load_8bit
77
 
78
+ if auth is not None:
79
+ try:
80
+ [Config.auth_username, Config.auth_password] = auth.split(':')
81
+ except ValueError:
82
+ raise ValueError("--auth must be in the format <username>:<password>, e.g.: --auth='username:password'")
83
+
84
  if wandb_api_key is not None:
85
  Config.wandb_api_key = wandb_api_key
86
 
 
116
  main_page()
117
 
118
  demo.queue(concurrency_count=1).launch(
119
+ server_name=server_name,
120
+ share=share,
121
+ auth=((Config.auth_username, Config.auth_password)
122
+ if Config.auth_username and Config.auth_password else None)
123
+ )
124
 
125
 
126
+ def read_yaml_config(config_path: Union[str, None] = None):
127
+ if not config_path:
128
+ app_dir = os.path.dirname(os.path.abspath(__file__))
129
+ config_path = os.path.join(app_dir, 'config.yaml')
130
 
131
  if not os.path.exists(config_path):
132
  return None
llama_lora/config.py CHANGED
@@ -18,6 +18,9 @@ class Config:
18
 
19
  timezone: Any = pytz.UTC
20
 
 
 
 
21
  # WandB
22
  enable_wandb: Union[bool, None] = None
23
  wandb_api_key: Union[str, None] = None
 
18
 
19
  timezone: Any = pytz.UTC
20
 
21
+ auth_username: Union[str, None] = None
22
+ auth_password: Union[str, None] = None
23
+
24
  # WandB
25
  enable_wandb: Union[bool, None] = None
26
  wandb_api_key: Union[str, None] = None