AxionLab-official commited on
Commit
19267db
Β·
verified Β·
1 Parent(s): fbce1fc

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +41 -80
README.md CHANGED
@@ -75,44 +75,59 @@ Limitations:
75
 
76
  ---
77
 
 
 
78
  ## πŸš€ Usage
79
 
80
  ### Generate images
81
 
82
- ## THE INITIAL IDEA WAS A STUDENT U-NET FROM A TEACHER U-NET, BUT THIS WAS DISCONTINUED BECAUSE THE TEACHER WAS INITIALIZATED WITH RANDOM WEIGHTS, THAT WOULD KILL THE STUDENT LEARNING
83
 
84
 
85
  ```python
86
 
87
- import argparse
88
  import torch
89
  from pathlib import Path
 
90
 
 
 
 
 
 
 
 
 
91
  from train import StudentUNet, DDPMScheduler, Config
92
 
 
 
93
 
94
- # ----------------------------
95
- # Sampling (com controle de steps)
96
- # ----------------------------
97
- @torch.no_grad()
98
- def generate_samples(model, scheduler, n=4, steps=50, device="cpu", dtype=torch.float32):
99
- model.eval()
100
 
101
- x = torch.randn(n, 3, cfg.image_size, cfg.image_size, device=device, dtype=dtype)
 
 
 
 
 
102
 
103
- # πŸ”₯ Usa menos steps β†’ muito mais rΓ‘pido
104
  step_size = scheduler.T // steps
105
- timesteps = list(range(0, scheduler.T, step_size))
106
- timesteps = list(reversed(timesteps))
107
 
108
  for t_val in timesteps:
109
- t = torch.full((n,), t_val, device=device, dtype=torch.long)
 
110
  noise_pred = model(x, t)
111
 
112
  if t_val > 0:
113
- ab = scheduler.alpha_bar[t_val].to(x.dtype)
114
  prev_t = max(t_val - step_size, 0)
115
- ab_prev = scheduler.alpha_bar[prev_t].to(x.dtype)
116
 
117
  beta_t = 1.0 - (ab / ab_prev)
118
  alpha_t = 1.0 - beta_t
@@ -121,83 +136,28 @@ def generate_samples(model, scheduler, n=4, steps=50, device="cpu", dtype=torch.
121
  x - (beta_t / (1.0 - ab).sqrt()) * noise_pred
122
  )
123
 
124
- sigma = beta_t.sqrt()
125
- x = mean + sigma * torch.randn_like(x)
126
  else:
127
  x = scheduler.predict_x0(x, noise_pred, t)
128
 
129
- model.train()
130
  return x.clamp(-1, 1)
131
 
 
132
 
133
- # ----------------------------
134
- # Save
135
- # ----------------------------
136
- def save_samples(samples, path: Path):
137
- samples = (samples + 1) / 2
138
- samples = (samples * 255).byte().permute(0, 2, 3, 1).cpu().numpy()
139
-
140
- from PIL import Image
141
-
142
- n = len(samples)
143
- w = samples.shape[1]
144
 
145
- grid = Image.new("RGB", (n * w, w))
146
 
147
- for i, s in enumerate(samples):
148
- grid.paste(Image.fromarray(s), (i * w, 0))
149
 
150
- path.parent.mkdir(parents=True, exist_ok=True)
151
- grid.save(path)
152
 
 
153
 
154
- # ----------------------------
155
- # Main
156
- # ----------------------------
157
- def main():
158
- parser = argparse.ArgumentParser()
159
- parser.add_argument("--checkpoint", type=str, required=True)
160
- parser.add_argument("--n_images", type=int, default=8)
161
- parser.add_argument("--steps", type=int, default=50)
162
- parser.add_argument("--seed", type=int, default=42)
163
- parser.add_argument("--out", type=str, default="outputs/generated.png")
164
 
165
- args = parser.parse_args()
166
-
167
- # Seed
168
- torch.manual_seed(args.seed)
169
-
170
- # Load checkpoint
171
- ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
172
- global cfg
173
- cfg = ckpt.get("config", Config())
174
-
175
- # Model
176
- model = StudentUNet(cfg)
177
- model.load_state_dict(ckpt["model_state"])
178
- model.eval()
179
-
180
- # Scheduler
181
- scheduler = DDPMScheduler(cfg.timesteps, cfg.beta_start, cfg.beta_end)
182
-
183
- print(f"\nπŸš€ Generating {args.n_images} images")
184
- print(f"βš™οΈ Steps: {args.steps} | Seed: {args.seed}")
185
-
186
- samples = generate_samples(
187
- model,
188
- scheduler,
189
- n=args.n_images,
190
- steps=args.steps,
191
- dtype=cfg.dtype
192
- )
193
-
194
- save_samples(samples, Path(args.out))
195
-
196
- print(f"βœ… Saved to: {args.out}")
197
-
198
-
199
- if __name__ == "__main__":
200
- main()
201
  ```
