from vjepa_encoder.vision_encoder import JepaEncoder | |
encoder = JepaEncoder.load_model( | |
"logs/params-encoder.yaml" | |
) | |
import numpy | |
import torch | |
img = numpy.random.random(size=(360, 480, 3)) | |
x = torch.rand((32, 3, 256, 900)) | |
print("Input Img:", img.shape) | |
embedding = encoder.embed_image(img) | |
print(embedding) | |
print(embedding.shape) | |
embedding = encoder.embed_image(x) | |
print(embedding) | |
print(embedding.shape) | |
encoder.save_checkpoint("./test_jepa_model.tar") |