File size: 459 Bytes
df07554 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import sys
from dataset import GridDataset
from Trainer import Trainer
trainer = Trainer(write_logs=False)
trainer.load_datasets()
trainer.create_model()
dataloader = trainer.dataset2dataloader(
trainer.train_dataset, num_workers=0
)
for batch in dataloader:
break
vid = batch.get('vid').cuda()
txt = batch.get('txt').cuda()
vid_len = batch.get('vid_len').cuda()
txt_len = batch.get('txt_len').cuda()
y = trainer.net(vid)
print(y)
print('>>> ')
|