munjed commited on
Commit
78daa12
·
1 Parent(s): 613a292
Files changed (4) hide show
  1. app.py +3 -11
  2. feature_extractor.py +1 -21
  3. model.py +79 -93
  4. models/{chotic.pth → model_32.pth} +0 -0
app.py CHANGED
@@ -4,34 +4,27 @@ import numpy as np
4
  import gradio as gr
5
  import os
6
 
7
- from model import ChaoticCoherentGenerator
8
- # from model_256 import EfficientChaoticGenerator
9
  from feature_extractor import CodeFeatureExtractor
10
 
11
- # ------------------- Device -------------------
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
- # ------------------- Load Model -------------------
15
- model = ChaoticCoherentGenerator().to(device)
16
- checkpoint = torch.load("models/chotic.pth", map_location=device)
17
  model.load_state_dict(checkpoint["state_dict"])
18
  model.eval()
19
 
20
  extractor = CodeFeatureExtractor()
21
 
22
- # ------------------- Image Enhancement -------------------
23
  def enhance_image(image: Image.Image, upscale_size: int):
24
- # Upscale smoothly
25
  image = image.resize((upscale_size, upscale_size), Image.Resampling.BICUBIC)
26
 
27
- # Optional post-processing
28
  image = image.filter(ImageFilter.GaussianBlur(radius=0.8))
29
  image = ImageEnhance.Color(image).enhance(1.2)
30
  image = ImageEnhance.Sharpness(image).enhance(1.1)
31
 
32
  return image
33
 
34
- # ------------------- Generation Function -------------------
35
  def generate_from_code(code_text, upscale_size):
36
  temp_file = "temp.py"
37
  with open(temp_file, "w", encoding="utf-8") as f:
@@ -51,7 +44,6 @@ def generate_from_code(code_text, upscale_size):
51
  enhanced = enhance_image(img, upscale_size)
52
  return enhanced
53
 
