| """ |
| train.py β Train your mini-style-transfer model |
| |
| Usage: |
| python train.py --style starry_night.jpg --output starry_night.pth |
| |
| What this script does: |
| 1. Loads your style image (the painting) |
| 2. Loops over MS-COCO images (content images β everyday photos) |
| 3. For each photo: runs it through StyleNet, compares result to style |
| 4. Updates model weights so outputs look more like the style painting |
| 5. Saves your trained model as a .pth file |
| |
| Beginner tip: Think of training as teaching the model by example. |
| You show it thousands of photos and say "make them look like Van Gogh". |
| After enough examples, it learns to do it on its own. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torchvision import transforms, models |
| from torch.utils.data import DataLoader, Dataset |
| from PIL import Image |
| import os |
| import argparse |
| from model import StyleNet |
|
|
|
|
| |
|
|
| IMAGE_SIZE = 256 |
| BATCH_SIZE = 4 |
| EPOCHS = 2 |
| LR = 1e-3 |
| CONTENT_W = 1.0 |
| STYLE_W = 1e5 |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| |
|
|
| class ImageFolderDataset(Dataset): |
| """Loads all images from a folder. Use MS-COCO train2017 images.""" |
| def __init__(self, folder, transform): |
| self.paths = [ |
| os.path.join(folder, f) for f in os.listdir(folder) |
| if f.lower().endswith(('.jpg', '.jpeg', '.png')) |
| ] |
| self.transform = transform |
|
|
| def __len__(self): |
| return len(self.paths) |
|
|
| def __getitem__(self, idx): |
| img = Image.open(self.paths[idx]).convert("RGB") |
| return self.transform(img) |
|
|
|
|
| |
| |
| |
|
|
| class VGGLoss(nn.Module): |
| def __init__(self): |
| super().__init__() |
| vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features |
| |
| self.slice1 = nn.Sequential(*list(vgg)[:4]).eval() |
| self.slice2 = nn.Sequential(*list(vgg)[4:9]).eval() |
| self.slice3 = nn.Sequential(*list(vgg)[9:16]).eval() |
| for p in self.parameters(): |
| p.requires_grad = False |
|
|
| def forward(self, x): |
| h1 = self.slice1(x) |
| h2 = self.slice2(h1) |
| h3 = self.slice3(h2) |
| return h1, h2, h3 |
|
|
| def gram_matrix(feat): |
| """Style is captured as correlations between feature maps (Gram matrix).""" |
| B, C, H, W = feat.shape |
| feat = feat.view(B, C, H * W) |
| return torch.bmm(feat, feat.transpose(1, 2)) / (C * H * W) |
|
|
|
|
| |
|
|
| def train(style_image_path, content_folder, output_path): |
| print(f"Device: {DEVICE}") |
| print(f"Style: {style_image_path}") |
| print(f"Output: {output_path}\n") |
|
|
| transform = transforms.Compose([ |
| transforms.Resize(IMAGE_SIZE), |
| transforms.CenterCrop(IMAGE_SIZE), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| |
| style_img = transform(Image.open(style_image_path).convert("RGB")) |
| style_img = style_img.unsqueeze(0).to(DEVICE) |
|
|
| dataset = ImageFolderDataset(content_folder, transform) |
| loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2) |
|
|
| model = StyleNet().to(DEVICE) |
| vgg = VGGLoss().to(DEVICE) |
| optimizer = optim.Adam(model.parameters(), lr=LR) |
| mse = nn.MSELoss() |
|
|
| |
| with torch.no_grad(): |
| s1, s2, s3 = vgg(style_img) |
| style_grams = [gram_matrix(s1), gram_matrix(s2), gram_matrix(s3)] |
|
|
| print(f"Training on {len(dataset)} images for {EPOCHS} epochs...") |
| print("β" * 50) |
|
|
| for epoch in range(EPOCHS): |
| for i, content in enumerate(loader): |
| content = content.to(DEVICE) |
| optimizer.zero_grad() |
|
|
| |
| styled = model(content) |
|
|
| |
| _, c_feat, _ = vgg(content) |
| _, s_feat, _ = vgg(styled) |
| content_loss = mse(s_feat, c_feat.detach()) |
|
|
| |
| o1, o2, o3 = vgg(styled) |
| style_loss = ( |
| mse(gram_matrix(o1), style_grams[0].expand(content.size(0), -1, -1)) + |
| mse(gram_matrix(o2), style_grams[1].expand(content.size(0), -1, -1)) + |
| mse(gram_matrix(o3), style_grams[2].expand(content.size(0), -1, -1)) |
| ) |
|
|
| loss = CONTENT_W * content_loss + STYLE_W * style_loss |
| loss.backward() |
| optimizer.step() |
|
|
| if i % 100 == 0: |
| print(f"Epoch {epoch+1}/{EPOCHS} Batch {i:4d}/{len(loader)}" |
| f" Loss: {loss.item():.2f}" |
| f" (content {content_loss.item():.3f}" |
| f" style {style_loss.item():.2f})") |
|
|
| torch.save(model.state_dict(), output_path) |
| print(f"\nDone! Model saved to: {output_path}") |
| print(f"Upload to HuggingFace: huggingface-cli upload your-username/mini-style-transfer {output_path}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--style", required=True, help="Path to your style painting image") |
| parser.add_argument("--content", default="coco/", help="Folder of training photos (MS-COCO)") |
| parser.add_argument("--output", default="style_model.pth", help="Output .pth file name") |
| args = parser.parse_args() |
| train(args.style, args.content, args.output) |
|
|