Blackroot commited on
Commit
4a9ad28
·
verified ·
1 Parent(s): d374f52

Upload 9 files

Browse files
Files changed (10) hide show
  1. .gitattributes +4 -0
  2. 1.png +3 -0
  3. 2.png +3 -0
  4. 3.png +3 -0
  5. 4.png +3 -0
  6. models/__init__.py +3 -0
  7. models/uvit.py +219 -0
  8. step_1799.safetensors +3 -0
  9. test_sample.py +81 -0
  10. train.py +236 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 1.png filter=lfs diff=lfs merge=lfs -text
37
+ 2.png filter=lfs diff=lfs merge=lfs -text
38
+ 3.png filter=lfs diff=lfs merge=lfs -text
39
+ 4.png filter=lfs diff=lfs merge=lfs -text
1.png ADDED

Git LFS Details

  • SHA256: e61e271503cd84735a944d563d82bbe321b4f0c8fc2490ed3f0e3f23310fb903
  • Pointer size: 132 Bytes
  • Size of remote file: 2.1 MB
2.png ADDED

Git LFS Details

  • SHA256: ff9e6a8322d050cb697bfe7b63bf3d56b9607267c4e6831610e16de35b4261a5
  • Pointer size: 132 Bytes
  • Size of remote file: 2.11 MB
3.png ADDED

Git LFS Details

  • SHA256: 02dc1fde3f9601de0ec09bb75a41a068e435b8478705c28c22dee6532ac4b2b6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.17 MB
4.png ADDED

Git LFS Details

  • SHA256: 4b6d4dc329903a1f0ac56d5669ecb4488b9f05778761844ce1fcbc53a9ad1092
  • Pointer size: 132 Bytes
  • Size of remote file: 2.15 MB
