File size: 4,331 Bytes
5c9efac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from model import ImageToVideoModel
from de_en.tokenizer import VideoTokenizer
import torch.optim as optim
from torch.nn import MSELoss
from tqdm import tqdm
import argparse

def prepare_datasets(dataset_name, batch_size, resolution):
    dataset = load_dataset(dataset_name)
    
    # Preprocess function
    def preprocess(examples):
        tokenizer = VideoTokenizer(resolution)
        examples['image'] = [tokenizer.encode_image(img) for img in examples['image']]
        examples['video'] = [tokenizer.encode_video(vid) for vid in examples['video']]
        return examples
    
    dataset = dataset.map(preprocess, batched=True)
    dataset.set_format(type='torch', columns=['image', 'video'])
    
    train_loader = DataLoader(dataset['train'], batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(dataset['validation'], batch_size=batch_size)
    
    return train_loader, val_loader

def train_model(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize model
    model = ImageToVideoModel(
        encoder_config=config['encoder'],
        decoder_config=config['decoder'],
        transformer_config=config['transformer']
    ).to(device)
    
    # Load datasets
    train_loader, val_loader = prepare_datasets(
        config['dataset_name'],
        config['batch_size'],
        config['resolution']
    )
    
    # Optimizer and loss
    optimizer = optim.AdamW(model.parameters(), lr=config['lr'])
    criterion = MSELoss()
    
    # Training loop
    for epoch in range(config['epochs']):
        model.train()
        train_loss = 0.0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            images = batch['image'].to(device)
            videos = batch['video'].to(device)
            
            # Random speed level for each sample in batch
            speed_levels = torch.randint(0, 10, (images.size(0),).to(device)
            
            optimizer.zero_grad()
            
            # Predict all frames at once (teacher forcing)
            outputs = model(images, videos[:, :-1], speed_levels)
            
            loss = criterion(outputs, videos[:, 1:])
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                images = batch['image'].to(device)
                videos = batch['video'].to(device)
                speed_levels = torch.randint(0, 10, (images.size(0),).to(device)
                
                outputs = model(images, videos[:, :-1], speed_levels)
                val_loss += criterion(outputs, videos[:, 1:]).item()
        
        print(f"Epoch {epoch+1}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}")
    
    # Save model
    torch.save(model.state_dict(), config['save_path'])

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="ucf101")
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--resolution", type=int, default=128)
    parser.add_argument("--save_path", type=str, default="image_to_video_model.pth")
    args = parser.parse_args()
    
    config = {
        'dataset_name': args.dataset,
        'batch_size': args.batch_size,
        'epochs': args.epochs,
        'lr': args.lr,
        'resolution': args.resolution,
        'save_path': args.save_path,
        'encoder': {
            'in_channels': 3,
            'hidden_dims': [64, 128, 256, 512],
            'embed_dim': 512
        },
        'decoder': {
            'embed_dim': 512,
            'hidden_dims': [512, 256, 128, 64],
            'out_channels': 3
        },
        'transformer': {
            'd_model': 512,
            'nhead': 8,
            'num_encoder_layers': 3,
            'num_decoder_layers': 3,
            'dim_feedforward': 2048,
            'dropout': 0.1
        }
    }
    
    train_model(config)