File size: 459 Bytes
df07554
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import sys

from dataset import GridDataset
from Trainer import Trainer

trainer = Trainer(write_logs=False)
trainer.load_datasets()
trainer.create_model()

dataloader = trainer.dataset2dataloader(
    trainer.train_dataset, num_workers=0
)

for batch in dataloader:
    break

vid = batch.get('vid').cuda()
txt = batch.get('txt').cuda()
vid_len = batch.get('vid_len').cuda()
txt_len = batch.get('txt_len').cuda()
y = trainer.net(vid)

print(y)
print('>>> ')