tiny_clip / src /trainer.py
sachin's picture
Refactoring to distangle modules
24d96ab
raw
history blame
962 Bytes
from src import data
from src import config
from src import vision_model
from src import tokenizer as tk
from src.lightning_module import LightningModule
from src import loss
from src import models
def train(config: config.TrainerConfig):
transform = vision_model.get_vision_transform(config._model_config.vision_config)
tokenizer = tk.Tokenizer(config._model_config.text_config)
train_dl, valid_dl = data.get_dataset(
transform=transform, tokenizer=tokenizer, hyper_parameters=config # type: ignore
)
vision_encoder = models.TinyCLIPVisionEncoder(config=config._model_config.vision_config)
text_encoder = models.TinyCLIPTextEncoder(config=config._model_config.text_config)
lightning_module = LightningModule(
vision_encoder=vision_encoder,
text_encoder=text_encoder,
loss_fn=loss.get_loss(config._model_config.loss_type),
hyper_parameters=config,
len_train_dl=len(train_dl),
)