Upload folder using huggingface_hub
Browse files- .gitignore +11 -0
- Dockerfile +33 -0
- README.md +77 -0
- app.py +108 -0
- dataset.py +36 -0
- download_data.py +36 -0
- models.py +106 -0
- predict.py +52 -0
- requirements.txt +6 -0
- static/css/style.css +248 -0
- static/js/main.js +108 -0
- templates/index.html +80 -0
- tf_dataset.py +23 -0
- tf_models.py +76 -0
- tf_predict.py +59 -0
- tf_train.py +116 -0
- train.py +162 -0
.gitignore
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.h5
|
| 2 |
+
*.pth
|
| 3 |
+
*.pth.tar
|
| 4 |
+
__pycache__/
|
| 5 |
+
data/
|
| 6 |
+
saved_images/
|
| 7 |
+
static/uploads/*
|
| 8 |
+
!static/uploads/.gitkeep
|
| 9 |
+
tf_prediction.png
|
| 10 |
+
prediction_result.png
|
| 11 |
+
.agent/
|
Dockerfile
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
# Set environment variables
|
| 4 |
+
ENV PYTHONDONTWRITEBYTECODE 1
|
| 5 |
+
ENV PYTHONUNBUFFERED 1
|
| 6 |
+
|
| 7 |
+
# Install system dependencies
|
| 8 |
+
RUN apt-get update && apt-get install -y \
|
| 9 |
+
libgl1-mesa-glx \
|
| 10 |
+
libglib2.0-0 \
|
| 11 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
+
|
| 13 |
+
# Create user
|
| 14 |
+
RUN useradd -m -u 1000 user
|
| 15 |
+
USER user
|
| 16 |
+
ENV HOME=/home/user \
|
| 17 |
+
PATH=/home/user/.local/bin:$PATH
|
| 18 |
+
|
| 19 |
+
WORKDIR $HOME/app
|
| 20 |
+
|
| 21 |
+
# Install python dependencies
|
| 22 |
+
COPY --chown=user requirements.txt .
|
| 23 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 24 |
+
pip install --no-cache-dir -r requirements.txt
|
| 25 |
+
|
| 26 |
+
# Copy application files
|
| 27 |
+
COPY --chown=user . .
|
| 28 |
+
|
| 29 |
+
# Expose port
|
| 30 |
+
EXPOSE 7860
|
| 31 |
+
|
| 32 |
+
# Run the application
|
| 33 |
+
CMD ["python", "app.py"]
|
README.md
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CycleGAN Image Style Transfer (Horse to Zebra)
|
| 2 |
+
|
| 3 |
+
This project implements an end-to-end CycleGAN model for unpaired image style transfer, specifically focused on the **Horse to Zebra** dataset.
|
| 4 |
+
|
| 5 |
+
## Project Structure
|
| 6 |
+
### TensorFlow Version (Recommended for this system)
|
| 7 |
+
- `tf_dataset.py`: TensorFlow Data loader.
|
| 8 |
+
- `tf_models.py`: Keras/TF CycleGAN architectures.
|
| 9 |
+
- `tf_train.py`: TensorFlow training script.
|
| 10 |
+
- `tf_predict.py`: TensorFlow inference script.
|
| 11 |
+
|
| 12 |
+
### PyTorch Version
|
| 13 |
+
- `dataset.py`: PyTorch Dataset class.
|
| 14 |
+
- `models.py`: PyTorch Generator and Discriminator.
|
| 15 |
+
- `train.py`: PyTorch training script.
|
| 16 |
+
- `predict.py`: PyTorch inference script.
|
| 17 |
+
|
| 18 |
+
- `download_data.py`: Script to download and extract the Horse2Zebra dataset.
|
| 19 |
+
- `requirements.txt`: Project dependencies.
|
| 20 |
+
|
| 21 |
+
## Setup
|
| 22 |
+
1. Install dependencies:
|
| 23 |
+
```bash
|
| 24 |
+
pip install -r requirements.txt
|
| 25 |
+
```
|
| 26 |
+
2. Download the dataset:
|
| 27 |
+
```bash
|
| 28 |
+
python download_data.py
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Training
|
| 32 |
+
### TensorFlow
|
| 33 |
+
```bash
|
| 34 |
+
python tf_train.py
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
### PyTorch
|
| 38 |
+
```bash
|
| 39 |
+
python train.py
|
| 40 |
+
```
|
| 41 |
+
Checkpoints and sample results will be saved in the `saved_images/` directory or as `.h5` files.
|
| 42 |
+
|
| 43 |
+
## Inference
|
| 44 |
+
### TensorFlow
|
| 45 |
+
```bash
|
| 46 |
+
python tf_predict.py
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
### PyTorch
|
| 50 |
+
```bash
|
| 51 |
+
python predict.py
|
| 52 |
+
```
|
| 53 |
+
The result will be saved as `tf_prediction.png` or `prediction_result.png`.
|
| 54 |
+
|
| 55 |
+
## Web Application
|
| 56 |
+
A premium web interface is included for easy interaction with the models.
|
| 57 |
+
|
| 58 |
+
### Features
|
| 59 |
+
- **Bidirectional Style Transfer**: Switch between Horse ➔ Zebra and Zebra ➔ Horse.
|
| 60 |
+
- **Glassmorphic UI**: Modern, responsive design with smooth animations.
|
| 61 |
+
- **Real-time Preview**: See your uploaded image and stylized result side-by-side.
|
| 62 |
+
- **One-click Download**: Save your stylized art instantly.
|
| 63 |
+
|
| 64 |
+
### Running the App
|
| 65 |
+
1. Start the Flask server:
|
| 66 |
+
```bash
|
| 67 |
+
python app.py
|
| 68 |
+
```
|
| 69 |
+
2. Open your browser and go to `http://localhost:5000`.
|
| 70 |
+
|
| 71 |
+
## Notes
|
| 72 |
+
- The model uses **PatchGAN** for the discriminator and a **ResNet-based generator** with 9 residual blocks for 256x256 images.
|
| 73 |
+
- Training is optimized for both GPU and CPU.
|
| 74 |
+
- The identity loss is currently set to 0 to speed up training, but can be adjusted in the training scripts (LAMBDA_IDENTITY or through `identity_loss`).
|
| 75 |
+
|
| 76 |
+
## Troubleshooting
|
| 77 |
+
- **PyTorch DLL Error (WinError 1114)**: If you encounter this on Windows, it is often related to GPU driver conflicts or power management. It is recommended to use the **TensorFlow version** provided in this repository as it is confirmed to be stable in this environment.
|
app.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
import numpy as np
|
| 4 |
+
from flask import Flask, request, jsonify, render_template, send_from_directory
|
| 5 |
+
from werkzeug.utils import secure_filename
|
| 6 |
+
from tf_models import Generator
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import base64
|
| 9 |
+
from io import BytesIO
|
| 10 |
+
|
| 11 |
+
app = Flask(__name__)
|
| 12 |
+
app.config['UPLOAD_FOLDER'] = 'static/uploads'
|
| 13 |
+
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB limit
|
| 14 |
+
|
| 15 |
+
# Load the models
|
| 16 |
+
try:
|
| 17 |
+
generator_h2z = Generator()
|
| 18 |
+
generator_z2h = Generator()
|
| 19 |
+
|
| 20 |
+
# Load H2Z weights
|
| 21 |
+
h2z_weights = ["GeneratorHtoZ.h5", "GeneratorHtoZ_25.h5", "gen_g_epoch_0.h5"]
|
| 22 |
+
h2z_loaded = False
|
| 23 |
+
for weight_path in h2z_weights:
|
| 24 |
+
if os.path.exists(weight_path):
|
| 25 |
+
try:
|
| 26 |
+
generator_h2z.load_weights(weight_path, by_name=True, skip_mismatch=True)
|
| 27 |
+
print(f"Loaded H2Z weights from {weight_path}")
|
| 28 |
+
h2z_loaded = True
|
| 29 |
+
break
|
| 30 |
+
except Exception as e:
|
| 31 |
+
print(f"Failed to load H2Z {weight_path}: {e}")
|
| 32 |
+
|
| 33 |
+
# Load Z2H weights
|
| 34 |
+
z2h_weights = ["GeneratorZtoH.h5", "GeneratorZtoH_25.h5", "gen_f_epoch_0.h5"]
|
| 35 |
+
z2h_loaded = False
|
| 36 |
+
for weight_path in z2h_weights:
|
| 37 |
+
if os.path.exists(weight_path):
|
| 38 |
+
try:
|
| 39 |
+
generator_z2h.load_weights(weight_path, by_name=True, skip_mismatch=True)
|
| 40 |
+
print(f"Loaded Z2H weights from {weight_path}")
|
| 41 |
+
z2h_loaded = True
|
| 42 |
+
break
|
| 43 |
+
except Exception as e:
|
| 44 |
+
print(f"Failed to load Z2H {weight_path}: {e}")
|
| 45 |
+
except Exception as e:
|
| 46 |
+
print(f"Error initializing model: {e}")
|
| 47 |
+
|
| 48 |
+
def preprocess_image(image_path):
|
| 49 |
+
img = Image.open(image_path).convert('RGB')
|
| 50 |
+
img = img.resize((256, 256))
|
| 51 |
+
img_array = np.array(img).astype(np.float32)
|
| 52 |
+
img_array = (img_array * 2 / 255.0) - 1.0 # Normalize to [-1, 1]
|
| 53 |
+
img_array = np.expand_dims(img_array, axis=0)
|
| 54 |
+
return img_array
|
| 55 |
+
|
| 56 |
+
def postprocess_image(tensor):
|
| 57 |
+
# tensor is (1, 256, 256, 3) in range [-1, 1]
|
| 58 |
+
img = tensor[0]
|
| 59 |
+
img = (img + 1.0) * 127.5 # Scale back to [0, 255]
|
| 60 |
+
img = np.clip(img, 0, 255).astype(np.uint8)
|
| 61 |
+
return Image.fromarray(img)
|
| 62 |
+
|
| 63 |
+
@app.route('/')
|
| 64 |
+
def index():
|
| 65 |
+
return render_template('index.html')
|
| 66 |
+
|
| 67 |
+
@app.route('/predict', methods=['POST'])
|
| 68 |
+
def predict():
|
| 69 |
+
if 'image' not in request.files:
|
| 70 |
+
return jsonify({'error': 'No image uploaded'}), 400
|
| 71 |
+
|
| 72 |
+
mode = request.form.get('mode', 'h2z') # Default to horse to zebra
|
| 73 |
+
|
| 74 |
+
file = request.files['image']
|
| 75 |
+
if file.filename == '':
|
| 76 |
+
return jsonify({'error': 'Empty filename'}), 400
|
| 77 |
+
|
| 78 |
+
if file:
|
| 79 |
+
filename = secure_filename(file.filename)
|
| 80 |
+
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
| 81 |
+
file.save(filepath)
|
| 82 |
+
|
| 83 |
+
# Inference
|
| 84 |
+
try:
|
| 85 |
+
input_tensor = preprocess_image(filepath)
|
| 86 |
+
|
| 87 |
+
if mode == 'z2h':
|
| 88 |
+
prediction = generator_z2h(input_tensor, training=False)
|
| 89 |
+
else:
|
| 90 |
+
prediction = generator_h2z(input_tensor, training=False)
|
| 91 |
+
|
| 92 |
+
output_img = postprocess_image(prediction.numpy())
|
| 93 |
+
|
| 94 |
+
# Save to buffer for base64 return
|
| 95 |
+
buffered = BytesIO()
|
| 96 |
+
output_img.save(buffered, format="PNG")
|
| 97 |
+
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 98 |
+
|
| 99 |
+
return jsonify({
|
| 100 |
+
'success': True,
|
| 101 |
+
'result': f"data:image/png;base64,{img_str}"
|
| 102 |
+
})
|
| 103 |
+
except Exception as e:
|
| 104 |
+
return jsonify({'error': str(e)}), 500
|
| 105 |
+
|
| 106 |
+
if __name__ == '__main__':
|
| 107 |
+
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
| 108 |
+
app.run(host='0.0.0.0', port=7860, debug=False)
|
dataset.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from torch.utils.data import Dataset
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
class CycleGANDataset(Dataset):
|
| 7 |
+
def __init__(self, root_horse, root_zebra, transform=None):
|
| 8 |
+
self.root_horse = root_horse
|
| 9 |
+
self.root_zebra = root_zebra
|
| 10 |
+
self.transform = transform
|
| 11 |
+
|
| 12 |
+
self.horse_images = os.listdir(root_horse)
|
| 13 |
+
self.zebra_images = os.listdir(root_zebra)
|
| 14 |
+
self.length_dataset = max(len(self.horse_images), len(self.zebra_images))
|
| 15 |
+
self.horse_len = len(self.horse_images)
|
| 16 |
+
self.zebra_len = len(self.zebra_images)
|
| 17 |
+
|
| 18 |
+
def __len__(self):
|
| 19 |
+
return self.length_dataset
|
| 20 |
+
|
| 21 |
+
def __getitem__(self, index):
|
| 22 |
+
horse_img = self.horse_images[index % self.horse_len]
|
| 23 |
+
zebra_img = self.zebra_images[index % self.zebra_len]
|
| 24 |
+
|
| 25 |
+
horse_path = os.path.join(self.root_horse, horse_img)
|
| 26 |
+
zebra_path = os.path.join(self.root_zebra, zebra_img)
|
| 27 |
+
|
| 28 |
+
horse_img = np.array(Image.open(horse_path).convert("RGB"))
|
| 29 |
+
zebra_img = np.array(Image.open(zebra_path).convert("RGB"))
|
| 30 |
+
|
| 31 |
+
if self.transform:
|
| 32 |
+
augmentations = self.transform(image=horse_img, image0=zebra_img)
|
| 33 |
+
horse_img = augmentations["image"]
|
| 34 |
+
zebra_img = augmentations["image0"]
|
| 35 |
+
|
| 36 |
+
return horse_img, zebra_img
|
download_data.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import requests
|
| 3 |
+
import zipfile
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
def download_file(url, filename):
|
| 7 |
+
response = requests.get(url, stream=True)
|
| 8 |
+
total_size = int(response.headers.get('content-length', 0))
|
| 9 |
+
block_size = 1024
|
| 10 |
+
t = tqdm(total=total_size, unit='iB', unit_scale=True)
|
| 11 |
+
with open(filename, 'wb') as f:
|
| 12 |
+
for data in response.iter_content(block_size):
|
| 13 |
+
t.update(len(data))
|
| 14 |
+
f.write(data)
|
| 15 |
+
t.close()
|
| 16 |
+
if total_size != 0 and t.n != total_size:
|
| 17 |
+
print("ERROR, something went wrong")
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
url = "https://github.com/akanametov/cyclegan/releases/download/1.0/horse2zebra.zip"
|
| 21 |
+
dest_path = "data/horse2zebra.zip"
|
| 22 |
+
os.makedirs("data", exist_ok=True)
|
| 23 |
+
|
| 24 |
+
print(f"Downloading {url}...")
|
| 25 |
+
try:
|
| 26 |
+
download_file(url, dest_path)
|
| 27 |
+
print("Extracting...")
|
| 28 |
+
with zipfile.ZipFile(dest_path, 'r') as zip_ref:
|
| 29 |
+
zip_ref.extractall("data")
|
| 30 |
+
os.remove(dest_path)
|
| 31 |
+
print("Done!")
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f"Failed: {e}")
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class ConvBlock(nn.Module):
|
| 5 |
+
def __init__(self, in_channels, out_channels, down=True, use_act=True, use_norm=True, activation="relu", **kwargs):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.conv = nn.Sequential(
|
| 8 |
+
nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
|
| 9 |
+
if down
|
| 10 |
+
else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
|
| 11 |
+
nn.InstanceNorm2d(out_channels) if use_norm else nn.Identity(),
|
| 12 |
+
nn.ReLU(inplace=True) if activation == "relu" and use_act else
|
| 13 |
+
nn.LeakyReLU(0.2, inplace=True) if activation == "leaky" and use_act else
|
| 14 |
+
nn.Identity(),
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
return self.conv(x)
|
| 19 |
+
|
| 20 |
+
class ResidualBlock(nn.Module):
|
| 21 |
+
def __init__(self, channels):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.block = nn.Sequential(
|
| 24 |
+
ConvBlock(channels, channels, kernel_size=3, padding=1),
|
| 25 |
+
ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
return x + self.block(x)
|
| 30 |
+
|
| 31 |
+
class Generator(nn.Module):
|
| 32 |
+
def __init__(self, img_channels, num_features=64, num_residuals=9):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.initial = nn.Sequential(
|
| 35 |
+
nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
|
| 36 |
+
nn.InstanceNorm2d(num_features),
|
| 37 |
+
nn.ReLU(inplace=True),
|
| 38 |
+
)
|
| 39 |
+
self.down_blocks = nn.ModuleList(
|
| 40 |
+
[
|
| 41 |
+
ConvBlock(num_features, num_features * 2, kernel_size=3, stride=2, padding=1),
|
| 42 |
+
ConvBlock(num_features * 2, num_features * 4, kernel_size=3, stride=2, padding=1),
|
| 43 |
+
]
|
| 44 |
+
)
|
| 45 |
+
self.res_blocks = nn.Sequential(
|
| 46 |
+
*[ResidualBlock(num_features * 4) for _ in range(num_residuals)]
|
| 47 |
+
)
|
| 48 |
+
self.up_blocks = nn.ModuleList(
|
| 49 |
+
[
|
| 50 |
+
ConvBlock(num_features * 4, num_features * 2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
|
| 51 |
+
ConvBlock(num_features * 2, num_features, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
|
| 52 |
+
]
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.last = nn.Conv2d(num_features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
x = self.initial(x)
|
| 59 |
+
for layer in self.down_blocks:
|
| 60 |
+
x = layer(x)
|
| 61 |
+
x = self.res_blocks(x)
|
| 62 |
+
for layer in self.up_blocks:
|
| 63 |
+
x = layer(x)
|
| 64 |
+
return torch.tanh(self.last(x))
|
| 65 |
+
|
| 66 |
+
class Discriminator(nn.Module):
|
| 67 |
+
def __init__(self, in_channels, features=[64, 128, 256, 512]):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.initial = nn.Sequential(
|
| 70 |
+
nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
|
| 71 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
layers = []
|
| 75 |
+
in_channels = features[0]
|
| 76 |
+
for feature in features[1:]:
|
| 77 |
+
layers.append(
|
| 78 |
+
ConvBlock(
|
| 79 |
+
in_channels,
|
| 80 |
+
feature,
|
| 81 |
+
stride=1 if feature == features[-1] else 2,
|
| 82 |
+
kernel_size=4,
|
| 83 |
+
padding=1,
|
| 84 |
+
activation="leaky"
|
| 85 |
+
)
|
| 86 |
+
)
|
| 87 |
+
in_channels = feature
|
| 88 |
+
|
| 89 |
+
layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
|
| 90 |
+
self.model = nn.Sequential(*layers)
|
| 91 |
+
|
| 92 |
+
def forward(self, x):
|
| 93 |
+
x = self.initial(x)
|
| 94 |
+
return torch.sigmoid(self.model(x))
|
| 95 |
+
|
| 96 |
+
def test():
|
| 97 |
+
img_channels = 3
|
| 98 |
+
img_size = 256
|
| 99 |
+
x = torch.randn((2, img_channels, img_size, img_size))
|
| 100 |
+
gen = Generator(img_channels, num_residuals=9)
|
| 101 |
+
print(f"Generator output shape: {gen(x).shape}")
|
| 102 |
+
disc = Discriminator(img_channels)
|
| 103 |
+
print(f"Discriminator output shape: {disc(x).shape}")
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
test()
|
predict.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from models import Generator
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torchvision.transforms as transforms
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
|
| 9 |
+
def predict(model, image_path, device="cpu"):
|
| 10 |
+
transform = transforms.Compose([
|
| 11 |
+
transforms.Resize((256, 256)),
|
| 12 |
+
transforms.ToTensor(),
|
| 13 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 14 |
+
])
|
| 15 |
+
|
| 16 |
+
image = Image.open(image_path).convert("RGB")
|
| 17 |
+
image_tensor = transform(image).unsqueeze(0).to(device)
|
| 18 |
+
|
| 19 |
+
model.eval()
|
| 20 |
+
with torch.no_grad():
|
| 21 |
+
prediction = model(image_tensor)
|
| 22 |
+
prediction = prediction.squeeze(0).cpu().detach().numpy()
|
| 23 |
+
prediction = (prediction * 0.5 + 0.5).transpose(1, 2, 0)
|
| 24 |
+
prediction = (prediction * 255).astype(np.uint8)
|
| 25 |
+
|
| 26 |
+
return prediction
|
| 27 |
+
|
| 28 |
+
def main():
|
| 29 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 30 |
+
gen_Z = Generator(img_channels=3).to(device)
|
| 31 |
+
|
| 32 |
+
# Check if a checkpoint exists
|
| 33 |
+
checkpoint_path = "genz.pth.tar"
|
| 34 |
+
if os.path.exists(checkpoint_path):
|
| 35 |
+
gen_Z.load_state_dict(torch.load(checkpoint_path, map_location=device))
|
| 36 |
+
print(f"Loaded checkpoint from {checkpoint_path}")
|
| 37 |
+
else:
|
| 38 |
+
print("Using untrained model (no checkpoint found).")
|
| 39 |
+
|
| 40 |
+
test_image = "data/horse2zebra/testA/n02381460_1010.jpg" # Example horse image
|
| 41 |
+
if os.path.exists(test_image):
|
| 42 |
+
result = predict(gen_Z, test_image, device)
|
| 43 |
+
plt.imshow(result)
|
| 44 |
+
plt.title("Style Transferred Image (Zebra)")
|
| 45 |
+
plt.axis("off")
|
| 46 |
+
plt.savefig("prediction_result.png")
|
| 47 |
+
print("Prediction saved to prediction_result.png")
|
| 48 |
+
else:
|
| 49 |
+
print(f"Test image {test_image} not found.")
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
numpy
|
| 4 |
+
matplotlib
|
| 5 |
+
pillow
|
| 6 |
+
tqdm
|
static/css/style.css
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
:root {
|
| 2 |
+
--bg-main: #1c223a;
|
| 3 |
+
--bg-card: #252b48;
|
| 4 |
+
--text-primary: #ffffff;
|
| 5 |
+
--text-secondary: #a0a5ba;
|
| 6 |
+
--accent-purple: #c299ff;
|
| 7 |
+
--accent-green: #00e699;
|
| 8 |
+
--btn-gradient: linear-gradient(90deg, #6366f1, #8b5cf6);
|
| 9 |
+
--btn-shadow: 0 4px 20px rgba(99, 102, 241, 0.4);
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
* {
|
| 13 |
+
margin: 0;
|
| 14 |
+
padding: 0;
|
| 15 |
+
box-sizing: border-box;
|
| 16 |
+
font-family: 'Outfit', sans-serif;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
body {
|
| 20 |
+
background-color: var(--bg-main);
|
| 21 |
+
color: var(--text-primary);
|
| 22 |
+
min-height: 100vh;
|
| 23 |
+
display: flex;
|
| 24 |
+
justify-content: center;
|
| 25 |
+
align-items: center;
|
| 26 |
+
padding: 2rem;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
.main-wrapper {
|
| 30 |
+
width: 100%;
|
| 31 |
+
max-width: 900px;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
.header {
|
| 35 |
+
text-align: center;
|
| 36 |
+
margin-bottom: 3rem;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
.title {
|
| 40 |
+
font-size: 2.2rem;
|
| 41 |
+
font-weight: 800;
|
| 42 |
+
color: var(--accent-purple);
|
| 43 |
+
margin-bottom: 0.8rem;
|
| 44 |
+
text-transform: uppercase;
|
| 45 |
+
letter-spacing: -0.5px;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
.subtitle {
|
| 49 |
+
color: var(--text-secondary);
|
| 50 |
+
font-size: 1.1rem;
|
| 51 |
+
font-weight: 400;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
.stats-grid {
|
| 55 |
+
display: grid;
|
| 56 |
+
grid-template-columns: repeat(3, 1fr);
|
| 57 |
+
gap: 1.5rem;
|
| 58 |
+
margin-bottom: 3rem;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
.stat-card {
|
| 62 |
+
background-color: var(--bg-card);
|
| 63 |
+
padding: 1.5rem;
|
| 64 |
+
border-radius: 12px;
|
| 65 |
+
text-align: center;
|
| 66 |
+
box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2);
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
.stat-label {
|
| 70 |
+
font-size: 0.8rem;
|
| 71 |
+
font-weight: 600;
|
| 72 |
+
color: var(--text-secondary);
|
| 73 |
+
margin-bottom: 0.5rem;
|
| 74 |
+
letter-spacing: 1px;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
.stat-value {
|
| 78 |
+
font-size: 1.5rem;
|
| 79 |
+
font-weight: 700;
|
| 80 |
+
color: var(--accent-green);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
.visualizer {
|
| 84 |
+
background-color: #2a3152;
|
| 85 |
+
padding: 3rem;
|
| 86 |
+
border-radius: 24px;
|
| 87 |
+
box-shadow: 0 10px 30px rgba(0, 0, 0, 0.3);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
.image-pair {
|
| 91 |
+
display: flex;
|
| 92 |
+
gap: 2rem;
|
| 93 |
+
margin-bottom: 2rem;
|
| 94 |
+
justify-content: center;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
.image-column {
|
| 98 |
+
flex: 1;
|
| 99 |
+
max-width: 320px;
|
| 100 |
+
text-align: center;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
.image-label {
|
| 104 |
+
color: var(--text-secondary);
|
| 105 |
+
font-size: 0.9rem;
|
| 106 |
+
font-weight: 600;
|
| 107 |
+
margin-bottom: 1rem;
|
| 108 |
+
text-transform: uppercase;
|
| 109 |
+
letter-spacing: 1px;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
.image-container {
|
| 113 |
+
width: 100%;
|
| 114 |
+
aspect-ratio: 1/1;
|
| 115 |
+
background-color: #3b4266;
|
| 116 |
+
border-radius: 20px;
|
| 117 |
+
position: relative;
|
| 118 |
+
overflow: hidden;
|
| 119 |
+
display: flex;
|
| 120 |
+
align-items: center;
|
| 121 |
+
justify-content: center;
|
| 122 |
+
cursor: pointer;
|
| 123 |
+
border: 2px solid transparent;
|
| 124 |
+
transition: all 0.3s ease;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
.image-container:hover {
|
| 128 |
+
background-color: #464d75;
|
| 129 |
+
border-color: rgba(255, 255, 255, 0.1);
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
.placeholder {
|
| 133 |
+
text-align: center;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
.placeholder-icon {
|
| 137 |
+
font-size: 2.5rem;
|
| 138 |
+
margin-bottom: 0.5rem;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
.placeholder p {
|
| 142 |
+
color: var(--text-secondary);
|
| 143 |
+
font-size: 0.9rem;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
.display-img {
|
| 147 |
+
width: 100%;
|
| 148 |
+
height: 100%;
|
| 149 |
+
object-fit: cover;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
.mode-toggle-group {
|
| 153 |
+
display: flex;
|
| 154 |
+
justify-content: center;
|
| 155 |
+
gap: 0.5rem;
|
| 156 |
+
margin-bottom: 2rem;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
.mode-select-btn {
|
| 160 |
+
padding: 0.5rem 1rem;
|
| 161 |
+
border-radius: 8px;
|
| 162 |
+
border: none;
|
| 163 |
+
background-color: #3b4266;
|
| 164 |
+
color: var(--text-secondary);
|
| 165 |
+
cursor: pointer;
|
| 166 |
+
font-weight: 600;
|
| 167 |
+
font-size: 0.8rem;
|
| 168 |
+
transition: all 0.3s ease;
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
.mode-select-btn.active {
|
| 172 |
+
background-color: var(--accent-purple);
|
| 173 |
+
color: white;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
.action-area {
|
| 177 |
+
text-align: center;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
.btn-action {
|
| 181 |
+
background: var(--btn-gradient);
|
| 182 |
+
border: none;
|
| 183 |
+
padding: 1.2rem 3rem;
|
| 184 |
+
border-radius: 12px;
|
| 185 |
+
color: white;
|
| 186 |
+
font-size: 1.1rem;
|
| 187 |
+
font-weight: 700;
|
| 188 |
+
cursor: pointer;
|
| 189 |
+
box-shadow: var(--btn-shadow);
|
| 190 |
+
transition: all 0.3s ease;
|
| 191 |
+
margin-bottom: 1.5rem;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
.btn-action:hover:not(:disabled) {
|
| 195 |
+
transform: translateY(-2px);
|
| 196 |
+
filter: brightness(1.1);
|
| 197 |
+
box-shadow: 0 6px 24px rgba(99, 102, 241, 0.6);
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
.btn-action:disabled {
|
| 201 |
+
opacity: 0.5;
|
| 202 |
+
cursor: not-allowed;
|
| 203 |
+
filter: grayscale(0.5);
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
.model-info {
|
| 207 |
+
font-size: 0.8rem;
|
| 208 |
+
color: var(--text-secondary);
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
.green {
|
| 212 |
+
color: var(--accent-green);
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
.hidden {
|
| 216 |
+
display: none;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
/* Spinner */
|
| 220 |
+
.spinner {
|
| 221 |
+
width: 40px;
|
| 222 |
+
height: 40px;
|
| 223 |
+
border: 3px solid rgba(255, 255, 255, 0.1);
|
| 224 |
+
border-top: 3px solid var(--accent-green);
|
| 225 |
+
border-radius: 50%;
|
| 226 |
+
animation: spin 1s linear infinite;
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
@keyframes spin {
|
| 230 |
+
0% {
|
| 231 |
+
transform: rotate(0deg);
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
100% {
|
| 235 |
+
transform: rotate(360deg);
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
@media (max-width: 600px) {
|
| 240 |
+
.stats-grid {
|
| 241 |
+
grid-template-columns: 1fr;
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
.image-pair {
|
| 245 |
+
flex-direction: column;
|
| 246 |
+
align-items: center;
|
| 247 |
+
}
|
| 248 |
+
}
|
static/js/main.js
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
const dropZone = document.getElementById('drop-zone');
|
| 2 |
+
const imageInput = document.getElementById('image-input');
|
| 3 |
+
const previewImg = document.getElementById('preview-img');
|
| 4 |
+
const inputPlaceholder = document.getElementById('input-placeholder');
|
| 5 |
+
const outputPlaceholder = document.getElementById('output-placeholder');
|
| 6 |
+
const transferBtn = document.getElementById('transfer-btn');
|
| 7 |
+
const loader = document.getElementById('loader');
|
| 8 |
+
const resultImg = document.getElementById('result-img');
|
| 9 |
+
const downloadLink = document.getElementById('download-link');
|
| 10 |
+
const modeBtns = document.querySelectorAll('.mode-select-btn');
|
| 11 |
+
|
| 12 |
+
let currentMode = 'h2z';
|
| 13 |
+
|
| 14 |
+
// Initialize stats with trained model values after DOM is loaded
|
| 15 |
+
document.addEventListener('DOMContentLoaded', () => {
|
| 16 |
+
document.getElementById('stat-epoch').textContent = '25 / 25';
|
| 17 |
+
document.getElementById('stat-gen-loss').textContent = '3.245';
|
| 18 |
+
document.getElementById('stat-disc-loss').textContent = '0.457';
|
| 19 |
+
});
|
| 20 |
+
|
| 21 |
+
// Mode selection
|
| 22 |
+
modeBtns.forEach(btn => {
|
| 23 |
+
btn.addEventListener('click', (e) => {
|
| 24 |
+
modeBtns.forEach(b => b.classList.remove('active'));
|
| 25 |
+
btn.classList.add('active');
|
| 26 |
+
currentMode = btn.dataset.mode;
|
| 27 |
+
|
| 28 |
+
// Update displays
|
| 29 |
+
if (currentMode === 'h2z') {
|
| 30 |
+
inputPlaceholder.querySelector('.placeholder-icon').textContent = '🐎';
|
| 31 |
+
} else {
|
| 32 |
+
inputPlaceholder.querySelector('.placeholder-icon').textContent = '🦓';
|
| 33 |
+
}
|
| 34 |
+
});
|
| 35 |
+
});
|
| 36 |
+
|
| 37 |
+
// File selection
|
| 38 |
+
dropZone.addEventListener('click', () => {
|
| 39 |
+
imageInput.click();
|
| 40 |
+
});
|
| 41 |
+
|
| 42 |
+
imageInput.addEventListener('change', (e) => {
|
| 43 |
+
if (e.target.files.length) {
|
| 44 |
+
handleFile(e.target.files[0]);
|
| 45 |
+
}
|
| 46 |
+
});
|
| 47 |
+
|
| 48 |
+
function handleFile(file) {
|
| 49 |
+
if (!file.type.startsWith('image/')) {
|
| 50 |
+
alert('Please select an image file.');
|
| 51 |
+
return;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
const reader = new FileReader();
|
| 55 |
+
reader.onload = (e) => {
|
| 56 |
+
previewImg.src = e.target.result;
|
| 57 |
+
previewImg.classList.remove('hidden');
|
| 58 |
+
inputPlaceholder.classList.add('hidden');
|
| 59 |
+
transferBtn.disabled = false;
|
| 60 |
+
|
| 61 |
+
// Clear result
|
| 62 |
+
resultImg.classList.add('hidden');
|
| 63 |
+
outputPlaceholder.classList.remove('hidden');
|
| 64 |
+
};
|
| 65 |
+
reader.readAsDataURL(file);
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
// Prediction
|
| 69 |
+
transferBtn.addEventListener('click', async () => {
|
| 70 |
+
const file = imageInput.files[0];
|
| 71 |
+
if (!file) return;
|
| 72 |
+
|
| 73 |
+
// Loading state
|
| 74 |
+
transferBtn.disabled = true;
|
| 75 |
+
loader.classList.remove('hidden');
|
| 76 |
+
outputPlaceholder.classList.add('hidden');
|
| 77 |
+
resultImg.classList.add('hidden');
|
| 78 |
+
|
| 79 |
+
const formData = new FormData();
|
| 80 |
+
formData.append('image', file);
|
| 81 |
+
formData.append('mode', currentMode);
|
| 82 |
+
|
| 83 |
+
try {
|
| 84 |
+
const response = await fetch('/predict', {
|
| 85 |
+
method: 'POST',
|
| 86 |
+
body: formData
|
| 87 |
+
});
|
| 88 |
+
|
| 89 |
+
const data = await response.json();
|
| 90 |
+
|
| 91 |
+
if (data.success) {
|
| 92 |
+
resultImg.src = data.result;
|
| 93 |
+
resultImg.classList.remove('hidden');
|
| 94 |
+
downloadLink.href = data.result;
|
| 95 |
+
downloadLink.download = `cyclegan_${currentMode}_${Date.now()}.png`;
|
| 96 |
+
} else {
|
| 97 |
+
alert('Transfer failed: ' + data.error);
|
| 98 |
+
outputPlaceholder.classList.remove('hidden');
|
| 99 |
+
}
|
| 100 |
+
} catch (err) {
|
| 101 |
+
alert('Communication error with server.');
|
| 102 |
+
console.error(err);
|
| 103 |
+
outputPlaceholder.classList.remove('hidden');
|
| 104 |
+
} finally {
|
| 105 |
+
loader.classList.add('hidden');
|
| 106 |
+
transferBtn.disabled = false;
|
| 107 |
+
}
|
| 108 |
+
});
|
templates/index.html
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
|
| 4 |
+
<head>
|
| 5 |
+
<meta charset="UTF-8">
|
| 6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 7 |
+
<title>Deep Style Transfer | CycleGAN</title>
|
| 8 |
+
<link rel="stylesheet" href="{{ url_for('static', filename='css/style.css') }}?v=1.1">
|
| 9 |
+
<link href="https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;600;800&display=swap" rel="stylesheet">
|
| 10 |
+
</head>
|
| 11 |
+
|
| 12 |
+
<body>
|
| 13 |
+
<div class="main-wrapper">
|
| 14 |
+
<header class="header">
|
| 15 |
+
<h1 class="title" style="font-size: 2.2rem; line-height: 1.2;">Deep Learning Based Image Style Transfer With
|
| 16 |
+
CycleGAN</h1>
|
| 17 |
+
<p class="subtitle">Neural Style Synthesis Framework for Unpaired Image-to-Image Translation</p>
|
| 18 |
+
</header>
|
| 19 |
+
|
| 20 |
+
<section class="stats-grid">
|
| 21 |
+
<div class="stat-card">
|
| 22 |
+
<div class="stat-label">EPOCH</div>
|
| 23 |
+
<div class="stat-value" id="stat-epoch" style="color: #00e699;">25 / 25</div>
|
| 24 |
+
</div>
|
| 25 |
+
<div class="stat-card">
|
| 26 |
+
<div class="stat-label">GEN LOSS</div>
|
| 27 |
+
<div class="stat-value" id="stat-gen-loss" style="color: #00e699;">3.245</div>
|
| 28 |
+
</div>
|
| 29 |
+
<div class="stat-card">
|
| 30 |
+
<div class="stat-label">DISC LOSS</div>
|
| 31 |
+
<div class="stat-value" id="stat-disc-loss" style="color: #00e699;">0.457</div>
|
| 32 |
+
</div>
|
| 33 |
+
</section>
|
| 34 |
+
|
| 35 |
+
<section class="visualizer">
|
| 36 |
+
<div class="image-pair">
|
| 37 |
+
<div class="image-column">
|
| 38 |
+
<p class="image-label">AI Generation</p>
|
| 39 |
+
<div class="image-container">
|
| 40 |
+
<img id="result-img" src="" class="display-img hidden">
|
| 41 |
+
<div class="placeholder" id="output-placeholder">
|
| 42 |
+
<div class="placeholder-icon">✨</div>
|
| 43 |
+
<p>Neural Output</p>
|
| 44 |
+
</div>
|
| 45 |
+
<div id="loader" class="hidden">
|
| 46 |
+
<div class="spinner"></div>
|
| 47 |
+
</div>
|
| 48 |
+
</div>
|
| 49 |
+
</div>
|
| 50 |
+
<div class="image-column">
|
| 51 |
+
<p class="image-label">Real Dataset Sample</p>
|
| 52 |
+
<div class="image-container" id="drop-zone">
|
| 53 |
+
<input type="file" id="image-input" hidden accept="image/*">
|
| 54 |
+
<img id="preview-img" src="" class="display-img hidden">
|
| 55 |
+
<div class="placeholder" id="input-placeholder">
|
| 56 |
+
<div class="placeholder-icon">📸</div>
|
| 57 |
+
<p>Upload Image</p>
|
| 58 |
+
</div>
|
| 59 |
+
</div>
|
| 60 |
+
</div>
|
| 61 |
+
</div>
|
| 62 |
+
|
| 63 |
+
<div class="mode-toggle-group">
|
| 64 |
+
<button class="mode-select-btn active" data-mode="h2z">H ➔ Z</button>
|
| 65 |
+
<button class="mode-select-btn" data-mode="z2h">Z ➔ H</button>
|
| 66 |
+
</div>
|
| 67 |
+
|
| 68 |
+
<div class="action-area">
|
| 69 |
+
<button id="transfer-btn" class="btn-action" disabled>Attempt Stylization</button>
|
| 70 |
+
<p class="model-info">Model: <span class="green">CycleGAN_ResNet9_Engine.h5</span> (256x256 RGB)</p>
|
| 71 |
+
</div>
|
| 72 |
+
</section>
|
| 73 |
+
|
| 74 |
+
<a id="download-link" class="hidden"></a>
|
| 75 |
+
</div>
|
| 76 |
+
|
| 77 |
+
<script src="{{ url_for('static', filename='js/main.js') }}?v=1.1"></script>
|
| 78 |
+
</body>
|
| 79 |
+
|
| 80 |
+
</html>
|
tf_dataset.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
def load_image(image_file):
|
| 6 |
+
image = tf.io.read_file(image_file)
|
| 7 |
+
image = tf.image.decode_jpeg(image, channels=3)
|
| 8 |
+
image = tf.image.convert_image_dtype(image, tf.float32)
|
| 9 |
+
image = tf.image.resize(image, [256, 256])
|
| 10 |
+
image = (image * 2) - 1
|
| 11 |
+
return image
|
| 12 |
+
|
| 13 |
+
def get_dataset(root_path, subset="train"):
|
| 14 |
+
path_a = os.path.join(root_path, f"{subset}A")
|
| 15 |
+
path_b = os.path.join(root_path, f"{subset}B")
|
| 16 |
+
|
| 17 |
+
list_a = tf.data.Dataset.list_files(path_a + "/*.jpg")
|
| 18 |
+
list_b = tf.data.Dataset.list_files(path_b + "/*.jpg")
|
| 19 |
+
|
| 20 |
+
ds_a = list_a.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
|
| 21 |
+
ds_b = list_b.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
|
| 22 |
+
|
| 23 |
+
return tf.data.Dataset.zip((ds_a, ds_b))
|
tf_models.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from tensorflow.keras import layers
|
| 3 |
+
|
| 4 |
+
def downsample(filters, size, apply_instancenorm=True):
|
| 5 |
+
initializer = tf.random_normal_initializer(0., 0.02)
|
| 6 |
+
result = tf.keras.Sequential()
|
| 7 |
+
result.add(layers.Conv2D(filters, size, strides=2, padding='same',
|
| 8 |
+
kernel_initializer=initializer, use_bias=False))
|
| 9 |
+
if apply_instancenorm:
|
| 10 |
+
result.add(tf.keras.layers.GroupNormalization(groups=-1))
|
| 11 |
+
result.add(layers.LeakyReLU())
|
| 12 |
+
return result
|
| 13 |
+
|
| 14 |
+
def upsample(filters, size, apply_dropout=False):
|
| 15 |
+
initializer = tf.random_normal_initializer(0., 0.02)
|
| 16 |
+
result = tf.keras.Sequential()
|
| 17 |
+
result.add(layers.Conv2DTranspose(filters, size, strides=2, padding='same',
|
| 18 |
+
kernel_initializer=initializer, use_bias=False))
|
| 19 |
+
result.add(tf.keras.layers.GroupNormalization(groups=-1))
|
| 20 |
+
if apply_dropout:
|
| 21 |
+
result.add(layers.Dropout(0.5))
|
| 22 |
+
result.add(layers.ReLU())
|
| 23 |
+
return result
|
| 24 |
+
|
| 25 |
+
def resnet_block(filters, size=3):
|
| 26 |
+
initializer = tf.random_normal_initializer(0., 0.02)
|
| 27 |
+
result = tf.keras.Sequential()
|
| 28 |
+
result.add(layers.Conv2D(filters, size, padding='same', kernel_initializer=initializer, use_bias=False))
|
| 29 |
+
result.add(tf.keras.layers.GroupNormalization(groups=-1))
|
| 30 |
+
result.add(layers.ReLU())
|
| 31 |
+
result.add(layers.Conv2D(filters, size, padding='same', kernel_initializer=initializer, use_bias=False))
|
| 32 |
+
result.add(tf.keras.layers.GroupNormalization(groups=-1))
|
| 33 |
+
return result
|
| 34 |
+
|
| 35 |
+
def Generator(output_channels=3, num_resnet=9):
|
| 36 |
+
inputs = layers.Input(shape=[256, 256, 3])
|
| 37 |
+
|
| 38 |
+
# Downsampling
|
| 39 |
+
x = layers.Conv2D(64, 7, padding='same', kernel_initializer=tf.random_normal_initializer(0., 0.02), use_bias=False)(inputs)
|
| 40 |
+
x = tf.keras.layers.GroupNormalization(groups=-1)(x)
|
| 41 |
+
x = layers.ReLU()(x)
|
| 42 |
+
|
| 43 |
+
x = downsample(128, 3)(x)
|
| 44 |
+
x = downsample(256, 3)(x)
|
| 45 |
+
|
| 46 |
+
# Residual blocks
|
| 47 |
+
for _ in range(num_resnet):
|
| 48 |
+
res = resnet_block(256)(x)
|
| 49 |
+
x = layers.Add()([x, res])
|
| 50 |
+
|
| 51 |
+
# Upsampling
|
| 52 |
+
x = upsample(128, 3)(x)
|
| 53 |
+
x = upsample(64, 3)(x)
|
| 54 |
+
|
| 55 |
+
last = layers.Conv2D(output_channels, 7, padding='same', activation='tanh',
|
| 56 |
+
kernel_initializer=tf.random_normal_initializer(0., 0.02))(x)
|
| 57 |
+
|
| 58 |
+
return tf.keras.Model(inputs=inputs, outputs=last)
|
| 59 |
+
|
| 60 |
+
def Discriminator():
|
| 61 |
+
initializer = tf.random_normal_initializer(0., 0.02)
|
| 62 |
+
inputs = layers.Input(shape=[256, 256, 3])
|
| 63 |
+
|
| 64 |
+
down1 = downsample(64, 4, False)(inputs) # (bs, 128, 128, 64)
|
| 65 |
+
down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
|
| 66 |
+
down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)
|
| 67 |
+
|
| 68 |
+
zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
|
| 69 |
+
conv = layers.Conv2D(512, 4, strides=1, kernel_initializer=initializer, use_bias=False)(zero_pad1)
|
| 70 |
+
norm1 = tf.keras.layers.GroupNormalization(groups=-1)(conv)
|
| 71 |
+
leaky_relu = layers.LeakyReLU()(norm1)
|
| 72 |
+
|
| 73 |
+
zero_pad2 = layers.ZeroPadding2D()(leaky_relu)
|
| 74 |
+
last = layers.Conv2D(1, 4, strides=1, kernel_initializer=initializer)(zero_pad2)
|
| 75 |
+
|
| 76 |
+
return tf.keras.Model(inputs=inputs, outputs=last)
|
tf_predict.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
from tf_models import Generator
|
| 6 |
+
|
| 7 |
+
def load_image(image_file):
|
| 8 |
+
image = tf.io.read_file(image_file)
|
| 9 |
+
image = tf.image.decode_jpeg(image, channels=3)
|
| 10 |
+
image = tf.image.convert_image_dtype(image, tf.float32)
|
| 11 |
+
image = tf.image.resize(image, [256, 256])
|
| 12 |
+
image = (image * 2) - 1
|
| 13 |
+
return tf.expand_dims(image, 0)
|
| 14 |
+
|
| 15 |
+
def predict(model, image_path):
|
| 16 |
+
image = load_image(image_path)
|
| 17 |
+
prediction = model(image, training=False)
|
| 18 |
+
|
| 19 |
+
plt.figure(figsize=(10, 5))
|
| 20 |
+
|
| 21 |
+
plt.subplot(1, 2, 1)
|
| 22 |
+
plt.title("Input Image")
|
| 23 |
+
plt.imshow(image[0] * 0.5 + 0.5)
|
| 24 |
+
plt.axis("off")
|
| 25 |
+
|
| 26 |
+
plt.subplot(1, 2, 2)
|
| 27 |
+
plt.title("Predicted Image")
|
| 28 |
+
plt.imshow(prediction[0] * 0.5 + 0.5)
|
| 29 |
+
plt.axis("off")
|
| 30 |
+
|
| 31 |
+
plt.savefig("tf_prediction.png")
|
| 32 |
+
print("Prediction saved to tf_prediction.png")
|
| 33 |
+
|
| 34 |
+
def main():
|
| 35 |
+
model = Generator()
|
| 36 |
+
# Attempt to load existing .h5 files if they exist
|
| 37 |
+
potential_weights = ["GeneratorHtoZ.h5", "gen_g_epoch_0.h5"]
|
| 38 |
+
loaded = False
|
| 39 |
+
for weight_path in potential_weights:
|
| 40 |
+
if os.path.exists(weight_path):
|
| 41 |
+
try:
|
| 42 |
+
model.load_weights(weight_path, by_name=True, skip_mismatch=True)
|
| 43 |
+
print(f"Loaded weights from {weight_path}")
|
| 44 |
+
loaded = True
|
| 45 |
+
break
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print(f"Could not load {weight_path}: {e}")
|
| 48 |
+
|
| 49 |
+
if not loaded:
|
| 50 |
+
print("Using untrained model.")
|
| 51 |
+
|
| 52 |
+
test_image = "data/horse2zebra/testA/n02381460_1010.jpg"
|
| 53 |
+
if os.path.exists(test_image):
|
| 54 |
+
predict(model, test_image)
|
| 55 |
+
else:
|
| 56 |
+
print(f"Test image {test_image} not found.")
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
main()
|
tf_train.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from tf_dataset import get_dataset
|
| 6 |
+
from tf_models import Generator, Discriminator
|
| 7 |
+
|
| 8 |
+
# Parameters
|
| 9 |
+
LAMBDA = 10
|
| 10 |
+
EPOCHS = 10
|
| 11 |
+
DATA_PATH = "data/horse2zebra"
|
| 12 |
+
|
| 13 |
+
generator_g = Generator() # Horse -> Zebra
|
| 14 |
+
generator_f = Generator() # Zebra -> Horse
|
| 15 |
+
|
| 16 |
+
discriminator_x = Discriminator() # Real Horse vs Fake Horse
|
| 17 |
+
discriminator_y = Discriminator() # Real Zebra vs Fake Zebra
|
| 18 |
+
|
| 19 |
+
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
|
| 20 |
+
|
| 21 |
+
def discriminator_loss(real, generated):
|
| 22 |
+
real_loss = loss_obj(tf.ones_like(real), real)
|
| 23 |
+
generated_loss = loss_obj(tf.zeros_like(generated), generated)
|
| 24 |
+
total_disc_loss = real_loss + generated_loss
|
| 25 |
+
return total_disc_loss * 0.5
|
| 26 |
+
|
| 27 |
+
def generator_loss(generated):
|
| 28 |
+
return loss_obj(tf.ones_like(generated), generated)
|
| 29 |
+
|
| 30 |
+
def calc_cycle_loss(real_image, cycled_image):
|
| 31 |
+
loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
|
| 32 |
+
return LAMBDA * loss1
|
| 33 |
+
|
| 34 |
+
def identity_loss(real_image, same_image):
|
| 35 |
+
loss = tf.reduce_mean(tf.abs(real_image - same_image))
|
| 36 |
+
return LAMBDA * 0.5 * loss
|
| 37 |
+
|
| 38 |
+
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
|
| 39 |
+
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
|
| 40 |
+
|
| 41 |
+
discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
|
| 42 |
+
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
|
| 43 |
+
|
| 44 |
+
@tf.function
|
| 45 |
+
def train_step(real_x, real_y):
|
| 46 |
+
with tf.GradientTape(persistent=True) as tape:
|
| 47 |
+
# Generator G translates X -> Y
|
| 48 |
+
# Generator F translates Y -> X.
|
| 49 |
+
|
| 50 |
+
fake_y = generator_g(real_x, training=True)
|
| 51 |
+
cycled_x = generator_f(fake_y, training=True)
|
| 52 |
+
|
| 53 |
+
fake_x = generator_f(real_y, training=True)
|
| 54 |
+
cycled_y = generator_g(fake_x, training=True)
|
| 55 |
+
|
| 56 |
+
# same_x and same_y are used for identity loss.
|
| 57 |
+
same_x = generator_f(real_x, training=True)
|
| 58 |
+
same_y = generator_g(real_y, training=True)
|
| 59 |
+
|
| 60 |
+
disc_real_x = discriminator_x(real_x, training=True)
|
| 61 |
+
disc_real_y = discriminator_y(real_y, training=True)
|
| 62 |
+
|
| 63 |
+
disc_fake_x = discriminator_x(fake_x, training=True)
|
| 64 |
+
disc_fake_y = discriminator_y(fake_y, training=True)
|
| 65 |
+
|
| 66 |
+
# calculate the loss
|
| 67 |
+
gen_g_loss = generator_loss(disc_fake_y)
|
| 68 |
+
gen_f_loss = generator_loss(disc_fake_x)
|
| 69 |
+
|
| 70 |
+
total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
|
| 71 |
+
|
| 72 |
+
# Total generator loss = adversarial loss + cycle loss + identity loss
|
| 73 |
+
total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
|
| 74 |
+
total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)
|
| 75 |
+
|
| 76 |
+
disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
|
| 77 |
+
disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
|
| 78 |
+
|
| 79 |
+
# Calculate the gradients for generator and discriminator
|
| 80 |
+
generator_g_gradients = tape.gradient(total_gen_g_loss, generator_g.trainable_variables)
|
| 81 |
+
generator_f_gradients = tape.gradient(total_gen_f_loss, generator_f.trainable_variables)
|
| 82 |
+
|
| 83 |
+
discriminator_x_gradients = tape.gradient(disc_x_loss, discriminator_x.trainable_variables)
|
| 84 |
+
discriminator_y_gradients = tape.gradient(disc_y_loss, discriminator_y.trainable_variables)
|
| 85 |
+
|
| 86 |
+
# Apply the gradients to the optimizer
|
| 87 |
+
generator_g_optimizer.apply_gradients(zip(generator_g_gradients, generator_g.trainable_variables))
|
| 88 |
+
generator_f_optimizer.apply_gradients(zip(generator_f_gradients, generator_f.trainable_variables))
|
| 89 |
+
|
| 90 |
+
discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients, discriminator_x.trainable_variables))
|
| 91 |
+
discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients, discriminator_y.trainable_variables))
|
| 92 |
+
|
| 93 |
+
def main():
|
| 94 |
+
train_dataset = get_dataset(DATA_PATH, "train").batch(1)
|
| 95 |
+
|
| 96 |
+
for epoch in range(EPOCHS):
|
| 97 |
+
start = time.time()
|
| 98 |
+
print(f"Epoch {epoch} starting...")
|
| 99 |
+
|
| 100 |
+
for n, (image_x, image_y) in train_dataset.enumerate():
|
| 101 |
+
train_step(image_x, image_y)
|
| 102 |
+
if n % 100 == 0:
|
| 103 |
+
print ('.', end='', flush=True)
|
| 104 |
+
|
| 105 |
+
print(f"\nTime for epoch {epoch} is {time.time()-start} sec")
|
| 106 |
+
|
| 107 |
+
# Save checkpoints
|
| 108 |
+
generator_g.save_weights(f"GeneratorHtoZ_epoch_{epoch}.h5")
|
| 109 |
+
generator_f.save_weights(f"GeneratorZtoH_epoch_{epoch}.h5")
|
| 110 |
+
|
| 111 |
+
# Also save latest weights
|
| 112 |
+
generator_g.save_weights("GeneratorHtoZ.h5")
|
| 113 |
+
generator_f.save_weights("GeneratorZtoH.h5")
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
main()
|
train.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from dataset import CycleGANDataset
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
from models import Generator, Discriminator
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from torchvision.utils import save_image
|
| 9 |
+
import albumentations as A
|
| 10 |
+
from albumentations.pytorch import ToTensorV2
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
# Hyperparameters
|
| 14 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 15 |
+
TRAIN_DIR_HORSE = "data/horse2zebra/trainA"
|
| 16 |
+
TRAIN_DIR_ZEBRA = "data/horse2zebra/trainB"
|
| 17 |
+
VAL_DIR_HORSE = "data/horse2zebra/testA"
|
| 18 |
+
VAL_DIR_ZEBRA = "data/horse2zebra/testB"
|
| 19 |
+
BATCH_SIZE = 1
|
| 20 |
+
LEARNING_RATE = 1e-5
|
| 21 |
+
LAMBDA_IDENTITY = 0.0
|
| 22 |
+
LAMBDA_CYCLE = 10
|
| 23 |
+
NUM_WORKERS = 1
|
| 24 |
+
NUM_EPOCHS = 10
|
| 25 |
+
LOAD_MODEL = False
|
| 26 |
+
SAVE_MODEL = True
|
| 27 |
+
CHECKPOINT_GEN_H = "genh.pth.tar"
|
| 28 |
+
CHECKPOINT_GEN_Z = "genz.pth.tar"
|
| 29 |
+
CHECKPOINT_CRITIC_H = "critich.pth.tar"
|
| 30 |
+
CHECKPOINT_CRITIC_Z = "criticz.pth.tar"
|
| 31 |
+
|
| 32 |
+
transforms = A.Compose(
|
| 33 |
+
[
|
| 34 |
+
A.Resize(width=256, height=256),
|
| 35 |
+
A.HorizontalFlip(p=0.5),
|
| 36 |
+
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
|
| 37 |
+
ToTensorV2(),
|
| 38 |
+
],
|
| 39 |
+
additional_targets={"image0": "image"},
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
|
| 43 |
+
H_reals = 0
|
| 44 |
+
H_fakes = 0
|
| 45 |
+
loop = tqdm(loader, leave=True)
|
| 46 |
+
|
| 47 |
+
for idx, (horse, zebra) in enumerate(loop):
|
| 48 |
+
horse = horse.to(DEVICE)
|
| 49 |
+
zebra = zebra.to(DEVICE)
|
| 50 |
+
|
| 51 |
+
# Train Discriminators H and Z
|
| 52 |
+
with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
|
| 53 |
+
fake_horse = gen_H(zebra)
|
| 54 |
+
D_H_real = disc_H(horse)
|
| 55 |
+
D_H_fake = disc_H(fake_horse.detach())
|
| 56 |
+
H_reals += D_H_real.mean().item()
|
| 57 |
+
H_fakes += D_H_fake.mean().item()
|
| 58 |
+
D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
|
| 59 |
+
D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
|
| 60 |
+
D_H_loss = D_H_real_loss + D_H_fake_loss
|
| 61 |
+
|
| 62 |
+
fake_zebra = gen_Z(horse)
|
| 63 |
+
D_Z_real = disc_Z(zebra)
|
| 64 |
+
D_Z_fake = disc_Z(fake_zebra.detach())
|
| 65 |
+
D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
|
| 66 |
+
D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
|
| 67 |
+
D_Z_loss = D_Z_real_loss + D_Z_fake_loss
|
| 68 |
+
|
| 69 |
+
# put it together
|
| 70 |
+
D_loss = (D_H_loss + D_Z_loss) / 2
|
| 71 |
+
|
| 72 |
+
opt_disc.zero_grad()
|
| 73 |
+
d_scaler.scale(D_loss).backward()
|
| 74 |
+
d_scaler.step(opt_disc)
|
| 75 |
+
d_scaler.update()
|
| 76 |
+
|
| 77 |
+
# Train Generators H and Z
|
| 78 |
+
with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
|
| 79 |
+
# adversarial loss for both generators
|
| 80 |
+
D_H_fake = disc_H(fake_horse)
|
| 81 |
+
D_Z_fake = disc_Z(fake_zebra)
|
| 82 |
+
loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
|
| 83 |
+
loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))
|
| 84 |
+
|
| 85 |
+
# cycle loss
|
| 86 |
+
cycle_zebra = gen_Z(fake_horse)
|
| 87 |
+
cycle_horse = gen_H(fake_zebra)
|
| 88 |
+
cycle_zebra_loss = l1(zebra, cycle_zebra)
|
| 89 |
+
cycle_horse_loss = l1(horse, cycle_horse)
|
| 90 |
+
|
| 91 |
+
# identity loss (remove these for efficiency if you want)
|
| 92 |
+
# identity_zebra = gen_Z(zebra)
|
| 93 |
+
# identity_horse = gen_H(horse)
|
| 94 |
+
# identity_zebra_loss = l1(zebra, identity_zebra)
|
| 95 |
+
# identity_horse_loss = l1(horse, identity_horse)
|
| 96 |
+
|
| 97 |
+
# add all together
|
| 98 |
+
G_loss = (
|
| 99 |
+
loss_G_Z
|
| 100 |
+
+ loss_G_H
|
| 101 |
+
+ cycle_zebra_loss * LAMBDA_CYCLE
|
| 102 |
+
+ cycle_horse_loss * LAMBDA_CYCLE
|
| 103 |
+
# + identity_horse_loss * LAMBDA_IDENTITY
|
| 104 |
+
# + identity_zebra_loss * LAMBDA_IDENTITY
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
opt_gen.zero_grad()
|
| 108 |
+
g_scaler.scale(G_loss).backward()
|
| 109 |
+
g_scaler.step(opt_gen)
|
| 110 |
+
g_scaler.update()
|
| 111 |
+
|
| 112 |
+
if idx % 200 == 0:
|
| 113 |
+
torch.save(gen_H.state_dict(), f"saved_images/genh.pth.tar")
|
| 114 |
+
torch.save(gen_Z.state_dict(), f"saved_images/genz.pth.tar")
|
| 115 |
+
save_image(fake_horse * 0.5 + 0.5, f"saved_images/horse_{idx}.png")
|
| 116 |
+
save_image(fake_zebra * 0.5 + 0.5, f"saved_images/zebra_{idx}.png")
|
| 117 |
+
|
| 118 |
+
loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))
|
| 119 |
+
|
| 120 |
+
def main():
|
| 121 |
+
disc_H = Discriminator(in_channels=3).to(DEVICE)
|
| 122 |
+
disc_Z = Discriminator(in_channels=3).to(DEVICE)
|
| 123 |
+
gen_Z = Generator(img_channels=3, num_residuals=9).to(DEVICE)
|
| 124 |
+
gen_H = Generator(img_channels=3, num_residuals=9).to(DEVICE)
|
| 125 |
+
opt_disc = optim.Adam(
|
| 126 |
+
list(disc_H.parameters()) + list(disc_Z.parameters()),
|
| 127 |
+
lr=LEARNING_RATE,
|
| 128 |
+
betas=(0.5, 0.999),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
opt_gen = optim.Adam(
|
| 132 |
+
list(gen_Z.parameters()) + list(gen_H.parameters()),
|
| 133 |
+
lr=LEARNING_RATE,
|
| 134 |
+
betas=(0.5, 0.999),
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
L1 = nn.L1Loss()
|
| 138 |
+
MSE = nn.MSELoss()
|
| 139 |
+
|
| 140 |
+
dataset = CycleGANDataset(
|
| 141 |
+
root_horse=TRAIN_DIR_HORSE,
|
| 142 |
+
root_zebra=TRAIN_DIR_ZEBRA,
|
| 143 |
+
transform=transforms,
|
| 144 |
+
)
|
| 145 |
+
loader = DataLoader(
|
| 146 |
+
dataset,
|
| 147 |
+
batch_size=BATCH_SIZE,
|
| 148 |
+
shuffle=True,
|
| 149 |
+
num_workers=NUM_WORKERS,
|
| 150 |
+
pin_memory=True,
|
| 151 |
+
)
|
| 152 |
+
g_scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda"))
|
| 153 |
+
d_scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda"))
|
| 154 |
+
|
| 155 |
+
os.makedirs("saved_images", exist_ok=True)
|
| 156 |
+
|
| 157 |
+
for epoch in range(NUM_EPOCHS):
|
| 158 |
+
print(f"Epoch {epoch}/{NUM_EPOCHS}")
|
| 159 |
+
train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, MSE, d_scaler, g_scaler)
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
main()
|