vjepa-self-driving / demo_jepa_encoder.py
jonathanzkoch's picture
update demo files and embedding config
81cffda
raw
history blame contribute delete
465 Bytes
from vjepa_encoder.vision_encoder import JepaEncoder
encoder = JepaEncoder.load_model(
"logs/params-encoder.yaml"
)
import numpy
import torch
img = numpy.random.random(size=(360, 480, 3))
x = torch.rand((32, 3, 256, 900))
print("Input Img:", img.shape)
embedding = encoder.embed_image(img)
print(embedding)
print(embedding.shape)
embedding = encoder.embed_image(x)
print(embedding)
print(embedding.shape)
encoder.save_checkpoint("./test_jepa_model.tar")