54
- # ------------------- Gradio UI -------------------
55
  demo = gr.Interface(
56
  fn=generate_from_code,
57
  inputs=[
 
4
  import gradio as gr
5
  import os
6
 
7
+ from model import Generator
 
8
  from feature_extractor import CodeFeatureExtractor
9
 
 
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
+ model = Generator().to(device)
13
+ checkpoint = torch.load("models/mode_32.pth", map_location=device)
 
14
  model.load_state_dict(checkpoint["state_dict"])
15
  model.eval()
16
 
17
  extractor = CodeFeatureExtractor()
18
 
 
19
  def enhance_image(image: Image.Image, upscale_size: int):
 
20
  image = image.resize((upscale_size, upscale_size), Image.Resampling.BICUBIC)
21
 
 
22
  image = image.filter(ImageFilter.GaussianBlur(radius=0.8))
23
  image = ImageEnhance.Color(image).enhance(1.2)
24
  image = ImageEnhance.Sharpness(image).enhance(1.1)
25
 
26
  return image
27
 
 
28
  def generate_from_code(code_text, upscale_size):
29
  temp_file = "temp.py"
30
  with open(temp_file, "w", encoding="utf-8") as f:
 
44
  enhanced = enhance_image(img, upscale_size)
45
  return enhanced
46
 
 
47
  demo = gr.Interface(
48
  fn=generate_from_code,
49
  inputs=[
feature_extractor.py CHANGED
@@ -38,15 +38,6 @@ class CodeFeatureExtractor:
38
  ]
39
 
40
  def extract_from_file(self, filepath):
41
- """
42
- Extract features from file
43
-
44
- Args:
45
- filepath (str)
46
-
47
- Returns:
48
- list: list of dfeatures
49
- """
50
  try:
51
  with open(filepath, 'r', encoding='utf-8') as f:
52
  code = f.read()
@@ -229,13 +220,6 @@ class CodeFeatureExtractor:
229
  return volume
230
 
231
  def extract_from_directory(self, directory, output_file='data/features.npy'):
232
- """
233
- Extract from a dir
234
-
235
- Args:
236
- directory (str): dir name
237
- output_file (str): file to save
238
- """
239
  print(f"Extracting from: {directory}")
240
 
241
  features_list = []
@@ -258,7 +242,7 @@ class CodeFeatureExtractor:
258
  failed += 1
259
 
260
  print(f"\n Features extracted from {len(features_list)} files")
261
- print(f"⚠️Failed: {failed} files")
262
 
263
  features_array = np.array(features_list)
264
  np.save(output_file, features_array)
@@ -290,10 +274,6 @@ class CodeFeatureExtractor:
290
  def main():
291
  extractor = CodeFeatureExtractor()
292
 
293
- print("="*60)
294
- print("Feature Extractor Code")
295
- print("="*60)
296
-
297
  #features = extractor.extract_from_directory('data/raw_code')
298
 
299
  features = extractor.extract_from_file('src/model.py')
 
38
  ]
39
 
40
  def extract_from_file(self, filepath):
 
 
 
 
 
 
 
 
 
41
  try:
42
  with open(filepath, 'r', encoding='utf-8') as f:
43
  code = f.read()
 
220
  return volume
221
 
222
  def extract_from_directory(self, directory, output_file='data/features.npy'):
 
 
 
 
 
 
 
223
  print(f"Extracting from: {directory}")
224
 
225
  features_list = []
 
242
  failed += 1
243
 
244
  print(f"\n Features extracted from {len(features_list)} files")
245
+ print(f"Failed: {failed} files")
246
 
247
  features_array = np.array(features_list)
248
  np.save(output_file, features_array)
 
274
  def main():
275
  extractor = CodeFeatureExtractor()
276
 
 
 
 
 
277
  #features = extractor.extract_from_directory('data/raw_code')
278
 
279
  features = extractor.extract_from_file('src/model.py')
model.py CHANGED
@@ -7,7 +7,6 @@ from torch.utils.data import Dataset, DataLoader
7
  import matplotlib.pyplot as plt
8
 
9
 
10
- # ===================== Feature Names =====================
11
 
12
  FEATURE_NAMES = [
13
  "lines_of_code", "num_functions", "num_classes", "num_loops",
@@ -20,7 +19,6 @@ FEATURE_NAMES = [
20
  ]
21
 
22
 
23
- # ===================== Dataset =====================
24
 
25
  class CodeToImageDataset(Dataset):
26
  def __init__(self, features):
@@ -33,28 +31,27 @@ class CodeToImageDataset(Dataset):
33
  return self.features[idx]
34
 
35
 
36
- # ===================== Feature-Driven Generator with CNN =====================
37
 
38
- class ChaoticCoherentGenerator(nn.Module):
39
  """
40
- Generates images with controlled chaos through:
41
- 1. Frequency-domain decomposition (low/mid/high)
42
- 2. Spatial composition layers
43
- 3. Feature-dependent color palettes
44
- 4. Texture injection at multiple scales
45
  """
46
  def __init__(self, input_dim=25, image_size=32):
47
  super().__init__()
48
  self.image_size = image_size
49
 
50
- # Feature grouping (keep your existing groups)
51
  self.structure_idx = [0, 1, 2, 10, 19, 20]
52
  self.control_idx = [3, 4, 9]
53
  self.operations_idx = [5, 7, 8, 6]
54
  self.style_idx = [14, 15, 24, 23]
55
  self.advanced_idx = [12, 13, 16, 17, 18, 11, 21, 22]
56
 
57
- # Low-freq: overall composition, large shapes (structure features)
58
  self.low_freq_encoder = nn.Sequential(
59
  nn.Linear(len(self.structure_idx), 256),
60
  nn.LayerNorm(256),
@@ -63,7 +60,7 @@ class ChaoticCoherentGenerator(nn.Module):
63
  nn.Linear(256, 512)
64
  )
65
 
66
- # Mid-freq: patterns, textures (control + operations)
67
  self.mid_freq_encoder = nn.Sequential(
68
  nn.Linear(len(self.control_idx) + len(self.operations_idx), 256),
69
  nn.LayerNorm(256),
@@ -72,7 +69,7 @@ class ChaoticCoherentGenerator(nn.Module):
72
  nn.Linear(256, 512)
73
  )
74
 
75
- # High-freq: details, noise (style + advanced)
76
  self.high_freq_encoder = nn.Sequential(
77
  nn.Linear(len(self.style_idx) + len(self.advanced_idx), 256),
78
  nn.LayerNorm(256),
@@ -88,7 +85,7 @@ class ChaoticCoherentGenerator(nn.Module):
88
  nn.Linear(128, 12) # 4 colors × 3 channels (HSV or RGB)
89
  )
90
 
91
- # Controls WHERE patterns appear (position, scale, rotation)
92
  self.composition_net = nn.Sequential(
93
  nn.Linear(512, 256),
94
  nn.LayerNorm(256),
@@ -96,12 +93,12 @@ class ChaoticCoherentGenerator(nn.Module):
96
  nn.Linear(256, 6) # [center_x, center_y, scale, rotation, flow_x, flow_y]
97
  )
98
 
99
- # Projection to spatial maps (separate for each frequency)
100
  self.low_to_spatial = nn.Linear(512, 8 * 8 * 128)
101
  self.mid_to_spatial = nn.Linear(512, 16 * 16 * 64)
102
  self.high_to_spatial = nn.Linear(512, 32 * 32 * 32)
103
 
104
- # Low-frequency path (8x8 -> 32x32)
105
  self.low_decoder = nn.Sequential(
106
  nn.ConvTranspose2d(128, 64, 4, 2, 1), # 8->16
107
  nn.GroupNorm(8, 64),
@@ -111,14 +108,14 @@ class ChaoticCoherentGenerator(nn.Module):
111
  nn.GELU()
112
  )
113
 
114
- # Mid-frequency path (16x16 -> 32x32)
115
  self.mid_decoder = nn.Sequential(
116
  nn.ConvTranspose2d(64, 32, 4, 2, 1), # 16->32
117
  nn.GroupNorm(8, 32),
118
  nn.GELU()
119
  )
120
 
121
- # High-frequency path (already 32x32)
122
  self.high_decoder = nn.Sequential(
123
  nn.Conv2d(32, 32, 3, 1, 1),
124
  nn.GroupNorm(8, 32),
@@ -126,7 +123,7 @@ class ChaoticCoherentGenerator(nn.Module):
126
  )
127
 
128
  self.fusion = nn.Sequential(
129
- nn.Conv2d(96, 64, 3, 1, 1), # 32+32+32 channels
130
  nn.GroupNorm(8, 64),
131
  nn.GELU(),
132
  nn.Conv2d(64, 32, 3, 1, 1),
@@ -152,14 +149,12 @@ class ChaoticCoherentGenerator(nn.Module):
152
  nn.init.constant_(m.bias, 0)
153
 
154
  def generate_perlin_noise(self, batch_size, size, device):
155
- """Generate smooth Perlin-like noise for organic texture"""
156
  grid_size = 4
157
  grid = torch.randn(batch_size, 2, grid_size, grid_size, device=device)
158
  noise = F.interpolate(grid, size=(size, size), mode='bicubic', align_corners=True)
159
- return noise[:, 0:1] # Single channel
160
 
161
  def apply_color_palette(self, grayscale, palette):
162
- """Map grayscale values to feature-driven color palette"""
163
  batch_size = grayscale.size(0)
164
  palette = palette.view(batch_size, 4, 3) # 4 clors, 3 channels
165
 
@@ -172,7 +167,6 @@ class ChaoticCoherentGenerator(nn.Module):
172
  channel_colors = palette[:, :, i] # (batch, 4)
173
  stops = torch.linspace(0, 1, 4, device=grayscale.device)
174
 
175
- # Interpolate between color stops
176
  result = torch.zeros_like(gray_norm)
177
  for j in range(3):
178
  mask = (gray_norm >= stops[j]) & (gray_norm < stops[j+1])
@@ -191,37 +185,36 @@ class ChaoticCoherentGenerator(nn.Module):
191
  batch_size = x.size(0)
192
  device = x.device
193
 
194
- # Extract feature groups
195
  structure = x[:, self.structure_idx]
196
  control = x[:, self.control_idx]
197
  operations = x[:, self.operations_idx]
198
  style = x[:, self.style_idx]
199
  advanced = x[:, self.advanced_idx]
200
 
201
- # Encode to frequency bands
202
  low_freq = self.low_freq_encoder(structure)
203
  mid_freq = self.mid_freq_encoder(torch.cat([control, operations], dim=1))
204
  high_freq = self.high_freq_encoder(torch.cat([style, advanced], dim=1))
205
 
206
- # Generate color palette
207
  palette = self.color_palette_net(x)
208
 
209
- # Generate composition parameters
210
  composition_params = self.composition_net(low_freq)
211
  composition_params = torch.tanh(composition_params)
212
 
213
- # Project to spatial maps
214
  low_spatial = self.low_to_spatial(low_freq).view(batch_size, 128, 8, 8)
215
  mid_spatial = self.mid_to_spatial(mid_freq).view(batch_size, 64, 16, 16)
216
  high_spatial = self.high_to_spatial(high_freq).view(batch_size, 32, 32, 32)
217
 
218
- # Decode each frequency band
219
  low_decoded = self.low_decoder(low_spatial) # (batch, 32, 32, 32)
220
  mid_decoded = self.mid_decoder(mid_spatial) # (batch, 32, 32, 32)
221
  high_decoded = self.high_decoder(high_spatial) # (batch, 32, 32, 32)
222
 
223
- # Apply spatial transformations based on composition params
224
- # (rotation, translation via grid_sample)
225
  theta = composition_params[:, 3:4] # rotation
226
  scale = 0.5 + composition_params[:, 2:3] # scale
227
  tx, ty = composition_params[:, 0:1], composition_params[:, 1:2]
@@ -239,32 +232,42 @@ class ChaoticCoherentGenerator(nn.Module):
239
  grid = F.affine_grid(affine_matrix, low_decoded.size(), align_corners=False)
240
  low_decoded = F.grid_sample(low_decoded, grid, align_corners=False, padding_mode='reflection')
241
 
242
- # Merge frequency bands
243
  merged = torch.cat([low_decoded, mid_decoded, high_decoded], dim=1)
244
- fused = self.fusion(merged) # (batch, 3, 32, 32)
245
 
246
- # Add controlled Perlin noise for organic texture
247
  noise_scale = self.noise_strength(x)
248
  perlin = self.generate_perlin_noise(batch_size, 32, device)
249
  perlin = perlin.expand(-1, 3, -1, -1)
250
  fused = fused + noise_scale.view(-1, 1, 1, 1) * perlin * 0.3
251
 
252
- # Convert to grayscale for palette mapping
253
  grayscale = fused.mean(dim=1, keepdim=True)
254
 
255
- # Apply feature-driven color palette
256
  colored = self.apply_color_palette(grayscale, palette)
257
 
258
- # Final normalization
259
  output = torch.tanh(colored)
260
 
261
  return output
262
 
263
-
264
- # ===================== Feature-Aware Loss Functions =====================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  def frequency_coherence_loss(images):
267
- """Enforce low-frequency dominance with high-frequency details"""
268
  # FFT decomposition
269
  fft = torch.fft.fft2(images)
270
  fft_shifted = torch.fft.fftshift(fft)
@@ -290,7 +293,6 @@ def frequency_coherence_loss(images):
290
 
291
 
292
  def color_harmony_loss(images):
293
- """Encourage harmonious color distributions"""
294
  batch_size = images.size(0)
295
 
296
  # convert to LAB-like space (approximation)
@@ -308,7 +310,6 @@ def color_harmony_loss(images):
308
 
309
 
310
  def texture_diversity_loss(images):
311
- """Encourage multi-scale texture patterns"""
312
  textures = []
313
  for scale in [1, 2, 4]:
314
  if scale > 1:
@@ -324,9 +325,10 @@ def texture_diversity_loss(images):
324
 
325
  return F.relu(0.1 - diversity) # Penalize if diversity < 0.1
326
 
 
 
327
 
328
  def combined_aesthetic_loss(features, images):
329
- """Enhanced loss with BALANCED weights"""
330
 
331
  consistency = feature_consistency_loss(features, images)
332
  variance = feature_variance_preservation_loss(features, images)
@@ -339,21 +341,21 @@ def combined_aesthetic_loss(features, images):
339
  smoothness = tv_h + tv_w
340
 
341
  total = (
342
- 10.0 * consistency + # ~1.5 → 15
343
- 0.05 * variance + # ~50 → 2.5
344
- 0.5 * freq_coherence + # ~14 → 7
345
- 0.2 * color_harmony + # ~31 → 6.2
346
- 0.01 * texture_div + # ~126 → 1.26
347
- 0.1 * smoothness # ~38 → 3.8
348
  )
349
 
350
  return total, {
351
- 'consistency': consistency.item(),
352
- 'variance': variance.item(),
353
- 'freq_coherence': freq_coherence.item(),
354
- 'color_harmony': color_harmony.item(),
355
- 'texture_div': texture_div.item(),
356
- 'smoothness': smoothness.item()
357
  }
358
 
359
  def frequency_coherence_loss(images):
@@ -385,7 +387,7 @@ def frequency_coherence_loss(images):
385
 
386
 
387
  def feature_variance_preservation_loss(features, images):
388
- """Ensure each feature dimension contributes to image variance"""
389
  batch_size = features.size(0)
390
  if batch_size < 4:
391
  return torch.tensor(0.0, device=features.device)
@@ -417,22 +419,19 @@ def hsv_to_rgb(h, s, v):
417
 
418
  # ===================== Trainer =====================
419
  def create_feature_explanation_grid(model, features, save_path):
420
- """
421
- Show how specific features affect the output
422
- Systematic perturbation of individual features
423
- """
424
  model.eval()
425
- base_idx = 42 # Choose interesting sample
426
  base_features = features[base_idx].copy()
427
 
428
- # Select 6 interesting features to visualize
429
  feature_indices = [0, 3, 10, 14, 19, 24] # LOC, loops, nesting, comments, complexity, indentation
430
  feature_names = [FEATURE_NAMES[i] for i in feature_indices]
431
 
432
  fig, axes = plt.subplots(len(feature_indices), 5, figsize=(15, 3*len(feature_indices)))
433
 
434
  for row, feat_idx in enumerate(feature_indices):
435
- # Create variations: [-2σ, -1σ, 0, +1σ, +2σ]
436
  variations = []
437
  for multiplier in [-2, -1, 0, 1, 2]:
438
  varied = base_features.copy()
@@ -466,28 +465,23 @@ def compute_color_palette_from_features(features):
466
  # Normalize to [0, 1]
467
  f_norm = (features - features.min(0)) / (features.max(0) - features.min(0) + 1e-8)
468
 
469
- # Base hue from structure
470
  structure_score = f_norm[[0, 1, 2]].mean()
471
  base_hue = structure_score * 0.7 # 0 (red) to 0.7 (blue)
472
 
473
- # Saturation from control flow
474
  control_score = f_norm[[3, 4]].mean()
475
  saturation = 0.3 + control_score * 0.6
476
 
477
- # Value from style
478
  style_score = f_norm[[14, 15]].mean()
479
  value = 0.4 + style_score * 0.5
480
 
481
- # Create 4 colors with variation
482
  palette = []
483
  for shift in [-0.1, 0, 0.1, 0.2]:
484
  hue = (base_hue + shift) % 1.0
485
  palette.append(hsv_to_rgb(hue, saturation, value))
486
 
487
- return np.array(palette).flatten() # (12,) vector
488
 
489
  def interpolate_code_features(model, features, idx1, idx2, steps=10):
490
- """Smooth transition between two code samples"""
491
  model.eval()
492
 
493
  f1 = features[idx1]
@@ -505,7 +499,7 @@ def interpolate_code_features(model, features, idx1, idx2, steps=10):
505
  return images
506
 
507
 
508
- class FeatureDrivenTrainer:
509
  def __init__(self, model, device='cpu'):
510
  self.model = model.to(device)
511
  self.device = device
@@ -564,12 +558,11 @@ class FeatureDrivenTrainer:
564
  min_lr=1e-6
565
  )
566
 
567
- print(f"Training chaotic-coherent model for {epochs} epochs on {self.device}...")
568
- print("Focus: Frequency coherence + color harmony + texture diversity")
569
 
570
  for epoch in range(epochs):
571
  avg_loss, components = self.train_epoch(train_loader, optimizer)
572
- scheduler.step(avg_loss) # Pass loss for plateau detection
573
 
574
  if (epoch + 1) % 10 == 0:
575
  print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f}")
@@ -642,9 +635,7 @@ class FeatureDrivenTrainer:
642
  print(f"Samples saved to {save_path}")
643
 
644
  def test_feature_consistency(self, features, save_path='results/consistency_test.png'):
645
- """
646
- Visualize that similar features produce similar images.
647
- """
648
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
649
  self.model.eval()
650
 
@@ -689,7 +680,6 @@ class FeatureDrivenTrainer:
689
  print(f"Consistency test saved to {save_path}")
690
 
691
 
692
- # ===================== Utilities =====================
693
 
694
  def prepare_data_loaders(features, batch_size=32):
695
  dataset = CodeToImageDataset(features)
@@ -698,19 +688,17 @@ def prepare_data_loaders(features, batch_size=32):
698
  return loader
699
 
700
 
701
- # ===================== Main =====================
702
 
703
  def main():
704
  os.makedirs("models", exist_ok=True)
705
  os.makedirs("results", exist_ok=True)
706
 
707
  print("=" * 70)
708
- print("Feature-Driven Code-to-Image Generator (Final Version)")
709
- print("Features → Meaningful Visual Attributes")
710
  print("=" * 70)
711
 
712
  # Load features
713
- print("\n[1/5] Loading features...")
714
  features_path = "data/features.npy"
715
  if not os.path.exists(features_path):
716
  raise FileNotFoundError(f"{features_path} not found")
@@ -719,33 +707,31 @@ def main():
719
  print(f"Features shape: {features.shape}")
720
 
721
  # Prepare data
722
- print("\n[2/5] Preparing data loaders...")
723
  loader = prepare_data_loaders(features, batch_size=32)
724
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
725
  print(f"Device: {device}")
726
 
727
  # Initialize model
728
- print("\n[3/5] Initializing feature-driven CNN model...")
729
- model = ChaoticCoherentGenerator(input_dim=features.shape[1], image_size=32)
730
  param_count = sum(p.numel() for p in model.parameters())
731
  print(f"Model parameters: {param_count:,}")
732
 
733
  # Train model
734
- print("\n[4/5] Training model...")
735
- trainer = FeatureDrivenTrainer(model, device=device)
736
  trainer.train(loader, epochs=100, lr=1e-3)
737
 
738
  # Save results
739
  print("\n[5/5] Saving results...")
740
- trainer.save_model("models/chotic.pth")
741
- trainer.plot_losses("results/chotic_losses.png")
742
- trainer.generate_samples(features, n=25, save_path="results/chotic_samples.png")
743
- trainer.test_feature_consistency(features, save_path="results/chotic_consistency.png")
744
 
745
  print("\n" + "=" * 70)
746
- print("Complete! Check results/ folder.")
747
- print("=" * 70)
748
-
749
 
750
  if __name__ == "__main__":
751
  main()
 
7
  import matplotlib.pyplot as plt
8
 
9
 
 
10
 
11
  FEATURE_NAMES = [
12
  "lines_of_code", "num_functions", "num_classes", "num_loops",
 
19
  ]
20
 
21
 
 
22
 
23
  class CodeToImageDataset(Dataset):
24
  def __init__(self, features):
 
31
  return self.features[idx]
32
 
33
 
 
34
 
35
+ class Generator(nn.Module):
36
  """
37
+ generates images with controlled chaos through:
38
+ Frequency-domain decomposition (low/mid/high)
39
+ Spatial composition layers
40
+ Feature-dependent color palettes
41
+ Texture injection at multiple scales
42
  """
43
  def __init__(self, input_dim=25, image_size=32):
44
  super().__init__()
45
  self.image_size = image_size
46
 
47
+ # Feature groupin
48
  self.structure_idx = [0, 1, 2, 10, 19, 20]
49
  self.control_idx = [3, 4, 9]
50
  self.operations_idx = [5, 7, 8, 6]
51
  self.style_idx = [14, 15, 24, 23]
52
  self.advanced_idx = [12, 13, 16, 17, 18, 11, 21, 22]
53
 
54
+ #overall composition, large shapes (structure features)
55
  self.low_freq_encoder = nn.Sequential(
56
  nn.Linear(len(self.structure_idx), 256),
57
  nn.LayerNorm(256),
 
60
  nn.Linear(256, 512)
61
  )
62
 
63
+ # patterns, textures (control + operations)
64
  self.mid_freq_encoder = nn.Sequential(
65
  nn.Linear(len(self.control_idx) + len(self.operations_idx), 256),
66
  nn.LayerNorm(256),
 
69
  nn.Linear(256, 512)
70
  )
71
 
72
+ # details, noise (style + advanced)
73
  self.high_freq_encoder = nn.Sequential(
74
  nn.Linear(len(self.style_idx) + len(self.advanced_idx), 256),
75
  nn.LayerNorm(256),
 
85
  nn.Linear(128, 12) # 4 colors × 3 channels (HSV or RGB)
86
  )
87
 
88
+ # WHERE patterns appear (position, scale, rotation)
89
  self.composition_net = nn.Sequential(
90
  nn.Linear(512, 256),
91
  nn.LayerNorm(256),
 
93
  nn.Linear(256, 6) # [center_x, center_y, scale, rotation, flow_x, flow_y]
94
  )
95
 
96
+ # projection
97
  self.low_to_spatial = nn.Linear(512, 8 * 8 * 128)
98
  self.mid_to_spatial = nn.Linear(512, 16 * 16 * 64)
99
  self.high_to_spatial = nn.Linear(512, 32 * 32 * 32)
100
 
101
+ # Low(8x8 -> 32x32)
102
  self.low_decoder = nn.Sequential(
103
  nn.ConvTranspose2d(128, 64, 4, 2, 1), # 8->16
104
  nn.GroupNorm(8, 64),
 
108
  nn.GELU()
109
  )
110
 
111
+ # Mid(16x16 -> 32x32)
112
  self.mid_decoder = nn.Sequential(
113
  nn.ConvTranspose2d(64, 32, 4, 2, 1), # 16->32
114
  nn.GroupNorm(8, 32),
115
  nn.GELU()
116
  )
117
 
118
+ # High pth (32x32)
119
  self.high_decoder = nn.Sequential(
120
  nn.Conv2d(32, 32, 3, 1, 1),
121
  nn.GroupNorm(8, 32),
 
123
  )
124
 
125
  self.fusion = nn.Sequential(
126
+ nn.Conv2d(96, 64, 3, 1, 1),
127
  nn.GroupNorm(8, 64),
128
  nn.GELU(),
129
  nn.Conv2d(64, 32, 3, 1, 1),
 
149
  nn.init.constant_(m.bias, 0)
150
 
151
  def generate_perlin_noise(self, batch_size, size, device):
 
152
  grid_size = 4
153
  grid = torch.randn(batch_size, 2, grid_size, grid_size, device=device)
154
  noise = F.interpolate(grid, size=(size, size), mode='bicubic', align_corners=True)
155
+ return noise[:, 0:1] # single channel
156
 
157
  def apply_color_palette(self, grayscale, palette):
 
158
  batch_size = grayscale.size(0)
159
  palette = palette.view(batch_size, 4, 3) # 4 clors, 3 channels
160
 
 
167
  channel_colors = palette[:, :, i] # (batch, 4)
168
  stops = torch.linspace(0, 1, 4, device=grayscale.device)
169
 
 
170
  result = torch.zeros_like(gray_norm)
171
  for j in range(3):
172
  mask = (gray_norm >= stops[j]) & (gray_norm < stops[j+1])
 
185
  batch_size = x.size(0)
186
  device = x.device
187
 
188
+ # extract feature groups
189
  structure = x[:, self.structure_idx]
190
  control = x[:, self.control_idx]
191
  operations = x[:, self.operations_idx]
192
  style = x[:, self.style_idx]
193
  advanced = x[:, self.advanced_idx]
194
 
195
+ # encode to frequency bands
196
  low_freq = self.low_freq_encoder(structure)
197
  mid_freq = self.mid_freq_encoder(torch.cat([control, operations], dim=1))
198
  high_freq = self.high_freq_encoder(torch.cat([style, advanced], dim=1))
199
 
200
+ # generate color palette
201
  palette = self.color_palette_net(x)
202
 
203
+ # generate composition parameters
204
  composition_params = self.composition_net(low_freq)
205
  composition_params = torch.tanh(composition_params)
206
 
207
+ # Project
208
  low_spatial = self.low_to_spatial(low_freq).view(batch_size, 128, 8, 8)
209
  mid_spatial = self.mid_to_spatial(mid_freq).view(batch_size, 64, 16, 16)
210
  high_spatial = self.high_to_spatial(high_freq).view(batch_size, 32, 32, 32)
211
 
212
+ # Decode
213
  low_decoded = self.low_decoder(low_spatial) # (batch, 32, 32, 32)
214
  mid_decoded = self.mid_decoder(mid_spatial) # (batch, 32, 32, 32)
215
  high_decoded = self.high_decoder(high_spatial) # (batch, 32, 32, 32)
216
 
217
+ # Apply
 
218
  theta = composition_params[:, 3:4] # rotation
219
  scale = 0.5 + composition_params[:, 2:3] # scale
220
  tx, ty = composition_params[:, 0:1], composition_params[:, 1:2]
 
232
  grid = F.affine_grid(affine_matrix, low_decoded.size(), align_corners=False)
233
  low_decoded = F.grid_sample(low_decoded, grid, align_corners=False, padding_mode='reflection')
234
 
 
235
  merged = torch.cat([low_decoded, mid_decoded, high_decoded], dim=1)
236
+ fused = self.fusion(merged)
237
 
238
+ # Add controlled Perlin noise
239
  noise_scale = self.noise_strength(x)
240
  perlin = self.generate_perlin_noise(batch_size, 32, device)
241
  perlin = perlin.expand(-1, 3, -1, -1)
242
  fused = fused + noise_scale.view(-1, 1, 1, 1) * perlin * 0.3
243
 
244
+ # Convert to grayscale
245
  grayscale = fused.mean(dim=1, keepdim=True)
246
 
 
247
  colored = self.apply_color_palette(grayscale, palette)
248
 
 
249
  output = torch.tanh(colored)
250
 
251
  return output
252
 
253
+ def feature_consistency_loss(features, images):
254
+ """Ensure feature similarity leads to image similarity"""
255
+ batch_size = features.size(0)
256
+ if batch_size < 4:
257
+ return torch.tensor(0.0, device=features.device)
258
+
259
+ feat_dists = torch.cdist(features, features)
260
+ img_flat = images.view(batch_size, -1)
261
+ img_dists = torch.cdist(img_flat, img_flat)
262
+
263
+ feat_dists = feat_dists / (feat_dists.max() + 1e-8)
264
+ img_dists = img_dists / (img_dists.max() + 1e-8)
265
+
266
+ consistency = F.mse_loss(feat_dists, img_dists)
267
+
268
+ return consistency
269
 
270
  def frequency_coherence_loss(images):
 
271
  # FFT decomposition
272
  fft = torch.fft.fft2(images)
273
  fft_shifted = torch.fft.fftshift(fft)
 
293
 
294
 
295
  def color_harmony_loss(images):
 
296
  batch_size = images.size(0)
297
 
298
  # convert to LAB-like space (approximation)
 
310
 
311
 
312
  def texture_diversity_loss(images):
 
313
  textures = []
314
  for scale in [1, 2, 4]:
315
  if scale > 1:
 
325
 
326
  return F.relu(0.1 - diversity) # Penalize if diversity < 0.1
327
 
328
+ def safe_item(x):
329
+ return x.item() if isinstance(x, torch.Tensor) else float(x)
330
 
331
  def combined_aesthetic_loss(features, images):
 
332
 
333
  consistency = feature_consistency_loss(features, images)
334
  variance = feature_variance_preservation_loss(features, images)
 
341
  smoothness = tv_h + tv_w
342
 
343
  total = (
344
+ 10.0 * consistency +
345
+ 0.05 * variance +
346
+ 0.5 * freq_coherence +
347
+ 0.2 * color_harmony +
348
+ 0.01 * texture_div +
349
+ 0.1 * smoothness
350
  )
351
 
352
  return total, {
353
+ 'consistency': safe_item(consistency),
354
+ 'variance': safe_item(variance),
355
+ 'freq_coherence': safe_item(freq_coherence),
356
+ 'color_harmony': safe_item(color_harmony),
357
+ 'texture_div': safe_item(texture_div),
358
+ 'smoothness': safe_item(smoothness)
359
  }
360
 
361
  def frequency_coherence_loss(images):
 
387
 
388
 
389
  def feature_variance_preservation_loss(features, images):
390
+ """this function ensure each feature dimension contributes to image variance"""
391
  batch_size = features.size(0)
392
  if batch_size < 4:
393
  return torch.tensor(0.0, device=features.device)
 
419
 
420
  # ===================== Trainer =====================
421
  def create_feature_explanation_grid(model, features, save_path):
422
+
 
 
 
423
  model.eval()
424
+ base_idx = 42 # random sample
425
  base_features = features[base_idx].copy()
426
 
427
+ # 6 features to visualize
428
  feature_indices = [0, 3, 10, 14, 19, 24] # LOC, loops, nesting, comments, complexity, indentation
429
  feature_names = [FEATURE_NAMES[i] for i in feature_indices]
430
 
431
  fig, axes = plt.subplots(len(feature_indices), 5, figsize=(15, 3*len(feature_indices)))
432
 
433
  for row, feat_idx in enumerate(feature_indices):
434
+ # create variations: [-2σ, -1σ, 0, +1σ, +2σ]
435
  variations = []
436
  for multiplier in [-2, -1, 0, 1, 2]:
437
  varied = base_features.copy()
 
465
  # Normalize to [0, 1]
466
  f_norm = (features - features.min(0)) / (features.max(0) - features.min(0) + 1e-8)
467
 
 
468
  structure_score = f_norm[[0, 1, 2]].mean()
469
  base_hue = structure_score * 0.7 # 0 (red) to 0.7 (blue)
470
 
 
471
  control_score = f_norm[[3, 4]].mean()
472
  saturation = 0.3 + control_score * 0.6
473
 
 
474
  style_score = f_norm[[14, 15]].mean()
475
  value = 0.4 + style_score * 0.5
476
 
 
477
  palette = []
478
  for shift in [-0.1, 0, 0.1, 0.2]:
479
  hue = (base_hue + shift) % 1.0
480
  palette.append(hsv_to_rgb(hue, saturation, value))
481
 
482
+ return np.array(palette).flatten()
483
 
484
  def interpolate_code_features(model, features, idx1, idx2, steps=10):
 
485
  model.eval()
486
 
487
  f1 = features[idx1]
 
499
  return images
500
 
501
 
502
+ class Trainer:
503
  def __init__(self, model, device='cpu'):
504
  self.model = model.to(device)
505
  self.device = device
 
558
  min_lr=1e-6
559
  )
560
 
561
+ print(f"Training model for {epochs} epochs on {self.device}...")
 
562
 
563
  for epoch in range(epochs):
564
  avg_loss, components = self.train_epoch(train_loader, optimizer)
565
+ scheduler.step(avg_loss)
566
 
567
  if (epoch + 1) % 10 == 0:
568
  print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f}")
 
635
  print(f"Samples saved to {save_path}")
636
 
637
  def test_feature_consistency(self, features, save_path='results/consistency_test.png'):
638
+
 
 
639
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
640
  self.model.eval()
641
 
 
680
  print(f"Consistency test saved to {save_path}")
681
 
682
 
 
683
 
684
  def prepare_data_loaders(features, batch_size=32):
685
  dataset = CodeToImageDataset(features)
 
688
  return loader
689
 
690
 
 
691
 
692
  def main():
693
  os.makedirs("models", exist_ok=True)
694
  os.makedirs("results", exist_ok=True)
695
 
696
  print("=" * 70)
697
+ print(" Code-to-Image Generator (Final Version)")
 
698
  print("=" * 70)
699
 
700
  # Load features
701
+ print("\n Loading features...")
702
  features_path = "data/features.npy"
703
  if not os.path.exists(features_path):
704
  raise FileNotFoundError(f"{features_path} not found")
 
707
  print(f"Features shape: {features.shape}")
708
 
709
  # Prepare data
710
+ print("\nPreparing data loaders...")
711
  loader = prepare_data_loaders(features, batch_size=32)
712
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
713
  print(f"Device: {device}")
714
 
715
  # Initialize model
716
+ print("\n InitializingCNN model...")
717
+ model = Generator(input_dim=features.shape[1], image_size=32)
718
  param_count = sum(p.numel() for p in model.parameters())
719
  print(f"Model parameters: {param_count:,}")
720
 
721
  # Train model
722
+ print("\nTraining model...")
723
+ trainer = Trainer(model, device=device)
724
  trainer.train(loader, epochs=100, lr=1e-3)
725
 
726
  # Save results
727
  print("\n[5/5] Saving results...")
728
+ trainer.save_model("models/model_32.pth")
729
+ trainer.plot_losses("results/losses.png")
730
+ trainer.generate_samples(features, n=25, save_path="results/samples.png")
731
+ trainer.test_feature_consistency(features, save_path="results/consistency.png")
732
 
733
  print("\n" + "=" * 70)
734
+ print("Check results/ folder.")
 
 
735
 
736
  if __name__ == "__main__":
737
  main()
models/{chotic.pth → model_32.pth} RENAMED
File without changes