tiny_clip / src /trainer.py
sachin's picture
Added click to run the model
f872c60
raw
history blame
No virus
2.8 kB
import os
import click
from huggingface_hub import HfApi
from loguru import logger
from src import config
from src import data
from src import loss
from src import models
from src import tokenizer as tk
from src import vision_model
from src import utils
from src.lightning_module import LightningModule
def _upload_model_to_hub(
vision_encoder: models.TinyCLIPVisionEncoder,
text_encoder: models.TinyCLIPTextEncoder,
debug: bool = False,
):
vision_encoder.save_pretrained(
str(config.VISION_MODEL_PATH),
safe_serialization=True,
)
text_encoder.save_pretrained(
str(config.TEXT_MODEL_PATH),
safe_serialization=True,
)
api = HfApi()
if debug:
repo_components = config.REPO_ID.split("/", maxsplit=1)
repo_components[1] = f"debug-{repo_components[1]}"
repo_id = "/".join(repo_components)
else:
repo_id = config.REPO_ID
common_hf_api_params = {
"repo_id": repo_id,
"repo_type": "model",
}
if not api.repo_exists(**common_hf_api_params):
logger.info(f"Creating repo {repo_id} on Hugging Face Hub.")
api.create_repo(**common_hf_api_params) # type: ignore
logger.info(f"Uploading models in {str(config.MODEL_PATH)} to {repo_id}.")
api.upload_folder(
folder_path=config.MODEL_PATH,
**common_hf_api_params, # type: ignore
) # type: ignore
@click.group()
def cli():
pass
@click.command()
@click.option("--trainer-config-json", required=False, default="{}", type=str)
def train(trainer_config_json: str):
if "HF_TOKEN" not in os.environ:
raise ValueError("Please set the HF_TOKEN environment variable.")
trainer_config = config.TrainerConfig.model_validate_json(trainer_config_json)
transform = vision_model.get_vision_transform(trainer_config._model_config.vision_config)
tokenizer = tk.Tokenizer(trainer_config._model_config.text_config)
train_dl, valid_dl = data.get_dataset(
transform=transform, tokenizer=tokenizer, hyper_parameters=trainer_config # type: ignore
)
vision_encoder = models.TinyCLIPVisionEncoder(config=trainer_config._model_config.vision_config)
text_encoder = models.TinyCLIPTextEncoder(config=trainer_config._model_config.text_config)
lightning_module = LightningModule(
vision_encoder=vision_encoder,
text_encoder=text_encoder,
loss_fn=loss.get_loss(trainer_config._model_config.loss_type),
hyper_parameters=trainer_config,
len_train_dl=len(train_dl),
)
trainer = utils.get_trainer(trainer_config)
trainer.fit(lightning_module, train_dl, valid_dl)
_upload_model_to_hub(vision_encoder, text_encoder, trainer_config.debug)
cli.add_command(train)
if __name__ == "__main__":
cli()