import os import sys import pytest from PIL import Image import torch from training.main import main os.environ["CUDA_VISIBLE_DEVICES"] = "" @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") def test_training(): main([ '--save-frequency', '1', '--zeroshot-frequency', '1', '--dataset-type', "synthetic", '--train-num-samples', '16', '--warmup', '1', '--batch-size', '4', '--lr', '1e-3', '--wd', '0.1', '--epochs', '1', '--workers', '2', '--model', 'RN50' ]) @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") def test_training_mt5(): main([ '--save-frequency', '1', '--zeroshot-frequency', '1', '--dataset-type', "synthetic", '--train-num-samples', '16', '--warmup', '1', '--batch-size', '4', '--lr', '1e-3', '--wd', '0.1', '--epochs', '1', '--workers', '2', '--model', 'mt5-base-ViT-B-32', '--lock-text', '--lock-text-unlocked-layers', '2' ]) @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") def test_training_unfreezing_vit(): main([ '--save-frequency', '1', '--zeroshot-frequency', '1', '--dataset-type', "synthetic", '--train-num-samples', '16', '--warmup', '1', '--batch-size', '4', '--lr', '1e-3', '--wd', '0.1', '--epochs', '1', '--workers', '2', '--model', 'ViT-B-32', '--lock-image', '--lock-image-unlocked-groups', '5' ])