AsadAnalyst commited on
Commit
32d67cd
·
verified ·
1 Parent(s): ed55e71

Upload 8 files

Browse files
app.py ADDED
@@ -0,0 +1,660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CycleGAN Image-to-Image Translation
3
+ Beautiful Gradio UI for HuggingFace Spaces
4
+ Sketch ↔ Photo Translation with Loss Visualizations
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import torch
10
+ import numpy as np
11
+ import gradio as gr
12
+ from pathlib import Path
13
+ from PIL import Image
14
+ import matplotlib.pyplot as plt
15
+ import matplotlib
16
+ import io
17
+
18
+ matplotlib.use('Agg')
19
+
20
+ # ==================== CONFIGURATION ====================
21
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
+ IMG_SIZE = 256
23
+ NGF = NDF = 64
24
+ N_RES = 9
25
+
26
+
27
+ # ==================== MODEL ARCHITECTURES ====================
28
+ import torch.nn as nn
29
+
30
+
31
+ class ResBlock(nn.Module):
32
+ def __init__(self, dim):
33
+ super().__init__()
34
+ self.block = nn.Sequential(
35
+ nn.ReflectionPad2d(1), nn.Conv2d(dim, dim, 3),
36
+ nn.InstanceNorm2d(dim), nn.ReLU(True),
37
+ nn.ReflectionPad2d(1), nn.Conv2d(dim, dim, 3),
38
+ nn.InstanceNorm2d(dim))
39
+
40
+ def forward(self, x):
41
+ return x + self.block(x)
42
+
43
+
44
+ class Generator(nn.Module):
45
+ def __init__(self, in_ch=3, out_ch=3, ngf=64, n_res=9):
46
+ super().__init__()
47
+ m = [nn.ReflectionPad2d(3), nn.Conv2d(in_ch, ngf, 7),
48
+ nn.InstanceNorm2d(ngf), nn.ReLU(True)]
49
+ for i in range(2):
50
+ f = 2**i
51
+ m += [nn.Conv2d(ngf*f, ngf*f*2, 3, 2, 1),
52
+ nn.InstanceNorm2d(ngf*f*2), nn.ReLU(True)]
53
+ for _ in range(n_res):
54
+ m.append(ResBlock(ngf*4))
55
+ for i in range(2, 0, -1):
56
+ f = 2**i
57
+ m += [nn.ConvTranspose2d(ngf*f, ngf*f//2, 3, 2, 1, 1),
58
+ nn.InstanceNorm2d(ngf*f//2), nn.ReLU(True)]
59
+ m += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, out_ch, 7), nn.Tanh()]
60
+ self.model = nn.Sequential(*m)
61
+
62
+ def forward(self, x):
63
+ return self.model(x)
64
+
65
+
66
+ class PatchDisc(nn.Module):
67
+ def __init__(self, in_ch=3, ndf=64):
68
+ super().__init__()
69
+ def blk(i, o, norm=True, s=2):
70
+ layers = [nn.Conv2d(i, o, 4, s, 1)]
71
+ if norm:
72
+ layers.append(nn.InstanceNorm2d(o))
73
+ return layers + [nn.LeakyReLU(0.2, True)]
74
+
75
+ self.model = nn.Sequential(
76
+ *blk(in_ch, ndf, norm=False),
77
+ *blk(ndf, ndf*2),
78
+ *blk(ndf*2, ndf*4),
79
+ *blk(ndf*4, ndf*8, s=1),
80
+ nn.Conv2d(ndf*8, 1, 4, 1, 1))
81
+
82
+ def forward(self, x):
83
+ return self.model(x)
84
+
85
+
86
+ # ==================== MODEL INITIALIZATION ====================
87
+ def init_w(m):
88
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
89
+ nn.init.normal_(m.weight, 0.0, 0.02)
90
+ if m.bias is not None:
91
+ nn.init.zeros_(m.bias)
92
+ elif isinstance(m, nn.InstanceNorm2d) and m.weight is not None:
93
+ nn.init.ones_(m.weight)
94
+ nn.init.zeros_(m.bias)
95
+
96
+
97
+ def load_models():
98
+ """Load pre-trained models from HuggingFace Hub or local checkpoints"""
99
+ G_AB = Generator(3, 3, NGF, N_RES).to(DEVICE)
100
+ G_BA = Generator(3, 3, NGF, N_RES).to(DEVICE)
101
+ D_A = PatchDisc(3, NDF).to(DEVICE)
102
+ D_B = PatchDisc(3, NDF).to(DEVICE)
103
+
104
+ G_AB.apply(init_w)
105
+ G_BA.apply(init_w)
106
+ D_A.apply(init_w)
107
+ D_B.apply(init_w)
108
+
109
+ # Try to load from HuggingFace Hub
110
+ try:
111
+ from huggingface_hub import hf_hub_download
112
+ # Download models from your HuggingFace repo
113
+ # This is a placeholder - replace with your actual repo
114
+ model_path = hf_hub_download(
115
+ repo_id="hamzaAvvan/cyclegan-sketch-photo",
116
+ filename="cyclegan_best.pth",
117
+ repo_type="model"
118
+ )
119
+ checkpoint = torch.load(model_path, map_location=DEVICE)
120
+ if 'G_AB' in checkpoint:
121
+ G_AB.load_state_dict(checkpoint['G_AB'])
122
+ G_BA.load_state_dict(checkpoint['G_BA'])
123
+ except:
124
+ print("Models not found on HuggingFace Hub. Using initialized models.")
125
+
126
+ return G_AB, G_BA, D_A, D_B
127
+
128
+
129
+ def load_training_history():
130
+ """Load training history from JSON if available"""
131
+ try:
132
+ from huggingface_hub import hf_hub_download
133
+ history_path = hf_hub_download(
134
+ repo_id="hamzaAvvan/cyclegan-sketch-photo",
135
+ filename="training_history.json",
136
+ repo_type="model"
137
+ )
138
+ with open(history_path, 'r') as f:
139
+ return json.load(f)
140
+ except:
141
+ # Return dummy data for demonstration
142
+ return {
143
+ "num_epochs_completed": 5,
144
+ "total_epochs": 5,
145
+ "best_cycle_loss": 0.0523,
146
+ "training_losses": {
147
+ "generator": [0.8234, 0.7123, 0.6234, 0.5891, 0.5234],
148
+ "discriminator_a": [0.6234, 0.5891, 0.5123, 0.4891, 0.4523],
149
+ "discriminator_b": [0.6891, 0.6123, 0.5345, 0.5123, 0.4678],
150
+ "cycle_loss": [1.2345, 1.0234, 0.8923, 0.7456, 0.6234],
151
+ "identity_loss": [0.5234, 0.4891, 0.4123, 0.3891, 0.3456],
152
+ }
153
+ }
154
+
155
+
156
+ # ==================== IMAGE PROCESSING ====================
157
+ def tensor_to_image(tensor):
158
+ """Convert tensor to PIL Image"""
159
+ with torch.no_grad():
160
+ img_np = ((tensor.squeeze().cpu() + 1) / 2).clamp(0, 1).permute(1, 2, 0).numpy()
161
+ return Image.fromarray((img_np * 255).astype(np.uint8))
162
+
163
+
164
+ def image_to_tensor(pil_image):
165
+ """Convert PIL Image to normalized tensor"""
166
+ img_resized = pil_image.resize((IMG_SIZE, IMG_SIZE), Image.LANCZOS)
167
+ img_array = np.array(img_resized) / 255.0
168
+ if len(img_array.shape) == 2: # Grayscale
169
+ img_array = np.stack([img_array] * 3, axis=-1)
170
+ img_tensor = torch.from_numpy(img_array).float().permute(2, 0, 1)
171
+ img_tensor = (img_tensor * 2) - 1 # Normalize to [-1, 1]
172
+ return img_tensor.unsqueeze(0).to(DEVICE)
173
+
174
+
175
+ # ==================== LOSS FUNCTION EXPLANATIONS ====================
176
+ LOSS_EXPLANATIONS = {
177
+ "Adversarial Loss (LSGAN)": {
178
+ "formula": "L_GAN = E[(D(x) - 1)²] + E[(D(G(z)))²]",
179
+ "description": """
180
+ <b>Purpose:</b> Encourages the generator to produce realistic images that fool the discriminator.
181
+
182
+ <b>How it works:</b>
183
+ • Generator tries to minimize: E[(D(G(x)) - 1)²] (fool discriminator)
184
+ • Discriminator tries to minimize: E[(D(x) - 1)²] + E[(D(G(x)))²] (correct classification)
185
+
186
+ <b>Why LSGAN:</b> Provides stable training compared to standard GAN loss. Uses MSE instead of cross-entropy.
187
+ """,
188
+ "weight": "1.0 (baseline)"
189
+ },
190
+
191
+ "Cycle Consistency Loss": {
192
+ "formula": "L_cyc = E[||G_BA(G_AB(x)) - x||₁] + E[||G_AB(G_BA(y)) - y||₁]",
193
+ "description": """
194
+ <b>Purpose:</b> Ensures unpaired image-to-image translation maintains content.
195
+
196
+ <b>How it works:</b>
197
+ • Translation Forward: Sketch → Photo (G_AB)
198
+ • Translation Backward: Photo → Sketch (G_BA)
199
+ • Cycle: Sketch → Photo → Sketch should reconstruct original
200
+ • This prevents mode collapse and maintains structural information
201
+
202
+ <b>Why crucial:</b> Enables training WITHOUT paired data. Critical for unpaired translation.
203
+
204
+ <b>Weight:</b> λ_cyc = 10.0 (heavily weighted to preserve structure)
205
+ """,
206
+ "weight": "10.0 (most important)"
207
+ },
208
+
209
+ "Identity Loss": {
210
+ "formula": "L_idt = E[||G_AB(y) - y||₁] + E[||G_BA(x) - x||₁]",
211
+ "description": """
212
+ <b>Purpose:</b> Encourages generators to preserve image characteristics when translating similar domains.
213
+
214
+ <b>How it works:</b>
215
+ • If photo is translated through photo-generator, it should remain unchanged
216
+ • If sketch is translated through sketch-generator, it should remain unchanged
217
+ • Prevents unnecessary transformations when input is already in target domain
218
+
219
+ <b>Benefit:</b> Improves image quality and visual stability. Prevents artifacts.
220
+
221
+ <b>Weight:</b> λ_idt = 5.0 (secondary importance)
222
+ """,
223
+ "weight": "5.0 (secondary)"
224
+ }
225
+ }
226
+
227
+
228
+ def create_loss_explanation_tab():
229
+ """Create detailed loss function explanation with formulas"""
230
+ html_content = """
231
+ <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
232
+ padding: 30px; border-radius: 15px; color: white; margin-bottom: 20px;">
233
+ <h1 style="margin: 0; font-size: 2.5em;">🎨 CycleGAN Loss Functions</h1>
234
+ <p style="margin: 10px 0 0 0; font-size: 1.1em; opacity: 0.95;">
235
+ Understanding the training objectives for unpaired image translation
236
+ </p>
237
+ </div>
238
+ """
239
+
240
+ for loss_name, loss_info in LOSS_EXPLANATIONS.items():
241
+ html_content += f"""
242
+ <div style="background: #f8f9fa; padding: 20px; border-radius: 10px; margin: 20px 0;
243
+ border-left: 5px solid #667eea;">
244
+ <h2 style="color: #667eea; margin-top: 0;">{loss_name}</h2>
245
+
246
+ <div style="background: #e8eaf6; padding: 15px; border-radius: 8px;
247
+ font-family: 'Courier New', monospace; font-size: 1.05em;
248
+ margin: 15px 0; color: #333;">
249
+ <strong>Formula:</strong> {loss_info['formula']}
250
+ </div>
251
+
252
+ <div style="color: #333; line-height: 1.8;">
253
+ {loss_info['description']}
254
+ </div>
255
+
256
+ <div style="background: #fff3e0; padding: 10px 15px; border-radius: 8px;
257
+ margin-top: 15px; color: #e65100;">
258
+ <strong>⚖️ Weight:</strong> {loss_info['weight']}
259
+ </div>
260
+ </div>
261
+ """
262
+
263
+ html_content += """
264
+ <div style="background: #e3f2fd; padding: 20px; border-radius: 10px; margin: 20px 0;">
265
+ <h3 style="color: #1976d2; margin-top: 0;">🔬 Training Dynamics</h3>
266
+ <p style="color: #333; line-height: 1.8;">
267
+ <strong>Total Loss = L_GAN + λ_cyc × L_cyc + λ_idt × L_idt</strong><br><br>
268
+ The generator learns to balance three objectives:
269
+ <ul style="color: #333;">
270
+ <li><strong>Realism</strong>: Fool the discriminator (L_GAN)</li>
271
+ <li><strong>Content Preservation</strong>: Maintain structure through cycle (L_cyc) ⭐</li>
272
+ <li><strong>Domain Consistency</strong>: Preserve domain characteristics (L_idt)</li>
273
+ </ul>
274
+ The cycle consistency loss dominates, ensuring quality unpaired translation.
275
+ </p>
276
+ </div>
277
+ """
278
+
279
+ return html_content
280
+
281
+
282
+ # ==================== VISUALIZATION FUNCTIONS ====================
283
+ def plot_training_losses(history):
284
+ """Create matplotlib figure with training loss curves"""
285
+ if not history or 'training_losses' not in history:
286
+ return None
287
+
288
+ losses = history['training_losses']
289
+ epochs = range(1, len(losses['generator']) + 1)
290
+
291
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
292
+ fig.patch.set_facecolor('white')
293
+
294
+ # Generator Loss
295
+ axes[0, 0].plot(epochs, losses['generator'], 'o-', linewidth=2.5,
296
+ markersize=6, color='#667eea', label='Generator')
297
+ axes[0, 0].set_title('Generator Loss', fontsize=12, fontweight='bold')
298
+ axes[0, 0].set_xlabel('Epoch')
299
+ axes[0, 0].set_ylabel('Loss')
300
+ axes[0, 0].grid(True, alpha=0.3)
301
+ axes[0, 0].legend()
302
+
303
+ # Discriminator Losses
304
+ axes[0, 1].plot(epochs, losses['discriminator_a'], 'o-', linewidth=2.5,
305
+ markersize=6, color='#f57c00', label='Discriminator A (Sketch)')
306
+ axes[0, 1].plot(epochs, losses['discriminator_b'], 's-', linewidth=2.5,
307
+ markersize=6, color='#c62828', label='Discriminator B (Photo)')
308
+ axes[0, 1].set_title('Discriminator Losses', fontsize=12, fontweight='bold')
309
+ axes[0, 1].set_xlabel('Epoch')
310
+ axes[0, 1].set_ylabel('Loss')
311
+ axes[0, 1].grid(True, alpha=0.3)
312
+ axes[0, 1].legend()
313
+
314
+ # Cycle & Identity Loss
315
+ axes[1, 0].plot(epochs, losses['cycle_loss'], 'o-', linewidth=2.5,
316
+ markersize=6, color='#2e7d32', label='Cycle Loss')
317
+ axes[1, 0].plot(epochs, losses['identity_loss'], 's-', linewidth=2.5,
318
+ markersize=6, color='#7b1fa2', label='Identity Loss')
319
+ axes[1, 0].set_title('Cycle & Identity Losses', fontsize=12, fontweight='bold')
320
+ axes[1, 0].set_xlabel('Epoch')
321
+ axes[1, 0].set_ylabel('Loss')
322
+ axes[1, 0].grid(True, alpha=0.3)
323
+ axes[1, 0].legend()
324
+
325
+ # Combined Loss
326
+ total_loss = [g + d_a + d_b + c + i
327
+ for g, d_a, d_b, c, i in zip(
328
+ losses['generator'],
329
+ losses['discriminator_a'],
330
+ losses['discriminator_b'],
331
+ losses['cycle_loss'],
332
+ losses['identity_loss'])]
333
+ axes[1, 1].plot(epochs, total_loss, 'o-', linewidth=2.5, markersize=6,
334
+ color='#d32f2f', label='Total Loss')
335
+ axes[1, 1].fill_between(epochs, total_loss, alpha=0.3, color='#d32f2f')
336
+ axes[1, 1].set_title('Total Loss', fontsize=12, fontweight='bold')
337
+ axes[1, 1].set_xlabel('Epoch')
338
+ axes[1, 1].set_ylabel('Loss')
339
+ axes[1, 1].grid(True, alpha=0.3)
340
+ axes[1, 1].legend()
341
+
342
+ plt.tight_layout()
343
+
344
+ # Convert to PIL Image
345
+ buf = io.BytesIO()
346
+ plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
347
+ buf.seek(0)
348
+ img = Image.open(buf)
349
+ plt.close(fig)
350
+ return img
351
+
352
+
353
+ def create_model_info_html():
354
+ """Create HTML with model architecture information"""
355
+ html = """
356
+ <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
357
+ padding: 30px; border-radius: 15px; color: white; margin-bottom: 20px;">
358
+ <h1 style="margin: 0; font-size: 2.5em;">⚙️ Model Architecture</h1>
359
+ <p style="margin: 10px 0 0 0; font-size: 1.1em; opacity: 0.95;">
360
+ CycleGAN for Unpaired Sketch ↔ Photo Translation
361
+ </p>
362
+ </div>
363
+
364
+ <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 20px; margin: 20px 0;">
365
+ <div style="background: #e3f2fd; padding: 20px; border-radius: 10px;">
366
+ <h3 style="color: #1976d2; margin-top: 0;">🎬 Generator (G)</h3>
367
+ <ul style="color: #333; line-height: 2;">
368
+ <li><strong>Components:</strong> Encoder → Residual Blocks → Decoder</li>
369
+ <li><strong>Encoder:</strong> 2 conv layers (stride 2)</li>
370
+ <li><strong>Residual:</strong> 9 ResBlocks</li>
371
+ <li><strong>Decoder:</strong> 2 transpose conv layers</li>
372
+ <li><strong>Normalization:</strong> Instance Normalization</li>
373
+ <li><strong>Activation:</strong> ReLU (encoder), Tanh (output)</li>
374
+ <li><strong>Features:</strong> 64 → 128 → 256 → 128 → 64</li>
375
+ </ul>
376
+ </div>
377
+
378
+ <div style="background: #fff3e0; padding: 20px; border-radius: 10px;">
379
+ <h3 style="color: #e65100; margin-top: 0;">🕵️ Discriminator (D)</h3>
380
+ <ul style="color: #333; line-height: 2;">
381
+ <li><strong>Type:</strong> PatchGAN Discriminator</li>
382
+ <li><strong>Input:</strong> 256×256 images</li>
383
+ <li><strong>Patch Size:</strong> 70×70 receptive field</li>
384
+ <li><strong>Layers:</strong> 4 Conv blocks + 1 output conv</li>
385
+ <li><strong>Normalization:</strong> Instance Normalization</li>
386
+ <li><strong>Activation:</strong> LeakyReLU (slope 0.2)</li>
387
+ <li><strong>Output:</strong> 1 channel (real/fake prediction)</li>
388
+ </ul>
389
+ </div>
390
+ </div>
391
+
392
+ <div style="background: #f3e5f5; padding: 20px; border-radius: 10px; margin: 20px 0;">
393
+ <h3 style="color: #6a1b9a; margin-top: 0;">📊 Hyperparameters</h3>
394
+ <div style="display: grid; grid-template-columns: 1fr 1fr 1fr; gap: 15px; color: #4a148c;">
395
+ <div><strong>Image Size:</strong> 256×256</div>
396
+ <div><strong>Batch Size:</strong> 4</div>
397
+ <div><strong>Learning Rate:</strong> 2e-4</div>
398
+ <div><strong>Optimizer:</strong> Adam</div>
399
+ <div><strong>β₁, β₂:</strong> 0.5, 0.999</div>
400
+ <div><strong>Epochs:</strong> 5</div>
401
+ <div><strong>λ (Cycle):</strong> 10.0</div>
402
+ <div><strong>λ (Identity):</strong> 5.0</div>
403
+ <div><strong>Pool Size:</strong> 50 (image replay)</div>
404
+ </div>
405
+ </div>
406
+ """
407
+ return html
408
+
409
+
410
+ # ==================== MAIN INFERENCE FUNCTION ====================
411
+ def translate_image(input_image, translation_direction):
412
+ """Perform image translation"""
413
+ if input_image is None:
414
+ return None, "❌ Please upload an image first"
415
+
416
+ try:
417
+ # Ensure image is RGB
418
+ if input_image.mode != 'RGB':
419
+ input_image = input_image.convert('RGB')
420
+
421
+ # Convert to tensor
422
+ img_tensor = image_to_tensor(input_image)
423
+
424
+ # Select appropriate generator
425
+ if translation_direction == "Sketch → Photo":
426
+ generator = G_AB
427
+ else:
428
+ generator = G_BA
429
+
430
+ # Forward pass
431
+ with torch.no_grad():
432
+ output_tensor = generator(img_tensor)
433
+
434
+ output_image = tensor_to_image(output_tensor)
435
+ return output_image, "✅ Translation successful!"
436
+
437
+ except Exception as e:
438
+ return None, f"❌ Error: {str(e)}"
439
+
440
+
441
+ def create_comparison_figure(original, translated, direction):
442
+ """Create comparison image with labels"""
443
+ fig, axes = plt.subplots(1, 2, figsize=(12, 5))
444
+
445
+ axes[0].imshow(original)
446
+ axes[0].set_title(f"Original ({direction.split('→')[0].strip()})",
447
+ fontsize=12, fontweight='bold')
448
+ axes[0].axis('off')
449
+
450
+ axes[1].imshow(translated)
451
+ axes[1].set_title(f"Translated ({direction.split('→')[1].strip()})",
452
+ fontsize=12, fontweight='bold')
453
+ axes[1].axis('off')
454
+
455
+ plt.tight_layout()
456
+
457
+ buf = io.BytesIO()
458
+ plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
459
+ buf.seek(0)
460
+ comparison = Image.open(buf)
461
+ plt.close(fig)
462
+ return comparison
463
+
464
+
465
+ # ==================== GRADIO INTERFACE ====================
466
+ def create_interface():
467
+ """Create beautiful Gradio interface"""
468
+
469
+ # Load models and history
470
+ G_AB, G_BA, _, _ = load_models()
471
+ history = load_training_history()
472
+
473
+ with gr.Blocks(title="CycleGAN: Sketch ↔ Photo Translation") as demo:
474
+
475
+ gr.HTML("""
476
+ <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
477
+ padding: 40px 20px; border-radius: 15px; text-align: center;
478
+ margin-bottom: 30px; color: white;">
479
+ <h1 style="margin: 0; font-size: 3em;">🎨 CycleGAN Translation</h1>
480
+ <p style="margin: 15px 0 0 0; font-size: 1.2em; opacity: 0.95;">
481
+ 🖼️ Sketch ↔ Photo Translation | Beautiful Unpaired Image-to-Image Learning
482
+ </p>
483
+ <p style="margin: 10px 0 0 0; font-size: 0.95em; opacity: 0.85;">
484
+ Powered by Cycle Consistency Loss | Running on 🔥 {DEVICE}
485
+ </p>
486
+ </div>
487
+ """.format(DEVICE=str(DEVICE).upper()))
488
+
489
+ with gr.Tabs():
490
+
491
+ # ============ TAB 1: IMAGE TRANSLATION ============
492
+ with gr.Tab("🎨 Image Translation", id=0):
493
+ with gr.Row():
494
+ with gr.Column(scale=1):
495
+ gr.HTML("<h2 style='color: #667eea;'>Upload & Translate</h2>")
496
+
497
+ input_image = gr.Image(label="📸 Input Image",
498
+ type="pil", height=400)
499
+
500
+ direction = gr.Radio(
501
+ ["Sketch → Photo", "Photo → Sketch"],
502
+ value="Sketch → Photo",
503
+ label="🔄 Translation Direction"
504
+ )
505
+
506
+ translate_btn = gr.Button("🚀 Translate Image",
507
+ size="lg",
508
+ variant="primary")
509
+
510
+ output_status = gr.Textbox(label="Status",
511
+ interactive=False,
512
+ value="Ready")
513
+
514
+ with gr.Column(scale=1):
515
+ gr.HTML("<h2 style='color: #667eea;'>Result</h2>")
516
+ output_image = gr.Image(label="🎯 Translated Image",
517
+ type="pil", height=400)
518
+
519
+ translate_btn.click(
520
+ fn=translate_image,
521
+ inputs=[input_image, direction],
522
+ outputs=[output_image, output_status]
523
+ )
524
+
525
+ # Comparison gallery
526
+ gr.HTML("""
527
+ <div style="margin-top: 30px; padding: 20px; background: #f5f5f5;
528
+ border-radius: 10px;">
529
+ <h3 style="color: #667eea;">📖 Example Translations</h3>
530
+ <p style="color: #666;">
531
+ This model translates between sketches and photos using <b>Cycle Consistency Loss</b>,
532
+ enabling unpaired training. The cycle loss ensures that sketch→photo→sketch
533
+ reconstruction matches the original.
534
+ </p>
535
+ </div>
536
+ """)
537
+
538
+ # ============ TAB 2: LOSS FUNCTIONS ============
539
+ with gr.Tab("📚 Loss Functions", id=1):
540
+ gr.HTML(create_loss_explanation_tab())
541
+
542
+ # ============ TAB 3: TRAINING HISTORY ============
543
+ with gr.Tab("📊 Training History", id=2):
544
+ gr.HTML("<h2 style='color: #667eea; text-align: center;'>Training Loss Curves</h2>")
545
+
546
+ loss_plot = plot_training_losses(history)
547
+ if loss_plot:
548
+ gr.Image(value=loss_plot, label="Loss Visualization",
549
+ show_label=True)
550
+ else:
551
+ gr.HTML("<p style='text-align: center; color: #999;'>Loading training data...</p>")
552
+
553
+ # Statistics
554
+ if history:
555
+ gr.HTML(f"""
556
+ <div style="display: grid; grid-template-columns: 1fr 1fr 1fr 1fr; gap: 15px; margin-top: 20px;">
557
+ <div style="background: #e3f2fd; padding: 20px; border-radius: 10px; text-align: center;">
558
+ <h3 style="color: #1976d2; margin: 0;">Epochs</h3>
559
+ <p style="font-size: 1.5em; color: #1565c0; margin: 10px 0 0 0;">
560
+ {history.get('num_epochs_completed', 0)}/{history.get('total_epochs', 5)}
561
+ </p>
562
+ </div>
563
+
564
+ <div style="background: #fff3e0; padding: 20px; border-radius: 10px; text-align: center;">
565
+ <h3 style="color: #e65100; margin: 0;">Best Cycle Loss</h3>
566
+ <p style="font-size: 1.5em; color: #e65100; margin: 10px 0 0 0;">
567
+ {history.get('best_cycle_loss', 0):.4f}
568
+ </p>
569
+ </div>
570
+
571
+ <div style="background: #f3e5f5; padding: 20px; border-radius: 10px; text-align: center;">
572
+ <h3 style="color: #6a1b9a; margin: 0;">Final LR</h3>
573
+ <p style="font-size: 1.5em; color: #6a1b9a; margin: 10px 0 0 0;">
574
+ 2e-4 → 0
575
+ </p>
576
+ </div>
577
+
578
+ <div style="background: #e8f5e9; padding: 20px; border-radius: 10px; text-align: center;">
579
+ <h3 style="color: #2e7d32; margin: 0;">Status</h3>
580
+ <p style="font-size: 1.5em; color: #2e7d32; margin: 10px 0 0 0;">
581
+ ✅ Complete
582
+ </p>
583
+ </div>
584
+ </div>
585
+ """)
586
+
587
+ # ============ TAB 4: MODEL INFO ============
588
+ with gr.Tab("⚙️ Model Architecture", id=3):
589
+ gr.HTML(create_model_info_html())
590
+
591
+ # ============ TAB 5: ABOUT ============
592
+ with gr.Tab("ℹ️ About", id=4):
593
+ gr.HTML("""
594
+ <div style="padding: 30px;">
595
+ <h2 style="color: #667eea;">About CycleGAN</h2>
596
+
597
+ <div style="background: #f5f5f5; padding: 20px; border-radius: 10px; margin: 20px 0;">
598
+ <h3>What is CycleGAN?</h3>
599
+ <p>
600
+ CycleGAN is a deep learning model for unpaired image-to-image translation.
601
+ Unlike pix2pix, it doesn't require paired training data. Instead, it uses
602
+ <b>cycle consistency loss</b> to ensure that translating an image and then
603
+ translating it back recovers the original image.
604
+ </p>
605
+ </div>
606
+
607
+ <div style="background: #e3f2fd; padding: 20px; border-radius: 10px; margin: 20px 0;">
608
+ <h3 style="color: #1976d2;">Key Innovation: Cycle Consistency</h3>
609
+ <p>
610
+ <b>Traditional Approach:</b> x → y (requires paired data)<br>
611
+ <b>CycleGAN Approach:</b> x → G(x) → G(F(G(x))) ≈ x<br><br>
612
+ This enables training on unpaired image collections, making it applicable
613
+ to many real-world scenarios where paired data is unavailable.
614
+ </p>
615
+ </div>
616
+
617
+ <div style="background: #fff3e0; padding: 20px; border-radius: 10px; margin: 20px 0;">
618
+ <h3 style="color: #e65100;">Applications</h3>
619
+ <ul style="color: #333;">
620
+ <li>🖼️ Sketch → Photo / Photo → Sketch (this project)</li>
621
+ <li>🌅 Photo style transfer (summer ↔ winter)</li>
622
+ <li>🎨 Artistic style transfer</li>
623
+ <li>🐎 Object morphing (horses ↔ zebras)</li>
624
+ <li>🌃 Domain adaptation for autonomous driving</li>
625
+ </ul>
626
+ </div>
627
+
628
+ <div style="background: #f3e5f5; padding: 20px; border-radius: 10px; margin: 20px 0;">
629
+ <h3 style="color: #6a1b9a;">Paper & Resources</h3>
630
+ <ul style="color: #333;">
631
+ <li><b>Original Paper:</b> CycleGAN: Unpaired Image-to-Image Translation
632
+ (Zhu et al., 2017)</li>
633
+ <li><b>Repository:</b> junyanz/CycleGAN</li>
634
+ <li><b>This Implementation:</b> PyTorch with Instance Normalization</li>
635
+ </ul>
636
+ </div>
637
+
638
+ <hr style="border: none; border-top: 2px solid #ddd; margin: 30px 0;">
639
+
640
+ <div style="text-align: center; color: #999;">
641
+ <p>Made with ❤️ for HuggingFace Spaces</p>
642
+ <p>Dataset: TU-Berlin, Sketchy, QuickDraw, COCO</p>
643
+ </div>
644
+ </div>
645
+ """)
646
+
647
+ return demo
648
+
649
+
650
+ # ==================== MAIN ====================
651
+ if __name__ == "__main__":
652
+ G_AB, G_BA, D_A, D_B = load_models()
653
+
654
+ demo = create_interface()
655
+ demo.launch(
656
+ server_name="0.0.0.0",
657
+ server_port=7860,
658
+ share=False,
659
+ theme=gr.themes.Soft()
660
+ )
cyclegan_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d51fbffa4fd5f3502118fe8c43bf843a46fd4b3c97d596d7b9366e114c39783
3
+ size 339554077
cyclegan_latest.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70b20a530d5bda7d9c312a64a4ff520f829651ec164b377de04620ce0159d610
3
+ size 339570313
generator_photo_to_sketch.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92e0fcd35e37ef289e58b501c3abe1a26836e915d92041ac910491fcd708503f
3
+ size 45533279
generator_sketch_to_photo.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:048ddd25ae2e9f2d45cb3513de35b1823c27199d200c3f82e30c6c3d3856c2dd
3
+ size 45533279
model_config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": "CycleGAN",
3
+ "timestamp": "2026-03-31T10:25:57.786837",
4
+ "device": "cuda",
5
+ "hyperparameters": {
6
+ "img_size": 256,
7
+ "batch_size": 4,
8
+ "num_epochs": 5,
9
+ "learning_rate": 0.0002,
10
+ "betas": [
11
+ 0.5,
12
+ 0.999
13
+ ],
14
+ "lambda_cycle": 10.0,
15
+ "lambda_identity": 5.0,
16
+ "num_residual_blocks": 9,
17
+ "ngf": 64,
18
+ "ndf": 64,
19
+ "pool_size": 50
20
+ },
21
+ "architecture": {
22
+ "generator": {
23
+ "name": "Generator",
24
+ "in_channels": 3,
25
+ "out_channels": 3,
26
+ "ngf": 64,
27
+ "num_residual_blocks": 9
28
+ },
29
+ "discriminator": {
30
+ "name": "PatchDiscriminator",
31
+ "in_channels": 3,
32
+ "ndf": 64
33
+ }
34
+ }
35
+ }
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CycleGAN Gradio UI for HuggingFace Spaces
2
+ # Core Dependencies
3
+ torch>=2.5.0
4
+ torchvision>=0.20.0
5
+
6
+ # Web Framework
7
+ gradio>=4.40.0
8
+
9
+ # Image Processing
10
+ Pillow>=10.0.0
11
+ opencv-python>=4.8.0
12
+
13
+ # Data & Visualization
14
+ numpy>=1.24.0
15
+ matplotlib>=3.8.0
16
+ scikit-image>=0.21.0
17
+
18
+ # HuggingFace Integration
19
+ huggingface-hub>=0.23.0
20
+ datasets>=2.16.0
21
+
22
+ # Utilities
23
+ tqdm>=4.66.0
24
+ requests>=2.31.0
25
+
26
+ # Optional: For better performance
27
+ tensorboard>=2.14.0
training_history.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_epochs_completed": 5,
3
+ "total_epochs": 5,
4
+ "best_cycle_loss": 2.237016999010454,
5
+ "training_losses": {
6
+ "generator": [
7
+ 5.824810002728513,
8
+ 4.782709208956936,
9
+ 4.3199727865687585,
10
+ 4.312878673871358,
11
+ 4.070108941814356
12
+ ],
13
+ "discriminator_a": [
14
+ 0.18428663494686284,
15
+ 0.18036553872520464,
16
+ 0.17618322817510682,
17
+ 0.1273807303881959,
18
+ 0.11236962005888161
19
+ ],
20
+ "discriminator_b": [
21
+ 0.220113595857432,
22
+ 0.20027516328321215,
23
+ 0.21516741084686497,
24
+ 0.14414526903838443,
25
+ 0.12007981108338163
26
+ ],
27
+ "cycle_loss": [
28
+ 3.3778758166965686,
29
+ 2.7572627780730263,
30
+ 2.507615771753746,
31
+ 2.4373167831855906,
32
+ 2.237016999010454
33
+ ],
34
+ "identity_loss": [
35
+ 1.3829263848170898,
36
+ 1.064698489360642,
37
+ 0.9370824978853527,
38
+ 0.8602104993870384,
39
+ 0.7876052052305456
40
+ ]
41
+ },
42
+ "final_metrics": {
43
+ "ssim_A": 0.9793083667755127,
44
+ "psnr_A": 34.2286615181765,
45
+ "ssim_B": 0.5627841353416443,
46
+ "psnr_B": 18.380654660795532
47
+ }
48
+ }