jonathanzkoch commited on
Commit
9eb104f
·
1 Parent(s): 3999299

Updates to jepas compatability with python

Browse files
Files changed (2) hide show
  1. demo_jepa_encoder.py +8 -0
  2. params-encoder.yaml +3 -27
demo_jepa_encoder.py CHANGED
@@ -5,10 +5,18 @@ encoder = JepaEncoder.load_model(
5
  )
6
 
7
  import numpy
 
8
  img = numpy.random.random(size=(360, 480, 3))
9
 
 
 
10
  print("Input Img:", img.shape)
11
  embedding = encoder.embed_image(img)
12
 
 
 
 
 
 
13
  print(embedding)
14
  print(embedding.shape)
 
5
  )
6
 
7
  import numpy
8
+ import torch
9
  img = numpy.random.random(size=(360, 480, 3))
10
 
11
+ x = torch.rand((32, 3, 256, 900))
12
+
13
  print("Input Img:", img.shape)
14
  embedding = encoder.embed_image(img)
15
 
16
+ print(embedding)
17
+ print(embedding.shape)
18
+
19
+
20
+ embedding = encoder.embed_image(x)
21
  print(embedding)
22
  print(embedding.shape)
params-encoder.yaml CHANGED
@@ -1,11 +1,6 @@
1
  app: vjepa
2
  data:
3
- batch_size: 8
4
- clip_duration: null
5
  crop_size: 224
6
- dataset_type: VideoDataset
7
- datasets:
8
- - /path/to/dataset.csv
9
  decode_one_clip: true
10
  filter_short_videos: false
11
  num_clips: 1
@@ -14,7 +9,7 @@ data:
14
  patch_size: 16
15
  pin_mem: true
16
  sampling_rate: 4
17
- tubelet_size: 2
18
  data_aug:
19
  auto_augment: false
20
  motion_shift: false
@@ -26,11 +21,8 @@ data_aug:
26
  - 1.0
27
  reprob: 0.0
28
  logging:
29
- folder: /path/to/logs
30
  write_tag: jepa
31
- loss:
32
- loss_exp: 1.0
33
- reg_coeff: 0.0
34
  mask:
35
  - aspect_ratio:
36
  - 0.75
@@ -60,7 +52,7 @@ meta:
60
  dtype: bfloat16
61
  eval_freq: 100
62
  load_checkpoint: true
63
- read_checkpoint: /path/to/vitl16.pth.tar
64
  save_every_freq: 5
65
  seed: 234
66
  use_sdpa: true
@@ -71,19 +63,3 @@ model:
71
  uniform_power: true
72
  use_mask_tokens: true
73
  zero_init_mask_tokens: true
74
- nodes: 16
75
- optimization:
76
- clip_grad: 10.0
77
- ema:
78
- - 0.998
79
- - 1.0
80
- epochs: 25
81
- final_lr: 1.0e-06
82
- final_weight_decay: 0.4
83
- ipe: 300
84
- ipe_scale: 1.25
85
- lr: 0.000625
86
- start_lr: 0.0002
87
- warmup: 40
88
- weight_decay: 0.04
89
- tasks_per_node: 8
 
1
  app: vjepa
2
  data:
 
 
3
  crop_size: 224
 
 
 
4
  decode_one_clip: true
5
  filter_short_videos: false
6
  num_clips: 1
 
9
  patch_size: 16
10
  pin_mem: true
11
  sampling_rate: 4
12
+ tubelet_size: 1
13
  data_aug:
14
  auto_augment: false
15
  motion_shift: false
 
21
  - 1.0
22
  reprob: 0.0
23
  logging:
24
+ folder: /media/rpal/Drive_10TB/John/jepa/logs
25
  write_tag: jepa
 
 
 
26
  mask:
27
  - aspect_ratio:
28
  - 0.75
 
52
  dtype: bfloat16
53
  eval_freq: 100
54
  load_checkpoint: true
55
+ read_checkpoint: /media/rpal/Drive_10TB/John/jepa/huggingface/jepa-latest.pth.tar
56
  save_every_freq: 5
57
  seed: 234
58
  use_sdpa: true
 
63
  uniform_power: true
64
  use_mask_tokens: true
65
  zero_init_mask_tokens: true