SecondFussion / train.py
SkillForge45's picture
Create train.py
67dc5db verified
import os
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from model import DiffusionModel, UNet
from torchvision.datasets import CocoCaptions
import argparse
from tqdm import tqdm
# Config
IMAGE_SIZE = 256
BATCH_SIZE = 16
EPOCHS = 50
LR = 2e-5
TIMESTEPS = 1000
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_coco_dataset():
transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = CocoCaptions(
root='./train2017',
annFile='./annotations/captions_train2017.json',
transform=transform
)
dataloader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=4,
collate_fn=lambda x: (torch.stack([item[0] for item in x]), [item[1] for item in x])
)
return dataloader
def train():
# Setup
model = UNet().to(DEVICE)
betas = torch.linspace(1e-4, 0.02, TIMESTEPS).to(DEVICE)
diffusion = DiffusionModel(model, betas, DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
dataloader = load_coco_dataset()
# Training loop
for epoch in range(EPOCHS):
pbar = tqdm(dataloader)
for images, captions in pbar:
images = images.to(DEVICE)
# Flatten captions (5 per image) and repeat images
captions = [cap for sublist in captions for cap in sublist]
images = images.repeat_interleave(5, dim=0)
# Sample timesteps
t = torch.randint(0, TIMESTEPS, (images.shape[0],), device=DEVICE).long()
# Compute loss
loss = diffusion.p_losses(images, captions, t)
# Optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description(f"Epoch {epoch}, Loss: {loss.item():.4f}")
# Save checkpoint
torch.save(model.state_dict(), f"diffusion_model_epoch_{epoch}.pth")
if __name__ == "__main__":
train()