from vjepa_encoder.vision_encoder import JepaEncoder encoder = JepaEncoder.load_model( "logs/params-encoder.yaml" ) import numpy import torch img = numpy.random.random(size=(360, 480, 3)) x = torch.rand((32, 3, 256, 900)) print("Input Img:", img.shape) embedding = encoder.embed_image(img) print(embedding) print(embedding.shape) embedding = encoder.embed_image(x) print(embedding) print(embedding.shape) encoder.save_checkpoint("./test_jepa_model.tar")