models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .uvit import AsymmetricResidualUDiT
2
+
3
+ __all__ = ['AsymmetricResidualUDiT']
models/uvit.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # This architecture was my attempt at the following Simple Diffusion paper with some modifications:
6
+ # https://arxiv.org/pdf/2410.19324v1
7
+
8
+ # Very similar to GeGLU or SwiGLU, there's a learned gate FN, uses arctan as the activation fn.
9
+ class xATGLU(nn.Module):
10
+ def __init__(self, input_dim, output_dim, bias=True):
11
+ super().__init__()
12
+ # GATE path | VALUE path
13
+ self.proj = nn.Linear(input_dim, output_dim * 2, bias=bias)
14
+ nn.init.kaiming_normal_(self.proj.weight, nonlinearity='linear')
15
+
16
+ self.alpha = nn.Parameter(torch.zeros(1))
17
+ self.half_pi = torch.pi / 2
18
+ self.inv_pi = 1 / torch.pi
19
+
20
+ def forward(self, x):
21
+ projected = self.proj(x)
22
+ gate_path, value_path = projected.chunk(2, dim=-1)
23
+
24
+ # Apply arctan gating with expanded range via learned alpha -- https://arxiv.org/pdf/2405.20768
25
+ gate = (torch.arctan(gate_path) + self.half_pi) * self.inv_pi
26
+ expanded_gate = gate * (1 + 2 * self.alpha) - self.alpha
27
+
28
+ return expanded_gate * value_path # g(x) × y
29
+
30
+ class ResBlock(nn.Module):
31
+ def __init__(self, channels):
32
+ super().__init__()
33
+ self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
34
+ self.norm1 = nn.GroupNorm(32, channels)
35
+ self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
36
+ self.norm2 = nn.GroupNorm(32, channels)
37
+
38
+ def forward(self, x):
39
+ h = self.conv1(F.silu(self.norm1(x)))
40
+ h = self.conv2(F.silu(self.norm2(h)))
41
+ return x + h
42
+
43
+ class TransformerBlock(nn.Module):
44
+ def __init__(self, channels, num_heads=8):
45
+ super().__init__()
46
+ self.norm1 = nn.LayerNorm(channels)
47
+ self.attn = nn.MultiheadAttention(channels, num_heads)
48
+ self.norm2 = nn.LayerNorm(channels)
49
+ self.mlp = nn.Sequential(
50
+ xATGLU(channels, 4 * channels),
51
+ nn.Linear(4 * channels, channels)
52
+ )
53
+
54
+ def forward(self, x):
55
+ # Reshape for attention [B, C, H, W] -> [H*W, B, C]
56
+ b, c, h, w = x.shape
57
+ spatial_size = h * w
58
+ x = x.flatten(2).permute(2, 0, 1)
59
+
60
+ # Self attention
61
+ h_attn = self.norm1(x)
62
+ h_attn, _ = self.attn(h_attn, h_attn, h_attn)
63
+ x = x + h_attn
64
+
65
+ # MLP
66
+ h_mlp = self.norm2(x)
67
+ h_mlp = self.mlp(h_mlp)
68
+ x = x + h_mlp
69
+
70
+ # Reshape back [H*W, B, C] -> [B, C, H, W]
71
+ return x.permute(1, 2, 0).reshape(b, c, h, w)
72
+
73
+ class LevelBlock(nn.Module):
74
+ def __init__(self, channels, num_blocks, block_type='res'):
75
+ super().__init__()
76
+ self.blocks = nn.ModuleList()
77
+ for _ in range(num_blocks):
78
+ if block_type == 'transformer':
79
+ self.blocks.append(TransformerBlock(channels))
80
+ else:
81
+ self.blocks.append(ResBlock(channels))
82
+
83
+ def forward(self, x):
84
+ for block in self.blocks:
85
+ x = block(x)
86
+ return x
87
+
88
+ class AsymmetricResidualUDiT(nn.Module):
89
+ def __init__(self,
90
+ in_channels=3, # Input color channels
91
+ base_channels=128, # Initial feature size, dramatically increases parameter size of network.
92
+ patch_size=2, # Smaller patches dramatically increases flops and compute expenses. Recommend >=4 unless you have real compute.
93
+ num_levels=3, # Feature downsample, essentially the unet depth -- so we down/upsample three times. Dramatically increases parameters as you increase.
94
+ encoder_blocks=3, # Can be different number of blocks VS decoder_blocks
95
+ decoder_blocks=7, # Can be different number of blocks VS encoder_blocks
96
+ encoder_transformer_thresh=2, #When to start using transformer blocks instead of res blocks in the encoder. (>=)
97
+ decoder_transformer_thresh=4, #When to stop using transformer blocks instead of res blocks in the decoder. (<=)
98
+ mid_blocks=16 # Number of middle transformer blocks. Relatively cheap as this is at the bottom of the unet feature bottleneck.
99
+ ):
100
+ super().__init__()
101
+
102
+ # Initial projection from image space
103
+ self.patch_embed = nn.Conv2d(in_channels, base_channels,
104
+ kernel_size=patch_size, stride=patch_size)
105
+
106
+ # Create encoder levels
107
+ self.encoders = nn.ModuleList()
108
+ curr_channels = base_channels
109
+
110
+ for level in range(num_levels):
111
+ # Create the main processing blocks for this level
112
+ use_transformer = level >= encoder_transformer_thresh # Use transformers for latter levels
113
+
114
+ # Encoder blocks -- encoder_blocks
115
+ self.encoders.append(
116
+ LevelBlock(curr_channels, encoder_blocks, use_transformer)
117
+ )
118
+ # Add channel scaling for next level
119
+ # Doubles the size of the feature space for each step, except for the last level.
120
+ if level < num_levels - 1:
121
+ self.encoders.append(
122
+ nn.Conv2d(curr_channels, curr_channels * 2, 1)
123
+ )
124
+ curr_channels *= 2
125
+
126
+ # Middle transformer blocks -- mid_blocks
127
+ self.middle = nn.ModuleList([
128
+ TransformerBlock(curr_channels) for _ in range(mid_blocks)
129
+ ])
130
+
131
+ # Create decoder levels
132
+ self.decoders = nn.ModuleList()
133
+
134
+ for level in range(num_levels):
135
+ # Create the main processing blocks for this level
136
+ use_transformer = level <= decoder_transformer_thresh # Use transformers for early levels (inverse of encoder)
137
+
138
+ # Decoder blocks -- decoder_blocks
139
+ self.decoders.append(
140
+ LevelBlock(curr_channels, decoder_blocks, use_transformer)
141
+ )
142
+
143
+ # Add channel scaling for next level
144
+ # Halves the size of the feature space for each step, except for the last level.
145
+ if level < num_levels - 1:
146
+ self.decoders.append(
147
+ nn.Conv2d(curr_channels, curr_channels // 2, 1)
148
+ )
149
+ curr_channels //= 2
150
+
151
+ # Final projection back to image space
152
+ self.final_proj = nn.ConvTranspose2d(base_channels, in_channels,
153
+ kernel_size=patch_size, stride=patch_size)
154
+
155
+ def downsample(self, x):
156
+ return F.avg_pool2d(x, kernel_size=2)
157
+
158
+ def upsample(self, x):
159
+ return F.interpolate(x, scale_factor=2, mode='nearest')
160
+
161
+ def forward(self, x, t=None):
162
+ # Start by patch embedding the inputs.
163
+ x = self.patch_embed(x)
164
+
165
+ # Track residual path and features at each spatial level
166
+ # The paper was very specific about the residual flow path, I tried my best to copy how they described it.
167
+
168
+ # *Per resolution e.g. per num_level resolution block more or less
169
+ # f(x) = fu( U(fm(D(h)) - D(h)) + h ) where h = fd(x)
170
+ #
171
+ # Where
172
+ # 1. h = fd(x) : Encoder path processes input
173
+ # 2. D(h) : Downsample the encoded features
174
+ # 3. fm(D(h)) : Middle transformer blocks process downsampled features
175
+ # 4. fm(D(h))-D(h): Subtract original downsampled features (residual connection)
176
+ # 5. U(...) : Upsample the processed features
177
+ # 6. ... + h : Add back original encoder features (skip connection)
178
+ # 7. fu(...) : Decoder path processes the combined features
179
+
180
+ residuals = []
181
+ curr_res = x
182
+
183
+ # Encoder path (computing h = fd(x))
184
+ h = x
185
+ for i, blocks in enumerate(self.encoders):
186
+ if isinstance(blocks, LevelBlock):
187
+ h = blocks(h)
188
+ else:
189
+ # Save residual before downsampling
190
+ residuals.append(curr_res)
191
+ # Downsample and update current residual
192
+ h = self.downsample(blocks(h))
193
+ curr_res = h
194
+
195
+ # Middle blocks (fm)
196
+ x = h
197
+ for block in self.middle:
198
+ x = block(x)
199
+
200
+ # Subtract the residual at this level (D(h))
201
+ x = x - curr_res
202
+
203
+ # Decoder path (fu)
204
+ for i, blocks in enumerate(self.decoders):
205
+ if isinstance(blocks, LevelBlock):
206
+ x = blocks(x)
207
+ else:
208
+ # Channel reduction
209
+ x = blocks(x)
210
+ # Upsample
211
+ x = self.upsample(x)
212
+ # Add residual from encoder at this level, LIFO, last residual added is the first we want, since it's this u-shape.
213
+ curr_res = residuals.pop()
214
+ x = x + curr_res
215
+
216
+ # Final projection
217
+ x = self.final_proj(x)
218
+
219
+ return x
step_1799.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74718eb5a40f7e9576182888828dbc717050987f6be58dcc6a28b58e6591f013
3
+ size 383841508
test_sample.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.utils as vutils
4
+ from models import AsymmetricResidualUDiT
5
+ from safetensors.torch import load_file
6
+ import os
7
+ import argparse
8
+ from typing import Optional
9
+
10
+ def load_checkpoint(model: nn.Module, checkpoint_path: str) -> None:
11
+ state_dict = load_file(checkpoint_path)
12
+
13
+ # The training was done via torch compile which prefixes the model with this for whatever reason.
14
+ # Handle compiled model state dict by removing '_orig_mod.' prefix
15
+ if all(k.startswith('_orig_mod.') for k in state_dict.keys()):
16
+ state_dict = {k[10:]: v for k, v in state_dict.items()}
17
+
18
+ model.load_state_dict(state_dict)
19
+
20
+ def sample(model, n_samples=16, n_steps=50, image_size=256, device="cuda", sigma_min=0.001, dtype=torch.float32):
21
+ with torch.amp.autocast('cuda', dtype=dtype):
22
+ x = torch.randn(n_samples, 3, image_size, image_size, device=device)
23
+ ts = torch.linspace(0, 1, n_steps, device=device)
24
+ dt = 1/n_steps
25
+
26
+ # Forward Euler Integration step 0..1
27
+ with torch.no_grad():
28
+ for i in range(len(ts)):
29
+ t = ts[i]
30
+ t_input = t.repeat(n_samples, 1, 1, 1)
31
+
32
+ v_t = model(x, t_input)
33
+ x = x + v_t * dt
34
+
35
+ return x.float()
36
+
37
+ def main():
38
+ parser = argparse.ArgumentParser(description="Generate samples from a trained UDiT model")
39
+ parser.add_argument("checkpoint", type=str, help="Path to the model checkpoint (.safetensors)")
40
+ parser.add_argument("--samples", type=int, default=16, help="Number of samples to generate")
41
+ parser.add_argument("--steps", type=int, default=50, help="Number of sampling steps")
42
+ parser.add_argument("--output", type=str, default="output.png", help="Output filename")
43
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
44
+ help="Device to run inference on (cuda/cpu)")
45
+ args = parser.parse_args()
46
+ device = args.device
47
+
48
+ model = AsymmetricResidualUDiT(
49
+ in_channels=3,
50
+ base_channels=128,
51
+ num_levels=3,
52
+ patch_size=4,
53
+ encoder_blocks=3,
54
+ decoder_blocks=7,
55
+ encoder_transformer_thresh=2,
56
+ decoder_transformer_thresh=4,
57
+ mid_blocks=8
58
+ ).to(device)
59
+
60
+ # Load state dict into model
61
+ load_checkpoint(model, args.checkpoint)
62
+ model.eval()
63
+
64
+ # Generate samples
65
+ print(f"Generating {args.samples} samples with {args.steps} steps...")
66
+ with torch.no_grad():
67
+ samples = sample(
68
+ model,
69
+ n_samples=args.samples,
70
+ n_steps=args.steps,
71
+ device=args.device,
72
+ dtype=torch.float32
73
+ )
74
+
75
+ # Save samples
76
+ os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
77
+ vutils.save_image(samples, args.output, nrow=4, padding=2)
78
+ print(f"Samples saved to {args.output}")
79
+
80
+ if __name__ == "__main__":
81
+ main()
train.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ import torchvision.transforms as transforms
6
+ import torchvision.utils as vutils
7
+ from datasets import load_dataset
8
+ from torch.utils.data import DataLoader, TensorDataset
9
+ from schedulefree import AdamWScheduleFree
10
+ from torch.utils.tensorboard import SummaryWriter
11
+ from safetensors.torch import save_file, load_file
12
+ import os, time
13
+ from models import AsymmetricResidualUDiT
14
+ from torch.cuda.amp import autocast
15
+
16
+ def preload_dataset(image_size=256, device="cuda"):
17
+ """Preload and cache the entire dataset in GPU memory"""
18
+ print("Loading and preprocessing dataset...")
19
+ #dataset = load_dataset("jiovine/pixel-art-nouns-2k", split="train")
20
+ dataset = load_dataset("reach-vb/pokemon-blip-captions", split="train")
21
+
22
+ transform = transforms.Compose([
23
+ transforms.ToTensor(),
24
+ transforms.Resize((image_size, image_size), antialias=True),
25
+ transforms.Lambda(lambda x: (x * 2) - 1) # Scale to [-1, 1]
26
+ ])
27
+
28
+ all_images = []
29
+ for example in dataset:
30
+ img_tensor = transform(example['image'])
31
+ all_images.append(img_tensor)
32
+
33
+ # Stack entire dataset onto gpu
34
+ images_tensor = torch.stack(all_images).to(device)
35
+ print(f"Dataset loaded: {images_tensor.shape} ({images_tensor.element_size() * images_tensor.nelement() / 1024/1024:.2f} MB)")
36
+
37
+ return TensorDataset(images_tensor)
38
+
39
+ def count_parameters(model):
40
+ total_params = sum(p.numel() for p in model.parameters())
41
+ print(f'Total parameters: {total_params:,} ({total_params/1e6:.2f}M)')
42
+
43
+ def save_checkpoint(model, optimizer, filename="checkpoint.safetensors"):
44
+ model_state = model.state_dict()
45
+ save_file(model_state, filename)
46
+
47
+ def load_checkpoint(model, optimizer, filename="checkpoint.safetensors"):
48
+ model_state = load_file(filename)
49
+ model.load_state_dict(model_state)
50
+
51
+ # https://arxiv.org/abs/2210.02747
52
+ class OptimalTransportLinearFlowGenerator():
53
+ def __init__(self, sigma_min=0.001):
54
+ self.sigma_min = sigma_min
55
+
56
+ def loss(self, model, x1, device):
57
+ batch_size = x1.shape[0]
58
+
59
+ # Sample t uniform in [0,1]
60
+ t = torch.rand(batch_size, 1, 1, 1, device=device)
61
+
62
+ # Sample noise
63
+ x0 = torch.randn_like(x1)
64
+ x1 = x1
65
+
66
+ # Compute OT path interpolation (equation 22)
67
+ sigma_t = 1 - (1 - self.sigma_min) * t
68
+ mu_t = t * x1
69
+ x_t = sigma_t * x0 + mu_t
70
+
71
+ # Compute target (equation 23)
72
+ target = x1 - (1 - self.sigma_min) * x0
73
+
74
+ v_t = model(x_t, t)
75
+ loss = F.mse_loss(v_t, target)
76
+
77
+ return loss
78
+
79
+ def write_logs(writer, model, loss, batch_idx, epoch, epoch_time, batch_size, lr, log_gradients=True):
80
+ """
81
+ TensorBoard logging
82
+
83
+ Args:
84
+ writer: torch.utils.tensorboard.SummaryWriter instance
85
+ model: torch.nn.Module - the model being trained
86
+ loss: float or torch.Tensor - the loss value to log
87
+ batch_idx: int - current batch index
88
+ epoch: int - current epoch
89
+ epoch_time: float - time taken for epoch
90
+ batch_size: int - current batch size
91
+ lr: float - current learning rate
92
+ samples: Optional[torch.Tensor] - generated samples to log (only passed every 50 epochs)
93
+ log_gradients: bool - whether to log gradient norms
94
+ """
95
+ total_steps = epoch * batch_idx
96
+
97
+ writer.add_scalar('Loss/batch', loss, total_steps)
98
+ writer.add_scalar('Time/epoch', epoch_time, epoch)
99
+ writer.add_scalar('Training/batch_size', batch_size, epoch)
100
+ writer.add_scalar('Training/learning_rate', lr, epoch)
101
+
102
+ if log_gradients:
103
+ total_norm = 0.0
104
+ for p in model.parameters():
105
+ if p.grad is not None:
106
+ param_norm = p.grad.detach().data.norm(2)
107
+ total_norm += param_norm.item() ** 2
108
+ total_norm = total_norm ** 0.5
109
+ writer.add_scalar('Gradients/total_norm', total_norm, total_steps)
110
+
111
+ def train_udit_flow(num_epochs=5000, initial_batch_sizes=[8, 16, 32, 64, 128], epoch_batch_drop_at=40, device="cuda", dtype=torch.float32):
112
+ dataset = preload_dataset(device=device)
113
+ temp_loader = DataLoader(dataset, batch_size=initial_batch_sizes[0], shuffle=True)
114
+ first_batch = next(iter(temp_loader))
115
+ image_shape = first_batch[0].shape[1:]
116
+
117
+ writer = SummaryWriter('logs/current_run')
118
+
119
+ model = AsymmetricResidualUDiT(
120
+ in_channels=3,
121
+ base_channels=128,
122
+ num_levels=3,
123
+ patch_size=4,
124
+ encoder_blocks=3,
125
+ decoder_blocks=7,
126
+ encoder_transformer_thresh=2,
127
+ decoder_transformer_thresh=4,
128
+ mid_blocks=8
129
+ ).to(device).to(dtype)
130
+ model.train()
131
+
132
+ count_parameters(model)
133
+ optimizer = AdamWScheduleFree(
134
+ model.parameters(),
135
+ lr=1e-4,
136
+ warmup_steps=100
137
+ )
138
+ optimizer.train()
139
+
140
+ current_batch_sizes = initial_batch_sizes.copy()
141
+ next_drop_epoch = epoch_batch_drop_at
142
+ interval_multiplier = 2
143
+
144
+ torch.set_float32_matmul_precision('high')
145
+ model = torch.compile(
146
+ model,
147
+ backend='inductor',
148
+ mode='max-autotune',
149
+ fullgraph=True,
150
+ )
151
+
152
+ flow_transport = OptimalTransportLinearFlowGenerator(sigma_min=0.001)
153
+
154
+ for epoch in range(num_epochs):
155
+ epoch_start_time = time.time()
156
+ total_loss = 0
157
+
158
+ # Batch size decay logic
159
+ # Geomtric growth, every X*N+(X-1*N+...) use the number batch size in the list.
160
+ if epoch > 0 and epoch == next_drop_epoch and len(current_batch_sizes) > 1:
161
+ current_batch_sizes.pop()
162
+ next_interval = epoch_batch_drop_at * interval_multiplier
163
+ next_drop_epoch += next_interval
164
+ interval_multiplier += 1
165
+ print(f"\nEpoch {epoch}: Reducing batch size to {current_batch_sizes[-1]}")
166
+ print(f"Next drop will occur at epoch {next_drop_epoch} (interval: {next_interval})")
167
+
168
+ current_batch_size = current_batch_sizes[-1]
169
+ dataloader = DataLoader(dataset, batch_size=current_batch_size, shuffle=True)
170
+ curr_lr = optimizer.param_groups[0]['lr']
171
+
172
+ with torch.amp.autocast('cuda', dtype=dtype):
173
+ for batch_idx, batch in enumerate(dataloader):
174
+ x1 = batch[0]
175
+ batch_size = x1.shape[0]
176
+
177
+ loss = flow_transport.loss(model, x1, device)
178
+
179
+ optimizer.zero_grad()
180
+ loss.backward()
181
+ optimizer.step()
182
+ total_loss += loss.item()
183
+
184
+ avg_loss = total_loss / len(dataloader)
185
+
186
+ epoch_time = time.time() - epoch_start_time
187
+ print(f"Epoch {epoch}, Took: {epoch_time:.2f}s, Batch Size: {current_batch_size}, "
188
+ f"Average Loss: {avg_loss:.4f}, Learning Rate: {curr_lr:.6f}")
189
+
190
+ write_logs(writer, model, avg_loss, batch_idx, epoch, epoch_time, current_batch_size, curr_lr)
191
+ if (epoch + 1) % 50 == 0:
192
+ with torch.amp.autocast('cuda', dtype=dtype):
193
+ sampling_start_time = time.time()
194
+ samples = sample(model, device=device, dtype=dtype)
195
+ os.makedirs("samples", exist_ok=True)
196
+ vutils.save_image(samples, f"samples/epoch_{epoch}.png", nrow=4, padding=2)
197
+
198
+ sample_time = time.time() - sampling_start_time
199
+ print(f"Sampling took: {sample_time:.2f}s")
200
+
201
+ if (epoch + 1) % 200 == 0:
202
+ save_checkpoint(model, optimizer, f"step_{epoch}.safetensors")
203
+
204
+ return model
205
+
206
+ def sample(model, n_samples=16, n_steps=50, image_size=256, device="cuda", sigma_min=0.001, dtype=torch.float32):
207
+ with torch.amp.autocast('cuda', dtype=dtype):
208
+
209
+ x = torch.randn(n_samples, 3, image_size, image_size, device=device)
210
+ ts = torch.linspace(0, 1, n_steps, device=device)
211
+ dt = 1/n_steps
212
+
213
+ # Forward Euler Integration step 0..1
214
+ with torch.no_grad():
215
+ for i in range(len(ts)):
216
+ t = ts[i]
217
+ t_input = t.repeat(n_samples, 1, 1, 1)
218
+
219
+ v_t = model(x, t_input)
220
+
221
+ x = x + v_t * dt
222
+
223
+ return x.float()
224
+
225
+ if __name__ == "__main__":
226
+ device = "cuda" if torch.cuda.is_available() else "cpu"
227
+ print(f"Using device: {device}")
228
+
229
+ model = train_udit_flow(
230
+ device=device,
231
+ initial_batch_sizes=[8, 16],
232
+ epoch_batch_drop_at=600,
233
+ dtype=torch.float32
234
+ )
235
+
236
+ print("Training complete! Samples saved in 'samples' directory")