Commit
·
9eb104f
1
Parent(s):
3999299
Updates to jepas compatability with python
Browse files- demo_jepa_encoder.py +8 -0
- params-encoder.yaml +3 -27
demo_jepa_encoder.py
CHANGED
@@ -5,10 +5,18 @@ encoder = JepaEncoder.load_model(
|
|
5 |
)
|
6 |
|
7 |
import numpy
|
|
|
8 |
img = numpy.random.random(size=(360, 480, 3))
|
9 |
|
|
|
|
|
10 |
print("Input Img:", img.shape)
|
11 |
embedding = encoder.embed_image(img)
|
12 |
|
|
|
|
|
|
|
|
|
|
|
13 |
print(embedding)
|
14 |
print(embedding.shape)
|
|
|
5 |
)
|
6 |
|
7 |
import numpy
|
8 |
+
import torch
|
9 |
img = numpy.random.random(size=(360, 480, 3))
|
10 |
|
11 |
+
x = torch.rand((32, 3, 256, 900))
|
12 |
+
|
13 |
print("Input Img:", img.shape)
|
14 |
embedding = encoder.embed_image(img)
|
15 |
|
16 |
+
print(embedding)
|
17 |
+
print(embedding.shape)
|
18 |
+
|
19 |
+
|
20 |
+
embedding = encoder.embed_image(x)
|
21 |
print(embedding)
|
22 |
print(embedding.shape)
|
params-encoder.yaml
CHANGED
@@ -1,11 +1,6 @@
|
|
1 |
app: vjepa
|
2 |
data:
|
3 |
-
batch_size: 8
|
4 |
-
clip_duration: null
|
5 |
crop_size: 224
|
6 |
-
dataset_type: VideoDataset
|
7 |
-
datasets:
|
8 |
-
- /path/to/dataset.csv
|
9 |
decode_one_clip: true
|
10 |
filter_short_videos: false
|
11 |
num_clips: 1
|
@@ -14,7 +9,7 @@ data:
|
|
14 |
patch_size: 16
|
15 |
pin_mem: true
|
16 |
sampling_rate: 4
|
17 |
-
tubelet_size:
|
18 |
data_aug:
|
19 |
auto_augment: false
|
20 |
motion_shift: false
|
@@ -26,11 +21,8 @@ data_aug:
|
|
26 |
- 1.0
|
27 |
reprob: 0.0
|
28 |
logging:
|
29 |
-
folder: /
|
30 |
write_tag: jepa
|
31 |
-
loss:
|
32 |
-
loss_exp: 1.0
|
33 |
-
reg_coeff: 0.0
|
34 |
mask:
|
35 |
- aspect_ratio:
|
36 |
- 0.75
|
@@ -60,7 +52,7 @@ meta:
|
|
60 |
dtype: bfloat16
|
61 |
eval_freq: 100
|
62 |
load_checkpoint: true
|
63 |
-
read_checkpoint: /
|
64 |
save_every_freq: 5
|
65 |
seed: 234
|
66 |
use_sdpa: true
|
@@ -71,19 +63,3 @@ model:
|
|
71 |
uniform_power: true
|
72 |
use_mask_tokens: true
|
73 |
zero_init_mask_tokens: true
|
74 |
-
nodes: 16
|
75 |
-
optimization:
|
76 |
-
clip_grad: 10.0
|
77 |
-
ema:
|
78 |
-
- 0.998
|
79 |
-
- 1.0
|
80 |
-
epochs: 25
|
81 |
-
final_lr: 1.0e-06
|
82 |
-
final_weight_decay: 0.4
|
83 |
-
ipe: 300
|
84 |
-
ipe_scale: 1.25
|
85 |
-
lr: 0.000625
|
86 |
-
start_lr: 0.0002
|
87 |
-
warmup: 40
|
88 |
-
weight_decay: 0.04
|
89 |
-
tasks_per_node: 8
|
|
|
1 |
app: vjepa
|
2 |
data:
|
|
|
|
|
3 |
crop_size: 224
|
|
|
|
|
|
|
4 |
decode_one_clip: true
|
5 |
filter_short_videos: false
|
6 |
num_clips: 1
|
|
|
9 |
patch_size: 16
|
10 |
pin_mem: true
|
11 |
sampling_rate: 4
|
12 |
+
tubelet_size: 1
|
13 |
data_aug:
|
14 |
auto_augment: false
|
15 |
motion_shift: false
|
|
|
21 |
- 1.0
|
22 |
reprob: 0.0
|
23 |
logging:
|
24 |
+
folder: /media/rpal/Drive_10TB/John/jepa/logs
|
25 |
write_tag: jepa
|
|
|
|
|
|
|
26 |
mask:
|
27 |
- aspect_ratio:
|
28 |
- 0.75
|
|
|
52 |
dtype: bfloat16
|
53 |
eval_freq: 100
|
54 |
load_checkpoint: true
|
55 |
+
read_checkpoint: /media/rpal/Drive_10TB/John/jepa/huggingface/jepa-latest.pth.tar
|
56 |
save_every_freq: 5
|
57 |
seed: 234
|
58 |
use_sdpa: true
|
|
|
63 |
uniform_power: true
|
64 |
use_mask_tokens: true
|
65 |
zero_init_mask_tokens: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|