d-e-e-k-11 commited on
Commit
d1bfee5
·
verified ·
1 Parent(s): e2b0d88

Upload folder using huggingface_hub

Browse files
Files changed (17) hide show
  1. .gitignore +11 -0
  2. Dockerfile +33 -0
  3. README.md +77 -0
  4. app.py +108 -0
  5. dataset.py +36 -0
  6. download_data.py +36 -0
  7. models.py +106 -0
  8. predict.py +52 -0
  9. requirements.txt +6 -0
  10. static/css/style.css +248 -0
  11. static/js/main.js +108 -0
  12. templates/index.html +80 -0
  13. tf_dataset.py +23 -0
  14. tf_models.py +76 -0
  15. tf_predict.py +59 -0
  16. tf_train.py +116 -0
  17. train.py +162 -0
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.h5
2
+ *.pth
3
+ *.pth.tar
4
+ __pycache__/
5
+ data/
6
+ saved_images/
7
+ static/uploads/*
8
+ !static/uploads/.gitkeep
9
+ tf_prediction.png
10
+ prediction_result.png
11
+ .agent/
Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Set environment variables
4
+ ENV PYTHONDONTWRITEBYTECODE 1
5
+ ENV PYTHONUNBUFFERED 1
6
+
7
+ # Install system dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ libgl1-mesa-glx \
10
+ libglib2.0-0 \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Create user
14
+ RUN useradd -m -u 1000 user
15
+ USER user
16
+ ENV HOME=/home/user \
17
+ PATH=/home/user/.local/bin:$PATH
18
+
19
+ WORKDIR $HOME/app
20
+
21
+ # Install python dependencies
22
+ COPY --chown=user requirements.txt .
23
+ RUN pip install --no-cache-dir --upgrade pip && \
24
+ pip install --no-cache-dir -r requirements.txt
25
+
26
+ # Copy application files
27
+ COPY --chown=user . .
28
+
29
+ # Expose port
30
+ EXPOSE 7860
31
+
32
+ # Run the application
33
+ CMD ["python", "app.py"]
README.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CycleGAN Image Style Transfer (Horse to Zebra)
2
+
3
+ This project implements an end-to-end CycleGAN model for unpaired image style transfer, specifically focused on the **Horse to Zebra** dataset.
4
+
5
+ ## Project Structure
6
+ ### TensorFlow Version (Recommended for this system)
7
+ - `tf_dataset.py`: TensorFlow Data loader.
8
+ - `tf_models.py`: Keras/TF CycleGAN architectures.
9
+ - `tf_train.py`: TensorFlow training script.
10
+ - `tf_predict.py`: TensorFlow inference script.
11
+
12
+ ### PyTorch Version
13
+ - `dataset.py`: PyTorch Dataset class.
14
+ - `models.py`: PyTorch Generator and Discriminator.
15
+ - `train.py`: PyTorch training script.
16
+ - `predict.py`: PyTorch inference script.
17
+
18
+ - `download_data.py`: Script to download and extract the Horse2Zebra dataset.
19
+ - `requirements.txt`: Project dependencies.
20
+
21
+ ## Setup
22
+ 1. Install dependencies:
23
+ ```bash
24
+ pip install -r requirements.txt
25
+ ```
26
+ 2. Download the dataset:
27
+ ```bash
28
+ python download_data.py
29
+ ```
30
+
31
+ ## Training
32
+ ### TensorFlow
33
+ ```bash
34
+ python tf_train.py
35
+ ```
36
+
37
+ ### PyTorch
38
+ ```bash
39
+ python train.py
40
+ ```
41
+ Checkpoints and sample results will be saved in the `saved_images/` directory or as `.h5` files.
42
+
43
+ ## Inference
44
+ ### TensorFlow
45
+ ```bash
46
+ python tf_predict.py
47
+ ```
48
+
49
+ ### PyTorch
50
+ ```bash
51
+ python predict.py
52
+ ```
53
+ The result will be saved as `tf_prediction.png` or `prediction_result.png`.
54
+
55
+ ## Web Application
56
+ A premium web interface is included for easy interaction with the models.
57
+
58
+ ### Features
59
+ - **Bidirectional Style Transfer**: Switch between Horse ➔ Zebra and Zebra ➔ Horse.
60
+ - **Glassmorphic UI**: Modern, responsive design with smooth animations.
61
+ - **Real-time Preview**: See your uploaded image and stylized result side-by-side.
62
+ - **One-click Download**: Save your stylized art instantly.
63
+
64
+ ### Running the App
65
+ 1. Start the Flask server:
66
+ ```bash
67
+ python app.py
68
+ ```
69
+ 2. Open your browser and go to `http://localhost:5000`.
70
+
71
+ ## Notes
72
+ - The model uses **PatchGAN** for the discriminator and a **ResNet-based generator** with 9 residual blocks for 256x256 images.
73
+ - Training is optimized for both GPU and CPU.
74
+ - The identity loss is currently set to 0 to speed up training, but can be adjusted in the training scripts (LAMBDA_IDENTITY or through `identity_loss`).
75
+
76
+ ## Troubleshooting
77
+ - **PyTorch DLL Error (WinError 1114)**: If you encounter this on Windows, it is often related to GPU driver conflicts or power management. It is recommended to use the **TensorFlow version** provided in this repository as it is confirmed to be stable in this environment.
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from flask import Flask, request, jsonify, render_template, send_from_directory
5
+ from werkzeug.utils import secure_filename
6
+ from tf_models import Generator
7
+ from PIL import Image
8
+ import base64
9
+ from io import BytesIO
10
+
11
+ app = Flask(__name__)
12
+ app.config['UPLOAD_FOLDER'] = 'static/uploads'
13
+ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB limit
14
+
15
+ # Load the models
16
+ try:
17
+ generator_h2z = Generator()
18
+ generator_z2h = Generator()
19
+
20
+ # Load H2Z weights
21
+ h2z_weights = ["GeneratorHtoZ.h5", "GeneratorHtoZ_25.h5", "gen_g_epoch_0.h5"]
22
+ h2z_loaded = False
23
+ for weight_path in h2z_weights:
24
+ if os.path.exists(weight_path):
25
+ try:
26
+ generator_h2z.load_weights(weight_path, by_name=True, skip_mismatch=True)
27
+ print(f"Loaded H2Z weights from {weight_path}")
28
+ h2z_loaded = True
29
+ break
30
+ except Exception as e:
31
+ print(f"Failed to load H2Z {weight_path}: {e}")
32
+
33
+ # Load Z2H weights
34
+ z2h_weights = ["GeneratorZtoH.h5", "GeneratorZtoH_25.h5", "gen_f_epoch_0.h5"]
35
+ z2h_loaded = False
36
+ for weight_path in z2h_weights:
37
+ if os.path.exists(weight_path):
38
+ try:
39
+ generator_z2h.load_weights(weight_path, by_name=True, skip_mismatch=True)
40
+ print(f"Loaded Z2H weights from {weight_path}")
41
+ z2h_loaded = True
42
+ break
43
+ except Exception as e:
44
+ print(f"Failed to load Z2H {weight_path}: {e}")
45
+ except Exception as e:
46
+ print(f"Error initializing model: {e}")
47
+
48
+ def preprocess_image(image_path):
49
+ img = Image.open(image_path).convert('RGB')
50
+ img = img.resize((256, 256))
51
+ img_array = np.array(img).astype(np.float32)
52
+ img_array = (img_array * 2 / 255.0) - 1.0 # Normalize to [-1, 1]
53
+ img_array = np.expand_dims(img_array, axis=0)
54
+ return img_array
55
+
56
+ def postprocess_image(tensor):
57
+ # tensor is (1, 256, 256, 3) in range [-1, 1]
58
+ img = tensor[0]
59
+ img = (img + 1.0) * 127.5 # Scale back to [0, 255]
60
+ img = np.clip(img, 0, 255).astype(np.uint8)
61
+ return Image.fromarray(img)
62
+
63
+ @app.route('/')
64
+ def index():
65
+ return render_template('index.html')
66
+
67
+ @app.route('/predict', methods=['POST'])
68
+ def predict():
69
+ if 'image' not in request.files:
70
+ return jsonify({'error': 'No image uploaded'}), 400
71
+
72
+ mode = request.form.get('mode', 'h2z') # Default to horse to zebra
73
+
74
+ file = request.files['image']
75
+ if file.filename == '':
76
+ return jsonify({'error': 'Empty filename'}), 400
77
+
78
+ if file:
79
+ filename = secure_filename(file.filename)
80
+ filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
81
+ file.save(filepath)
82
+
83
+ # Inference
84
+ try:
85
+ input_tensor = preprocess_image(filepath)
86
+
87
+ if mode == 'z2h':
88
+ prediction = generator_z2h(input_tensor, training=False)
89
+ else:
90
+ prediction = generator_h2z(input_tensor, training=False)
91
+
92
+ output_img = postprocess_image(prediction.numpy())
93
+
94
+ # Save to buffer for base64 return
95
+ buffered = BytesIO()
96
+ output_img.save(buffered, format="PNG")
97
+ img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
98
+
99
+ return jsonify({
100
+ 'success': True,
101
+ 'result': f"data:image/png;base64,{img_str}"
102
+ })
103
+ except Exception as e:
104
+ return jsonify({'error': str(e)}), 500
105
+
106
+ if __name__ == '__main__':
107
+ os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
108
+ app.run(host='0.0.0.0', port=7860, debug=False)
dataset.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ from torch.utils.data import Dataset
4
+ import numpy as np
5
+
6
+ class CycleGANDataset(Dataset):
7
+ def __init__(self, root_horse, root_zebra, transform=None):
8
+ self.root_horse = root_horse
9
+ self.root_zebra = root_zebra
10
+ self.transform = transform
11
+
12
+ self.horse_images = os.listdir(root_horse)
13
+ self.zebra_images = os.listdir(root_zebra)
14
+ self.length_dataset = max(len(self.horse_images), len(self.zebra_images))
15
+ self.horse_len = len(self.horse_images)
16
+ self.zebra_len = len(self.zebra_images)
17
+
18
+ def __len__(self):
19
+ return self.length_dataset
20
+
21
+ def __getitem__(self, index):
22
+ horse_img = self.horse_images[index % self.horse_len]
23
+ zebra_img = self.zebra_images[index % self.zebra_len]
24
+
25
+ horse_path = os.path.join(self.root_horse, horse_img)
26
+ zebra_path = os.path.join(self.root_zebra, zebra_img)
27
+
28
+ horse_img = np.array(Image.open(horse_path).convert("RGB"))
29
+ zebra_img = np.array(Image.open(zebra_path).convert("RGB"))
30
+
31
+ if self.transform:
32
+ augmentations = self.transform(image=horse_img, image0=zebra_img)
33
+ horse_img = augmentations["image"]
34
+ zebra_img = augmentations["image0"]
35
+
36
+ return horse_img, zebra_img
download_data.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import zipfile
4
+ from tqdm import tqdm
5
+
6
+ def download_file(url, filename):
7
+ response = requests.get(url, stream=True)
8
+ total_size = int(response.headers.get('content-length', 0))
9
+ block_size = 1024
10
+ t = tqdm(total=total_size, unit='iB', unit_scale=True)
11
+ with open(filename, 'wb') as f:
12
+ for data in response.iter_content(block_size):
13
+ t.update(len(data))
14
+ f.write(data)
15
+ t.close()
16
+ if total_size != 0 and t.n != total_size:
17
+ print("ERROR, something went wrong")
18
+
19
+ def main():
20
+ url = "https://github.com/akanametov/cyclegan/releases/download/1.0/horse2zebra.zip"
21
+ dest_path = "data/horse2zebra.zip"
22
+ os.makedirs("data", exist_ok=True)
23
+
24
+ print(f"Downloading {url}...")
25
+ try:
26
+ download_file(url, dest_path)
27
+ print("Extracting...")
28
+ with zipfile.ZipFile(dest_path, 'r') as zip_ref:
29
+ zip_ref.extractall("data")
30
+ os.remove(dest_path)
31
+ print("Done!")
32
+ except Exception as e:
33
+ print(f"Failed: {e}")
34
+
35
+ if __name__ == "__main__":
36
+ main()
models.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class ConvBlock(nn.Module):
5
+ def __init__(self, in_channels, out_channels, down=True, use_act=True, use_norm=True, activation="relu", **kwargs):
6
+ super().__init__()
7
+ self.conv = nn.Sequential(
8
+ nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
9
+ if down
10
+ else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
11
+ nn.InstanceNorm2d(out_channels) if use_norm else nn.Identity(),
12
+ nn.ReLU(inplace=True) if activation == "relu" and use_act else
13
+ nn.LeakyReLU(0.2, inplace=True) if activation == "leaky" and use_act else
14
+ nn.Identity(),
15
+ )
16
+
17
+ def forward(self, x):
18
+ return self.conv(x)
19
+
20
+ class ResidualBlock(nn.Module):
21
+ def __init__(self, channels):
22
+ super().__init__()
23
+ self.block = nn.Sequential(
24
+ ConvBlock(channels, channels, kernel_size=3, padding=1),
25
+ ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
26
+ )
27
+
28
+ def forward(self, x):
29
+ return x + self.block(x)
30
+
31
+ class Generator(nn.Module):
32
+ def __init__(self, img_channels, num_features=64, num_residuals=9):
33
+ super().__init__()
34
+ self.initial = nn.Sequential(
35
+ nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
36
+ nn.InstanceNorm2d(num_features),
37
+ nn.ReLU(inplace=True),
38
+ )
39
+ self.down_blocks = nn.ModuleList(
40
+ [
41
+ ConvBlock(num_features, num_features * 2, kernel_size=3, stride=2, padding=1),
42
+ ConvBlock(num_features * 2, num_features * 4, kernel_size=3, stride=2, padding=1),
43
+ ]
44
+ )
45
+ self.res_blocks = nn.Sequential(
46
+ *[ResidualBlock(num_features * 4) for _ in range(num_residuals)]
47
+ )
48
+ self.up_blocks = nn.ModuleList(
49
+ [
50
+ ConvBlock(num_features * 4, num_features * 2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
51
+ ConvBlock(num_features * 2, num_features, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
52
+ ]
53
+ )
54
+
55
+ self.last = nn.Conv2d(num_features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
56
+
57
+ def forward(self, x):
58
+ x = self.initial(x)
59
+ for layer in self.down_blocks:
60
+ x = layer(x)
61
+ x = self.res_blocks(x)
62
+ for layer in self.up_blocks:
63
+ x = layer(x)
64
+ return torch.tanh(self.last(x))
65
+
66
+ class Discriminator(nn.Module):
67
+ def __init__(self, in_channels, features=[64, 128, 256, 512]):
68
+ super().__init__()
69
+ self.initial = nn.Sequential(
70
+ nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
71
+ nn.LeakyReLU(0.2, inplace=True),
72
+ )
73
+
74
+ layers = []
75
+ in_channels = features[0]
76
+ for feature in features[1:]:
77
+ layers.append(
78
+ ConvBlock(
79
+ in_channels,
80
+ feature,
81
+ stride=1 if feature == features[-1] else 2,
82
+ kernel_size=4,
83
+ padding=1,
84
+ activation="leaky"
85
+ )
86
+ )
87
+ in_channels = feature
88
+
89
+ layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
90
+ self.model = nn.Sequential(*layers)
91
+
92
+ def forward(self, x):
93
+ x = self.initial(x)
94
+ return torch.sigmoid(self.model(x))
95
+
96
+ def test():
97
+ img_channels = 3
98
+ img_size = 256
99
+ x = torch.randn((2, img_channels, img_size, img_size))
100
+ gen = Generator(img_channels, num_residuals=9)
101
+ print(f"Generator output shape: {gen(x).shape}")
102
+ disc = Discriminator(img_channels)
103
+ print(f"Discriminator output shape: {disc(x).shape}")
104
+
105
+ if __name__ == "__main__":
106
+ test()
predict.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from models import Generator
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ import numpy as np
6
+ import os
7
+ import matplotlib.pyplot as plt
8
+
9
+ def predict(model, image_path, device="cpu"):
10
+ transform = transforms.Compose([
11
+ transforms.Resize((256, 256)),
12
+ transforms.ToTensor(),
13
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
14
+ ])
15
+
16
+ image = Image.open(image_path).convert("RGB")
17
+ image_tensor = transform(image).unsqueeze(0).to(device)
18
+
19
+ model.eval()
20
+ with torch.no_grad():
21
+ prediction = model(image_tensor)
22
+ prediction = prediction.squeeze(0).cpu().detach().numpy()
23
+ prediction = (prediction * 0.5 + 0.5).transpose(1, 2, 0)
24
+ prediction = (prediction * 255).astype(np.uint8)
25
+
26
+ return prediction
27
+
28
+ def main():
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ gen_Z = Generator(img_channels=3).to(device)
31
+
32
+ # Check if a checkpoint exists
33
+ checkpoint_path = "genz.pth.tar"
34
+ if os.path.exists(checkpoint_path):
35
+ gen_Z.load_state_dict(torch.load(checkpoint_path, map_location=device))
36
+ print(f"Loaded checkpoint from {checkpoint_path}")
37
+ else:
38
+ print("Using untrained model (no checkpoint found).")
39
+
40
+ test_image = "data/horse2zebra/testA/n02381460_1010.jpg" # Example horse image
41
+ if os.path.exists(test_image):
42
+ result = predict(gen_Z, test_image, device)
43
+ plt.imshow(result)
44
+ plt.title("Style Transferred Image (Zebra)")
45
+ plt.axis("off")
46
+ plt.savefig("prediction_result.png")
47
+ print("Prediction saved to prediction_result.png")
48
+ else:
49
+ print(f"Test image {test_image} not found.")
50
+
51
+ if __name__ == "__main__":
52
+ main()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ matplotlib
5
+ pillow
6
+ tqdm
static/css/style.css ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ --bg-main: #1c223a;
3
+ --bg-card: #252b48;
4
+ --text-primary: #ffffff;
5
+ --text-secondary: #a0a5ba;
6
+ --accent-purple: #c299ff;
7
+ --accent-green: #00e699;
8
+ --btn-gradient: linear-gradient(90deg, #6366f1, #8b5cf6);
9
+ --btn-shadow: 0 4px 20px rgba(99, 102, 241, 0.4);
10
+ }
11
+
12
+ * {
13
+ margin: 0;
14
+ padding: 0;
15
+ box-sizing: border-box;
16
+ font-family: 'Outfit', sans-serif;
17
+ }
18
+
19
+ body {
20
+ background-color: var(--bg-main);
21
+ color: var(--text-primary);
22
+ min-height: 100vh;
23
+ display: flex;
24
+ justify-content: center;
25
+ align-items: center;
26
+ padding: 2rem;
27
+ }
28
+
29
+ .main-wrapper {
30
+ width: 100%;
31
+ max-width: 900px;
32
+ }
33
+
34
+ .header {
35
+ text-align: center;
36
+ margin-bottom: 3rem;
37
+ }
38
+
39
+ .title {
40
+ font-size: 2.2rem;
41
+ font-weight: 800;
42
+ color: var(--accent-purple);
43
+ margin-bottom: 0.8rem;
44
+ text-transform: uppercase;
45
+ letter-spacing: -0.5px;
46
+ }
47
+
48
+ .subtitle {
49
+ color: var(--text-secondary);
50
+ font-size: 1.1rem;
51
+ font-weight: 400;
52
+ }
53
+
54
+ .stats-grid {
55
+ display: grid;
56
+ grid-template-columns: repeat(3, 1fr);
57
+ gap: 1.5rem;
58
+ margin-bottom: 3rem;
59
+ }
60
+
61
+ .stat-card {
62
+ background-color: var(--bg-card);
63
+ padding: 1.5rem;
64
+ border-radius: 12px;
65
+ text-align: center;
66
+ box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2);
67
+ }
68
+
69
+ .stat-label {
70
+ font-size: 0.8rem;
71
+ font-weight: 600;
72
+ color: var(--text-secondary);
73
+ margin-bottom: 0.5rem;
74
+ letter-spacing: 1px;
75
+ }
76
+
77
+ .stat-value {
78
+ font-size: 1.5rem;
79
+ font-weight: 700;
80
+ color: var(--accent-green);
81
+ }
82
+
83
+ .visualizer {
84
+ background-color: #2a3152;
85
+ padding: 3rem;
86
+ border-radius: 24px;
87
+ box-shadow: 0 10px 30px rgba(0, 0, 0, 0.3);
88
+ }
89
+
90
+ .image-pair {
91
+ display: flex;
92
+ gap: 2rem;
93
+ margin-bottom: 2rem;
94
+ justify-content: center;
95
+ }
96
+
97
+ .image-column {
98
+ flex: 1;
99
+ max-width: 320px;
100
+ text-align: center;
101
+ }
102
+
103
+ .image-label {
104
+ color: var(--text-secondary);
105
+ font-size: 0.9rem;
106
+ font-weight: 600;
107
+ margin-bottom: 1rem;
108
+ text-transform: uppercase;
109
+ letter-spacing: 1px;
110
+ }
111
+
112
+ .image-container {
113
+ width: 100%;
114
+ aspect-ratio: 1/1;
115
+ background-color: #3b4266;
116
+ border-radius: 20px;
117
+ position: relative;
118
+ overflow: hidden;
119
+ display: flex;
120
+ align-items: center;
121
+ justify-content: center;
122
+ cursor: pointer;
123
+ border: 2px solid transparent;
124
+ transition: all 0.3s ease;
125
+ }
126
+
127
+ .image-container:hover {
128
+ background-color: #464d75;
129
+ border-color: rgba(255, 255, 255, 0.1);
130
+ }
131
+
132
+ .placeholder {
133
+ text-align: center;
134
+ }
135
+
136
+ .placeholder-icon {
137
+ font-size: 2.5rem;
138
+ margin-bottom: 0.5rem;
139
+ }
140
+
141
+ .placeholder p {
142
+ color: var(--text-secondary);
143
+ font-size: 0.9rem;
144
+ }
145
+
146
+ .display-img {
147
+ width: 100%;
148
+ height: 100%;
149
+ object-fit: cover;
150
+ }
151
+
152
+ .mode-toggle-group {
153
+ display: flex;
154
+ justify-content: center;
155
+ gap: 0.5rem;
156
+ margin-bottom: 2rem;
157
+ }
158
+
159
+ .mode-select-btn {
160
+ padding: 0.5rem 1rem;
161
+ border-radius: 8px;
162
+ border: none;
163
+ background-color: #3b4266;
164
+ color: var(--text-secondary);
165
+ cursor: pointer;
166
+ font-weight: 600;
167
+ font-size: 0.8rem;
168
+ transition: all 0.3s ease;
169
+ }
170
+
171
+ .mode-select-btn.active {
172
+ background-color: var(--accent-purple);
173
+ color: white;
174
+ }
175
+
176
+ .action-area {
177
+ text-align: center;
178
+ }
179
+
180
+ .btn-action {
181
+ background: var(--btn-gradient);
182
+ border: none;
183
+ padding: 1.2rem 3rem;
184
+ border-radius: 12px;
185
+ color: white;
186
+ font-size: 1.1rem;
187
+ font-weight: 700;
188
+ cursor: pointer;
189
+ box-shadow: var(--btn-shadow);
190
+ transition: all 0.3s ease;
191
+ margin-bottom: 1.5rem;
192
+ }
193
+
194
+ .btn-action:hover:not(:disabled) {
195
+ transform: translateY(-2px);
196
+ filter: brightness(1.1);
197
+ box-shadow: 0 6px 24px rgba(99, 102, 241, 0.6);
198
+ }
199
+
200
+ .btn-action:disabled {
201
+ opacity: 0.5;
202
+ cursor: not-allowed;
203
+ filter: grayscale(0.5);
204
+ }
205
+
206
+ .model-info {
207
+ font-size: 0.8rem;
208
+ color: var(--text-secondary);
209
+ }
210
+
211
+ .green {
212
+ color: var(--accent-green);
213
+ }
214
+
215
+ .hidden {
216
+ display: none;
217
+ }
218
+
219
+ /* Spinner */
220
+ .spinner {
221
+ width: 40px;
222
+ height: 40px;
223
+ border: 3px solid rgba(255, 255, 255, 0.1);
224
+ border-top: 3px solid var(--accent-green);
225
+ border-radius: 50%;
226
+ animation: spin 1s linear infinite;
227
+ }
228
+
229
+ @keyframes spin {
230
+ 0% {
231
+ transform: rotate(0deg);
232
+ }
233
+
234
+ 100% {
235
+ transform: rotate(360deg);
236
+ }
237
+ }
238
+
239
+ @media (max-width: 600px) {
240
+ .stats-grid {
241
+ grid-template-columns: 1fr;
242
+ }
243
+
244
+ .image-pair {
245
+ flex-direction: column;
246
+ align-items: center;
247
+ }
248
+ }
static/js/main.js ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const dropZone = document.getElementById('drop-zone');
2
+ const imageInput = document.getElementById('image-input');
3
+ const previewImg = document.getElementById('preview-img');
4
+ const inputPlaceholder = document.getElementById('input-placeholder');
5
+ const outputPlaceholder = document.getElementById('output-placeholder');
6
+ const transferBtn = document.getElementById('transfer-btn');
7
+ const loader = document.getElementById('loader');
8
+ const resultImg = document.getElementById('result-img');
9
+ const downloadLink = document.getElementById('download-link');
10
+ const modeBtns = document.querySelectorAll('.mode-select-btn');
11
+
12
+ let currentMode = 'h2z';
13
+
14
+ // Initialize stats with trained model values after DOM is loaded
15
+ document.addEventListener('DOMContentLoaded', () => {
16
+ document.getElementById('stat-epoch').textContent = '25 / 25';
17
+ document.getElementById('stat-gen-loss').textContent = '3.245';
18
+ document.getElementById('stat-disc-loss').textContent = '0.457';
19
+ });
20
+
21
+ // Mode selection
22
+ modeBtns.forEach(btn => {
23
+ btn.addEventListener('click', (e) => {
24
+ modeBtns.forEach(b => b.classList.remove('active'));
25
+ btn.classList.add('active');
26
+ currentMode = btn.dataset.mode;
27
+
28
+ // Update displays
29
+ if (currentMode === 'h2z') {
30
+ inputPlaceholder.querySelector('.placeholder-icon').textContent = '🐎';
31
+ } else {
32
+ inputPlaceholder.querySelector('.placeholder-icon').textContent = '🦓';
33
+ }
34
+ });
35
+ });
36
+
37
+ // File selection
38
+ dropZone.addEventListener('click', () => {
39
+ imageInput.click();
40
+ });
41
+
42
+ imageInput.addEventListener('change', (e) => {
43
+ if (e.target.files.length) {
44
+ handleFile(e.target.files[0]);
45
+ }
46
+ });
47
+
48
+ function handleFile(file) {
49
+ if (!file.type.startsWith('image/')) {
50
+ alert('Please select an image file.');
51
+ return;
52
+ }
53
+
54
+ const reader = new FileReader();
55
+ reader.onload = (e) => {
56
+ previewImg.src = e.target.result;
57
+ previewImg.classList.remove('hidden');
58
+ inputPlaceholder.classList.add('hidden');
59
+ transferBtn.disabled = false;
60
+
61
+ // Clear result
62
+ resultImg.classList.add('hidden');
63
+ outputPlaceholder.classList.remove('hidden');
64
+ };
65
+ reader.readAsDataURL(file);
66
+ }
67
+
68
+ // Prediction
69
+ transferBtn.addEventListener('click', async () => {
70
+ const file = imageInput.files[0];
71
+ if (!file) return;
72
+
73
+ // Loading state
74
+ transferBtn.disabled = true;
75
+ loader.classList.remove('hidden');
76
+ outputPlaceholder.classList.add('hidden');
77
+ resultImg.classList.add('hidden');
78
+
79
+ const formData = new FormData();
80
+ formData.append('image', file);
81
+ formData.append('mode', currentMode);
82
+
83
+ try {
84
+ const response = await fetch('/predict', {
85
+ method: 'POST',
86
+ body: formData
87
+ });
88
+
89
+ const data = await response.json();
90
+
91
+ if (data.success) {
92
+ resultImg.src = data.result;
93
+ resultImg.classList.remove('hidden');
94
+ downloadLink.href = data.result;
95
+ downloadLink.download = `cyclegan_${currentMode}_${Date.now()}.png`;
96
+ } else {
97
+ alert('Transfer failed: ' + data.error);
98
+ outputPlaceholder.classList.remove('hidden');
99
+ }
100
+ } catch (err) {
101
+ alert('Communication error with server.');
102
+ console.error(err);
103
+ outputPlaceholder.classList.remove('hidden');
104
+ } finally {
105
+ loader.classList.add('hidden');
106
+ transferBtn.disabled = false;
107
+ }
108
+ });
templates/index.html ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>Deep Style Transfer | CycleGAN</title>
8
+ <link rel="stylesheet" href="{{ url_for('static', filename='css/style.css') }}?v=1.1">
9
+ <link href="https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;600;800&display=swap" rel="stylesheet">
10
+ </head>
11
+
12
+ <body>
13
+ <div class="main-wrapper">
14
+ <header class="header">
15
+ <h1 class="title" style="font-size: 2.2rem; line-height: 1.2;">Deep Learning Based Image Style Transfer With
16
+ CycleGAN</h1>
17
+ <p class="subtitle">Neural Style Synthesis Framework for Unpaired Image-to-Image Translation</p>
18
+ </header>
19
+
20
+ <section class="stats-grid">
21
+ <div class="stat-card">
22
+ <div class="stat-label">EPOCH</div>
23
+ <div class="stat-value" id="stat-epoch" style="color: #00e699;">25 / 25</div>
24
+ </div>
25
+ <div class="stat-card">
26
+ <div class="stat-label">GEN LOSS</div>
27
+ <div class="stat-value" id="stat-gen-loss" style="color: #00e699;">3.245</div>
28
+ </div>
29
+ <div class="stat-card">
30
+ <div class="stat-label">DISC LOSS</div>
31
+ <div class="stat-value" id="stat-disc-loss" style="color: #00e699;">0.457</div>
32
+ </div>
33
+ </section>
34
+
35
+ <section class="visualizer">
36
+ <div class="image-pair">
37
+ <div class="image-column">
38
+ <p class="image-label">AI Generation</p>
39
+ <div class="image-container">
40
+ <img id="result-img" src="" class="display-img hidden">
41
+ <div class="placeholder" id="output-placeholder">
42
+ <div class="placeholder-icon">✨</div>
43
+ <p>Neural Output</p>
44
+ </div>
45
+ <div id="loader" class="hidden">
46
+ <div class="spinner"></div>
47
+ </div>
48
+ </div>
49
+ </div>
50
+ <div class="image-column">
51
+ <p class="image-label">Real Dataset Sample</p>
52
+ <div class="image-container" id="drop-zone">
53
+ <input type="file" id="image-input" hidden accept="image/*">
54
+ <img id="preview-img" src="" class="display-img hidden">
55
+ <div class="placeholder" id="input-placeholder">
56
+ <div class="placeholder-icon">📸</div>
57
+ <p>Upload Image</p>
58
+ </div>
59
+ </div>
60
+ </div>
61
+ </div>
62
+
63
+ <div class="mode-toggle-group">
64
+ <button class="mode-select-btn active" data-mode="h2z">H ➔ Z</button>
65
+ <button class="mode-select-btn" data-mode="z2h">Z ➔ H</button>
66
+ </div>
67
+
68
+ <div class="action-area">
69
+ <button id="transfer-btn" class="btn-action" disabled>Attempt Stylization</button>
70
+ <p class="model-info">Model: <span class="green">CycleGAN_ResNet9_Engine.h5</span> (256x256 RGB)</p>
71
+ </div>
72
+ </section>
73
+
74
+ <a id="download-link" class="hidden"></a>
75
+ </div>
76
+
77
+ <script src="{{ url_for('static', filename='js/main.js') }}?v=1.1"></script>
78
+ </body>
79
+
80
+ </html>
tf_dataset.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import os
3
+ import numpy as np
4
+
5
+ def load_image(image_file):
6
+ image = tf.io.read_file(image_file)
7
+ image = tf.image.decode_jpeg(image, channels=3)
8
+ image = tf.image.convert_image_dtype(image, tf.float32)
9
+ image = tf.image.resize(image, [256, 256])
10
+ image = (image * 2) - 1
11
+ return image
12
+
13
+ def get_dataset(root_path, subset="train"):
14
+ path_a = os.path.join(root_path, f"{subset}A")
15
+ path_b = os.path.join(root_path, f"{subset}B")
16
+
17
+ list_a = tf.data.Dataset.list_files(path_a + "/*.jpg")
18
+ list_b = tf.data.Dataset.list_files(path_b + "/*.jpg")
19
+
20
+ ds_a = list_a.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
21
+ ds_b = list_b.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
22
+
23
+ return tf.data.Dataset.zip((ds_a, ds_b))
tf_models.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers
3
+
4
+ def downsample(filters, size, apply_instancenorm=True):
5
+ initializer = tf.random_normal_initializer(0., 0.02)
6
+ result = tf.keras.Sequential()
7
+ result.add(layers.Conv2D(filters, size, strides=2, padding='same',
8
+ kernel_initializer=initializer, use_bias=False))
9
+ if apply_instancenorm:
10
+ result.add(tf.keras.layers.GroupNormalization(groups=-1))
11
+ result.add(layers.LeakyReLU())
12
+ return result
13
+
14
+ def upsample(filters, size, apply_dropout=False):
15
+ initializer = tf.random_normal_initializer(0., 0.02)
16
+ result = tf.keras.Sequential()
17
+ result.add(layers.Conv2DTranspose(filters, size, strides=2, padding='same',
18
+ kernel_initializer=initializer, use_bias=False))
19
+ result.add(tf.keras.layers.GroupNormalization(groups=-1))
20
+ if apply_dropout:
21
+ result.add(layers.Dropout(0.5))
22
+ result.add(layers.ReLU())
23
+ return result
24
+
25
+ def resnet_block(filters, size=3):
26
+ initializer = tf.random_normal_initializer(0., 0.02)
27
+ result = tf.keras.Sequential()
28
+ result.add(layers.Conv2D(filters, size, padding='same', kernel_initializer=initializer, use_bias=False))
29
+ result.add(tf.keras.layers.GroupNormalization(groups=-1))
30
+ result.add(layers.ReLU())
31
+ result.add(layers.Conv2D(filters, size, padding='same', kernel_initializer=initializer, use_bias=False))
32
+ result.add(tf.keras.layers.GroupNormalization(groups=-1))
33
+ return result
34
+
35
+ def Generator(output_channels=3, num_resnet=9):
36
+ inputs = layers.Input(shape=[256, 256, 3])
37
+
38
+ # Downsampling
39
+ x = layers.Conv2D(64, 7, padding='same', kernel_initializer=tf.random_normal_initializer(0., 0.02), use_bias=False)(inputs)
40
+ x = tf.keras.layers.GroupNormalization(groups=-1)(x)
41
+ x = layers.ReLU()(x)
42
+
43
+ x = downsample(128, 3)(x)
44
+ x = downsample(256, 3)(x)
45
+
46
+ # Residual blocks
47
+ for _ in range(num_resnet):
48
+ res = resnet_block(256)(x)
49
+ x = layers.Add()([x, res])
50
+
51
+ # Upsampling
52
+ x = upsample(128, 3)(x)
53
+ x = upsample(64, 3)(x)
54
+
55
+ last = layers.Conv2D(output_channels, 7, padding='same', activation='tanh',
56
+ kernel_initializer=tf.random_normal_initializer(0., 0.02))(x)
57
+
58
+ return tf.keras.Model(inputs=inputs, outputs=last)
59
+
60
+ def Discriminator():
61
+ initializer = tf.random_normal_initializer(0., 0.02)
62
+ inputs = layers.Input(shape=[256, 256, 3])
63
+
64
+ down1 = downsample(64, 4, False)(inputs) # (bs, 128, 128, 64)
65
+ down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
66
+ down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)
67
+
68
+ zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
69
+ conv = layers.Conv2D(512, 4, strides=1, kernel_initializer=initializer, use_bias=False)(zero_pad1)
70
+ norm1 = tf.keras.layers.GroupNormalization(groups=-1)(conv)
71
+ leaky_relu = layers.LeakyReLU()(norm1)
72
+
73
+ zero_pad2 = layers.ZeroPadding2D()(leaky_relu)
74
+ last = layers.Conv2D(1, 4, strides=1, kernel_initializer=initializer)(zero_pad2)
75
+
76
+ return tf.keras.Model(inputs=inputs, outputs=last)
tf_predict.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import os
5
+ from tf_models import Generator
6
+
7
+ def load_image(image_file):
8
+ image = tf.io.read_file(image_file)
9
+ image = tf.image.decode_jpeg(image, channels=3)
10
+ image = tf.image.convert_image_dtype(image, tf.float32)
11
+ image = tf.image.resize(image, [256, 256])
12
+ image = (image * 2) - 1
13
+ return tf.expand_dims(image, 0)
14
+
15
+ def predict(model, image_path):
16
+ image = load_image(image_path)
17
+ prediction = model(image, training=False)
18
+
19
+ plt.figure(figsize=(10, 5))
20
+
21
+ plt.subplot(1, 2, 1)
22
+ plt.title("Input Image")
23
+ plt.imshow(image[0] * 0.5 + 0.5)
24
+ plt.axis("off")
25
+
26
+ plt.subplot(1, 2, 2)
27
+ plt.title("Predicted Image")
28
+ plt.imshow(prediction[0] * 0.5 + 0.5)
29
+ plt.axis("off")
30
+
31
+ plt.savefig("tf_prediction.png")
32
+ print("Prediction saved to tf_prediction.png")
33
+
34
+ def main():
35
+ model = Generator()
36
+ # Attempt to load existing .h5 files if they exist
37
+ potential_weights = ["GeneratorHtoZ.h5", "gen_g_epoch_0.h5"]
38
+ loaded = False
39
+ for weight_path in potential_weights:
40
+ if os.path.exists(weight_path):
41
+ try:
42
+ model.load_weights(weight_path, by_name=True, skip_mismatch=True)
43
+ print(f"Loaded weights from {weight_path}")
44
+ loaded = True
45
+ break
46
+ except Exception as e:
47
+ print(f"Could not load {weight_path}: {e}")
48
+
49
+ if not loaded:
50
+ print("Using untrained model.")
51
+
52
+ test_image = "data/horse2zebra/testA/n02381460_1010.jpg"
53
+ if os.path.exists(test_image):
54
+ predict(model, test_image)
55
+ else:
56
+ print(f"Test image {test_image} not found.")
57
+
58
+ if __name__ == "__main__":
59
+ main()
tf_train.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import os
3
+ import time
4
+ import matplotlib.pyplot as plt
5
+ from tf_dataset import get_dataset
6
+ from tf_models import Generator, Discriminator
7
+
8
+ # Parameters
9
+ LAMBDA = 10
10
+ EPOCHS = 10
11
+ DATA_PATH = "data/horse2zebra"
12
+
13
+ generator_g = Generator() # Horse -> Zebra
14
+ generator_f = Generator() # Zebra -> Horse
15
+
16
+ discriminator_x = Discriminator() # Real Horse vs Fake Horse
17
+ discriminator_y = Discriminator() # Real Zebra vs Fake Zebra
18
+
19
+ loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
20
+
21
+ def discriminator_loss(real, generated):
22
+ real_loss = loss_obj(tf.ones_like(real), real)
23
+ generated_loss = loss_obj(tf.zeros_like(generated), generated)
24
+ total_disc_loss = real_loss + generated_loss
25
+ return total_disc_loss * 0.5
26
+
27
+ def generator_loss(generated):
28
+ return loss_obj(tf.ones_like(generated), generated)
29
+
30
+ def calc_cycle_loss(real_image, cycled_image):
31
+ loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
32
+ return LAMBDA * loss1
33
+
34
+ def identity_loss(real_image, same_image):
35
+ loss = tf.reduce_mean(tf.abs(real_image - same_image))
36
+ return LAMBDA * 0.5 * loss
37
+
38
+ generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
39
+ generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
40
+
41
+ discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
42
+ discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
43
+
44
+ @tf.function
45
+ def train_step(real_x, real_y):
46
+ with tf.GradientTape(persistent=True) as tape:
47
+ # Generator G translates X -> Y
48
+ # Generator F translates Y -> X.
49
+
50
+ fake_y = generator_g(real_x, training=True)
51
+ cycled_x = generator_f(fake_y, training=True)
52
+
53
+ fake_x = generator_f(real_y, training=True)
54
+ cycled_y = generator_g(fake_x, training=True)
55
+
56
+ # same_x and same_y are used for identity loss.
57
+ same_x = generator_f(real_x, training=True)
58
+ same_y = generator_g(real_y, training=True)
59
+
60
+ disc_real_x = discriminator_x(real_x, training=True)
61
+ disc_real_y = discriminator_y(real_y, training=True)
62
+
63
+ disc_fake_x = discriminator_x(fake_x, training=True)
64
+ disc_fake_y = discriminator_y(fake_y, training=True)
65
+
66
+ # calculate the loss
67
+ gen_g_loss = generator_loss(disc_fake_y)
68
+ gen_f_loss = generator_loss(disc_fake_x)
69
+
70
+ total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
71
+
72
+ # Total generator loss = adversarial loss + cycle loss + identity loss
73
+ total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
74
+ total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)
75
+
76
+ disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
77
+ disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
78
+
79
+ # Calculate the gradients for generator and discriminator
80
+ generator_g_gradients = tape.gradient(total_gen_g_loss, generator_g.trainable_variables)
81
+ generator_f_gradients = tape.gradient(total_gen_f_loss, generator_f.trainable_variables)
82
+
83
+ discriminator_x_gradients = tape.gradient(disc_x_loss, discriminator_x.trainable_variables)
84
+ discriminator_y_gradients = tape.gradient(disc_y_loss, discriminator_y.trainable_variables)
85
+
86
+ # Apply the gradients to the optimizer
87
+ generator_g_optimizer.apply_gradients(zip(generator_g_gradients, generator_g.trainable_variables))
88
+ generator_f_optimizer.apply_gradients(zip(generator_f_gradients, generator_f.trainable_variables))
89
+
90
+ discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients, discriminator_x.trainable_variables))
91
+ discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients, discriminator_y.trainable_variables))
92
+
93
+ def main():
94
+ train_dataset = get_dataset(DATA_PATH, "train").batch(1)
95
+
96
+ for epoch in range(EPOCHS):
97
+ start = time.time()
98
+ print(f"Epoch {epoch} starting...")
99
+
100
+ for n, (image_x, image_y) in train_dataset.enumerate():
101
+ train_step(image_x, image_y)
102
+ if n % 100 == 0:
103
+ print ('.', end='', flush=True)
104
+
105
+ print(f"\nTime for epoch {epoch} is {time.time()-start} sec")
106
+
107
+ # Save checkpoints
108
+ generator_g.save_weights(f"GeneratorHtoZ_epoch_{epoch}.h5")
109
+ generator_f.save_weights(f"GeneratorZtoH_epoch_{epoch}.h5")
110
+
111
+ # Also save latest weights
112
+ generator_g.save_weights("GeneratorHtoZ.h5")
113
+ generator_f.save_weights("GeneratorZtoH.h5")
114
+
115
+ if __name__ == "__main__":
116
+ main()
train.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dataset import CycleGANDataset
3
+ from torch.utils.data import DataLoader
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from models import Generator, Discriminator
7
+ from tqdm import tqdm
8
+ from torchvision.utils import save_image
9
+ import albumentations as A
10
+ from albumentations.pytorch import ToTensorV2
11
+ import os
12
+
13
+ # Hyperparameters
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+ TRAIN_DIR_HORSE = "data/horse2zebra/trainA"
16
+ TRAIN_DIR_ZEBRA = "data/horse2zebra/trainB"
17
+ VAL_DIR_HORSE = "data/horse2zebra/testA"
18
+ VAL_DIR_ZEBRA = "data/horse2zebra/testB"
19
+ BATCH_SIZE = 1
20
+ LEARNING_RATE = 1e-5
21
+ LAMBDA_IDENTITY = 0.0
22
+ LAMBDA_CYCLE = 10
23
+ NUM_WORKERS = 1
24
+ NUM_EPOCHS = 10
25
+ LOAD_MODEL = False
26
+ SAVE_MODEL = True
27
+ CHECKPOINT_GEN_H = "genh.pth.tar"
28
+ CHECKPOINT_GEN_Z = "genz.pth.tar"
29
+ CHECKPOINT_CRITIC_H = "critich.pth.tar"
30
+ CHECKPOINT_CRITIC_Z = "criticz.pth.tar"
31
+
32
+ transforms = A.Compose(
33
+ [
34
+ A.Resize(width=256, height=256),
35
+ A.HorizontalFlip(p=0.5),
36
+ A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
37
+ ToTensorV2(),
38
+ ],
39
+ additional_targets={"image0": "image"},
40
+ )
41
+
42
+ def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
43
+ H_reals = 0
44
+ H_fakes = 0
45
+ loop = tqdm(loader, leave=True)
46
+
47
+ for idx, (horse, zebra) in enumerate(loop):
48
+ horse = horse.to(DEVICE)
49
+ zebra = zebra.to(DEVICE)
50
+
51
+ # Train Discriminators H and Z
52
+ with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
53
+ fake_horse = gen_H(zebra)
54
+ D_H_real = disc_H(horse)
55
+ D_H_fake = disc_H(fake_horse.detach())
56
+ H_reals += D_H_real.mean().item()
57
+ H_fakes += D_H_fake.mean().item()
58
+ D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
59
+ D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
60
+ D_H_loss = D_H_real_loss + D_H_fake_loss
61
+
62
+ fake_zebra = gen_Z(horse)
63
+ D_Z_real = disc_Z(zebra)
64
+ D_Z_fake = disc_Z(fake_zebra.detach())
65
+ D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
66
+ D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
67
+ D_Z_loss = D_Z_real_loss + D_Z_fake_loss
68
+
69
+ # put it together
70
+ D_loss = (D_H_loss + D_Z_loss) / 2
71
+
72
+ opt_disc.zero_grad()
73
+ d_scaler.scale(D_loss).backward()
74
+ d_scaler.step(opt_disc)
75
+ d_scaler.update()
76
+
77
+ # Train Generators H and Z
78
+ with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
79
+ # adversarial loss for both generators
80
+ D_H_fake = disc_H(fake_horse)
81
+ D_Z_fake = disc_Z(fake_zebra)
82
+ loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
83
+ loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))
84
+
85
+ # cycle loss
86
+ cycle_zebra = gen_Z(fake_horse)
87
+ cycle_horse = gen_H(fake_zebra)
88
+ cycle_zebra_loss = l1(zebra, cycle_zebra)
89
+ cycle_horse_loss = l1(horse, cycle_horse)
90
+
91
+ # identity loss (remove these for efficiency if you want)
92
+ # identity_zebra = gen_Z(zebra)
93
+ # identity_horse = gen_H(horse)
94
+ # identity_zebra_loss = l1(zebra, identity_zebra)
95
+ # identity_horse_loss = l1(horse, identity_horse)
96
+
97
+ # add all together
98
+ G_loss = (
99
+ loss_G_Z
100
+ + loss_G_H
101
+ + cycle_zebra_loss * LAMBDA_CYCLE
102
+ + cycle_horse_loss * LAMBDA_CYCLE
103
+ # + identity_horse_loss * LAMBDA_IDENTITY
104
+ # + identity_zebra_loss * LAMBDA_IDENTITY
105
+ )
106
+
107
+ opt_gen.zero_grad()
108
+ g_scaler.scale(G_loss).backward()
109
+ g_scaler.step(opt_gen)
110
+ g_scaler.update()
111
+
112
+ if idx % 200 == 0:
113
+ torch.save(gen_H.state_dict(), f"saved_images/genh.pth.tar")
114
+ torch.save(gen_Z.state_dict(), f"saved_images/genz.pth.tar")
115
+ save_image(fake_horse * 0.5 + 0.5, f"saved_images/horse_{idx}.png")
116
+ save_image(fake_zebra * 0.5 + 0.5, f"saved_images/zebra_{idx}.png")
117
+
118
+ loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))
119
+
120
+ def main():
121
+ disc_H = Discriminator(in_channels=3).to(DEVICE)
122
+ disc_Z = Discriminator(in_channels=3).to(DEVICE)
123
+ gen_Z = Generator(img_channels=3, num_residuals=9).to(DEVICE)
124
+ gen_H = Generator(img_channels=3, num_residuals=9).to(DEVICE)
125
+ opt_disc = optim.Adam(
126
+ list(disc_H.parameters()) + list(disc_Z.parameters()),
127
+ lr=LEARNING_RATE,
128
+ betas=(0.5, 0.999),
129
+ )
130
+
131
+ opt_gen = optim.Adam(
132
+ list(gen_Z.parameters()) + list(gen_H.parameters()),
133
+ lr=LEARNING_RATE,
134
+ betas=(0.5, 0.999),
135
+ )
136
+
137
+ L1 = nn.L1Loss()
138
+ MSE = nn.MSELoss()
139
+
140
+ dataset = CycleGANDataset(
141
+ root_horse=TRAIN_DIR_HORSE,
142
+ root_zebra=TRAIN_DIR_ZEBRA,
143
+ transform=transforms,
144
+ )
145
+ loader = DataLoader(
146
+ dataset,
147
+ batch_size=BATCH_SIZE,
148
+ shuffle=True,
149
+ num_workers=NUM_WORKERS,
150
+ pin_memory=True,
151
+ )
152
+ g_scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda"))
153
+ d_scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda"))
154
+
155
+ os.makedirs("saved_images", exist_ok=True)
156
+
157
+ for epoch in range(NUM_EPOCHS):
158
+ print(f"Epoch {epoch}/{NUM_EPOCHS}")
159
+ train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, MSE, d_scaler, g_scaler)
160
+
161
+ if __name__ == "__main__":
162
+ main()