| import wandb |
| from lightning.pytorch.loggers import WandbLogger |
| from lightning.pytorch.callbacks.early_stopping import EarlyStopping |
| from lightning.pytorch.callbacks import ModelCheckpoint |
| from lightning import Trainer |
|
|
| from mentioned.model import LitMentionDetector, ModelRegistry |
| from mentioned.data import DataRegistry |
|
|
|
|
| def train( |
| model_factory: str = "model_v1", |
| data_factory: str = "litbank_mentions", |
| repo_id: str = "kadarakos/mention-detector-poc-dry-run", |
| project_name: str = "mention-detector-poc", |
| encoder_id: str = "distilroberta-base", |
| patience: int = 5, |
| val_interval: int = 1000, |
| stop_criterion: str = "val_f1_mention", |
| max_epochs: int | None = None, |
| ): |
| if max_epochs is None: |
| max_epochs = 1000 |
| data = DataRegistry.get(data_factory)() |
| model = ModelRegistry.get(model_factory)(data, encoder_id) |
| wandb_logger = WandbLogger( |
| project=project_name, |
| name=encoder_id, |
| ) |
| best_checkpoint = ModelCheckpoint( |
| monitor=stop_criterion, |
| mode="max", |
| save_top_k=1, |
| filename=f"best-{stop_criterion}", |
| verbose=True, |
| ) |
| early_stopper = EarlyStopping( |
| monitor=stop_criterion, |
| min_delta=0.01, |
| patience=patience, |
| verbose=True, |
| mode="max", |
| ) |
| trainer = Trainer( |
| max_epochs=max_epochs, |
| val_check_interval=val_interval, |
| callbacks=[early_stopper, best_checkpoint], |
| logger=wandb_logger, |
| accelerator="auto", |
| ) |
| print(f"Starting Trainer for {max_epochs} epochs.") |
| trainer.fit( |
| model=model, |
| train_dataloaders=data.train_loader, |
| val_dataloaders=data.val_loader, |
| ) |
| trainer.test(dataloaders=data.test_loader, ckpt_path="best", weights_only=False) |
| print(f"Pushing best model to: {repo_id}") |
| fresh_bundle = ModelRegistry.get(model_factory)(data, encoder_id) |
| labeler = getattr(fresh_bundle, "mention_labeler", None) |
| l2id = getattr(fresh_bundle, "label2id", None) |
|
|
| best_model = LitMentionDetector.load_from_checkpoint( |
| trainer.checkpoint_callback.best_model_path, |
| tokenizer=fresh_bundle.tokenizer, |
| encoder=fresh_bundle.encoder, |
| mention_detector=fresh_bundle.mention_detector, |
| label2id=l2id, |
| mention_labeler=labeler, |
| weights_only=False, |
| ) |
| best_model.push_to_hub(repo_id, private=True) |
| wandb.finish() |
|
|
| print("Verifying Hub upload by pulling and re-evaluating...") |
| remote_model = LitMentionDetector.from_pretrained( |
| repo_id, |
| tokenizer=fresh_bundle.tokenizer, |
| encoder=fresh_bundle.encoder, |
| mention_detector=fresh_bundle.mention_detector, |
| label2id=l2id, |
| mention_labeler=labeler, |
| ) |
|
|
| verify_trainer = Trainer(accelerator="auto", logger=False) |
| verify_trainer.test(model=remote_model, dataloaders=data.test_loader) |
|
|