sachin commited on
Commit
f872c60
1 Parent(s): 571c526

Added click to run the model

Browse files
Files changed (1) hide show
  1. src/trainer.py +13 -3
src/trainer.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
 
 
3
  from huggingface_hub import HfApi
4
  from loguru import logger
5
 
@@ -48,9 +49,17 @@ def _upload_model_to_hub(
48
  ) # type: ignore
49
 
50
 
51
- def train(trainer_config: config.TrainerConfig):
 
 
 
 
 
 
 
52
  if "HF_TOKEN" not in os.environ:
53
  raise ValueError("Please set the HF_TOKEN environment variable.")
 
54
  transform = vision_model.get_vision_transform(trainer_config._model_config.vision_config)
55
  tokenizer = tk.Tokenizer(trainer_config._model_config.text_config)
56
  train_dl, valid_dl = data.get_dataset(
@@ -73,6 +82,7 @@ def train(trainer_config: config.TrainerConfig):
73
  _upload_model_to_hub(vision_encoder, text_encoder, trainer_config.debug)
74
 
75
 
 
 
76
  if __name__ == "__main__":
77
- trainer_config = config.TrainerConfig(debug=True)
78
- train(trainer_config)
 
1
  import os
2
 
3
+ import click
4
  from huggingface_hub import HfApi
5
  from loguru import logger
6
 
 
49
  ) # type: ignore
50
 
51
 
52
+ @click.group()
53
+ def cli():
54
+ pass
55
+
56
+
57
+ @click.command()
58
+ @click.option("--trainer-config-json", required=False, default="{}", type=str)
59
+ def train(trainer_config_json: str):
60
  if "HF_TOKEN" not in os.environ:
61
  raise ValueError("Please set the HF_TOKEN environment variable.")
62
+ trainer_config = config.TrainerConfig.model_validate_json(trainer_config_json)
63
  transform = vision_model.get_vision_transform(trainer_config._model_config.vision_config)
64
  tokenizer = tk.Tokenizer(trainer_config._model_config.text_config)
65
  train_dl, valid_dl = data.get_dataset(
 
82
  _upload_model_to_hub(vision_encoder, text_encoder, trainer_config.debug)
83
 
84
 
85
+ cli.add_command(train)
86
+
87
  if __name__ == "__main__":
88
+ cli()