202
 
203
  ```bash
@@ -206,6 +166,7 @@ python generate.py \
206
  --n_images 8 \
207
  --steps 50 \
208
  --seed 42
 
209
 
210
  πŸ“ Output
211
 
 
75
 
76
  ---
77
 
78
+ ## THE INITIAL IDEA WAS A STUDENT U-NET FROM A TEACHER U-NET, BUT THIS WAS DISCONTINUED BECAUSE THE TEACHER WAS INITIALIZATED WITH RANDOM WEIGHTS, THAT WOULD KILL THE STUDENT LEARNING
79
+
80
  ## πŸš€ Usage
81
 
82
  ### Generate images
83
 
 
84
 
85
 
86
  ```python
87
 
 
88
  import torch
89
  from pathlib import Path
90
+ from PIL import Image
91
 
92
+ # ===== CONFIG =====
93
+ CHECKPOINT = "model.pt"
94
+ N_IMAGES = 8
95
+ STEPS = 50
96
+ SEED = 42
97
+ OUT = "generated.png"
98
+
99
+ # ===== IMPORT MODEL =====
100
  from train import StudentUNet, DDPMScheduler, Config
101
 
102
+ # ===== LOAD =====
103
+ torch.manual_seed(SEED)
104
 
105
+ ckpt = torch.load(CHECKPOINT, map_location="cpu")
106
+ cfg = ckpt.get("config", Config())
107
+
108
+ model = StudentUNet(cfg)
109
+ model.load_state_dict(ckpt["model_state"])
110
+ model.eval()
111
 
112
+ scheduler = DDPMScheduler(cfg.timesteps, cfg.beta_start, cfg.beta_end)
113
+
114
+ # ===== SAMPLING =====
115
+ @torch.no_grad()
116
+ def sample(model, scheduler, n, steps):
117
+ x = torch.randn(n, 3, cfg.image_size, cfg.image_size)
118
 
 
119
  step_size = scheduler.T // steps
120
+ timesteps = list(range(0, scheduler.T, step_size))[::-1]
 
121
 
122
  for t_val in timesteps:
123
+ t = torch.full((n,), t_val, dtype=torch.long)
124
+
125
  noise_pred = model(x, t)
126
 
127
  if t_val > 0:
128
+ ab = scheduler.alpha_bar[t_val]
129
  prev_t = max(t_val - step_size, 0)
130
+ ab_prev = scheduler.alpha_bar[prev_t]
131
 
132
  beta_t = 1.0 - (ab / ab_prev)
133
  alpha_t = 1.0 - beta_t
 
136
  x - (beta_t / (1.0 - ab).sqrt()) * noise_pred
137
  )
138
 
139
+ x = mean + beta_t.sqrt() * torch.randn_like(x)
 
140
  else:
141
  x = scheduler.predict_x0(x, noise_pred, t)
142
 
 
143
  return x.clamp(-1, 1)
144
 
145
+ samples = sample(model, scheduler, N_IMAGES, STEPS)
146
 
147
+ # ===== SAVE =====
148
+ samples = (samples + 1) / 2
149
+ samples = (samples * 255).byte().permute(0, 2, 3, 1).numpy()
 
 
 
 
 
 
 
 
150
 
151
+ grid = Image.new("RGB", (cfg.image_size * N_IMAGES, cfg.image_size))
152
 
153
+ for i, img in enumerate(samples):
154
+ grid.paste(Image.fromarray(img), (i * cfg.image_size, 0))
155
 
156
+ grid.save(OUT)
 
157
 
158
+ print(f"βœ… Saved to {OUT}")
159
 
 
 
 
 
 
 
 
 
 
 
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  ```
162
 
163
  ```bash
 
166
  --n_images 8 \
167
  --steps 50 \
168
  --seed 42
169
+ ```
170
 
171
  πŸ“ Output
172