|
|
import os |
|
|
import torch |
|
|
import torchvision |
|
|
from torch.utils.data import DataLoader |
|
|
from torchvision import transforms |
|
|
from model import DiffusionModel, UNet |
|
|
from torchvision.datasets import CocoCaptions |
|
|
import argparse |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
IMAGE_SIZE = 256 |
|
|
BATCH_SIZE = 16 |
|
|
EPOCHS = 50 |
|
|
LR = 2e-5 |
|
|
TIMESTEPS = 1000 |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def load_coco_dataset(): |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
|
]) |
|
|
|
|
|
dataset = CocoCaptions( |
|
|
root='./train2017', |
|
|
annFile='./annotations/captions_train2017.json', |
|
|
transform=transform |
|
|
) |
|
|
|
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=BATCH_SIZE, |
|
|
shuffle=True, |
|
|
num_workers=4, |
|
|
collate_fn=lambda x: (torch.stack([item[0] for item in x]), [item[1] for item in x]) |
|
|
) |
|
|
return dataloader |
|
|
|
|
|
def train(): |
|
|
|
|
|
model = UNet().to(DEVICE) |
|
|
betas = torch.linspace(1e-4, 0.02, TIMESTEPS).to(DEVICE) |
|
|
diffusion = DiffusionModel(model, betas, DEVICE) |
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=LR) |
|
|
dataloader = load_coco_dataset() |
|
|
|
|
|
|
|
|
for epoch in range(EPOCHS): |
|
|
pbar = tqdm(dataloader) |
|
|
for images, captions in pbar: |
|
|
images = images.to(DEVICE) |
|
|
|
|
|
|
|
|
captions = [cap for sublist in captions for cap in sublist] |
|
|
images = images.repeat_interleave(5, dim=0) |
|
|
|
|
|
|
|
|
t = torch.randint(0, TIMESTEPS, (images.shape[0],), device=DEVICE).long() |
|
|
|
|
|
|
|
|
loss = diffusion.p_losses(images, captions, t) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
pbar.set_description(f"Epoch {epoch}, Loss: {loss.item():.4f}") |
|
|
|
|
|
|
|
|
torch.save(model.state_dict(), f"diffusion_model_epoch_{epoch}.pth") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
train() |