from src import data | |
from src import config | |
from src import vision_model | |
from src import tokenizer as tk | |
from src.lightning_module import LightningModule | |
from src import loss | |
from src import models | |
def train(config: config.TrainerConfig): | |
transform = vision_model.get_vision_transform(config._model_config.vision_config) | |
tokenizer = tk.Tokenizer(config._model_config.text_config) | |
train_dl, valid_dl = data.get_dataset( | |
transform=transform, tokenizer=tokenizer, hyper_parameters=config # type: ignore | |
) | |
vision_encoder = models.TinyCLIPVisionEncoder(config=config._model_config.vision_config) | |
text_encoder = models.TinyCLIPTextEncoder(config=config._model_config.text_config) | |
lightning_module = LightningModule( | |
vision_encoder=vision_encoder, | |
text_encoder=text_encoder, | |
loss_fn=loss.get_loss(config._model_config.loss_type), | |
hyper_parameters=config, | |
len_train_dl=len(train_dl), | |
) | |