Update README.md
Browse files
README.md
CHANGED
|
@@ -75,44 +75,59 @@ Limitations:
|
|
| 75 |
|
| 76 |
---
|
| 77 |
|
|
|
|
|
|
|
| 78 |
## π Usage
|
| 79 |
|
| 80 |
### Generate images
|
| 81 |
|
| 82 |
-
## THE INITIAL IDEA WAS A STUDENT U-NET FROM A TEACHER U-NET, BUT THIS WAS DISCONTINUED BECAUSE THE TEACHER WAS INITIALIZATED WITH RANDOM WEIGHTS, THAT WOULD KILL THE STUDENT LEARNING
|
| 83 |
|
| 84 |
|
| 85 |
```python
|
| 86 |
|
| 87 |
-
import argparse
|
| 88 |
import torch
|
| 89 |
from pathlib import Path
|
|
|
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
from train import StudentUNet, DDPMScheduler, Config
|
| 92 |
|
|
|
|
|
|
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
-
# π₯ Usa menos steps β muito mais rΓ‘pido
|
| 104 |
step_size = scheduler.T // steps
|
| 105 |
-
timesteps = list(range(0, scheduler.T, step_size))
|
| 106 |
-
timesteps = list(reversed(timesteps))
|
| 107 |
|
| 108 |
for t_val in timesteps:
|
| 109 |
-
t = torch.full((n,), t_val,
|
|
|
|
| 110 |
noise_pred = model(x, t)
|
| 111 |
|
| 112 |
if t_val > 0:
|
| 113 |
-
ab = scheduler.alpha_bar[t_val]
|
| 114 |
prev_t = max(t_val - step_size, 0)
|
| 115 |
-
ab_prev = scheduler.alpha_bar[prev_t]
|
| 116 |
|
| 117 |
beta_t = 1.0 - (ab / ab_prev)
|
| 118 |
alpha_t = 1.0 - beta_t
|
|
@@ -121,83 +136,28 @@ def generate_samples(model, scheduler, n=4, steps=50, device="cpu", dtype=torch.
|
|
| 121 |
x - (beta_t / (1.0 - ab).sqrt()) * noise_pred
|
| 122 |
)
|
| 123 |
|
| 124 |
-
|
| 125 |
-
x = mean + sigma * torch.randn_like(x)
|
| 126 |
else:
|
| 127 |
x = scheduler.predict_x0(x, noise_pred, t)
|
| 128 |
|
| 129 |
-
model.train()
|
| 130 |
return x.clamp(-1, 1)
|
| 131 |
|
|
|
|
| 132 |
|
| 133 |
-
#
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
def save_samples(samples, path: Path):
|
| 137 |
-
samples = (samples + 1) / 2
|
| 138 |
-
samples = (samples * 255).byte().permute(0, 2, 3, 1).cpu().numpy()
|
| 139 |
-
|
| 140 |
-
from PIL import Image
|
| 141 |
-
|
| 142 |
-
n = len(samples)
|
| 143 |
-
w = samples.shape[1]
|
| 144 |
|
| 145 |
-
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
| 149 |
|
| 150 |
-
|
| 151 |
-
grid.save(path)
|
| 152 |
|
|
|
|
| 153 |
|
| 154 |
-
# ----------------------------
|
| 155 |
-
# Main
|
| 156 |
-
# ----------------------------
|
| 157 |
-
def main():
|
| 158 |
-
parser = argparse.ArgumentParser()
|
| 159 |
-
parser.add_argument("--checkpoint", type=str, required=True)
|
| 160 |
-
parser.add_argument("--n_images", type=int, default=8)
|
| 161 |
-
parser.add_argument("--steps", type=int, default=50)
|
| 162 |
-
parser.add_argument("--seed", type=int, default=42)
|
| 163 |
-
parser.add_argument("--out", type=str, default="outputs/generated.png")
|
| 164 |
|
| 165 |
-
args = parser.parse_args()
|
| 166 |
-
|
| 167 |
-
# Seed
|
| 168 |
-
torch.manual_seed(args.seed)
|
| 169 |
-
|
| 170 |
-
# Load checkpoint
|
| 171 |
-
ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
| 172 |
-
global cfg
|
| 173 |
-
cfg = ckpt.get("config", Config())
|
| 174 |
-
|
| 175 |
-
# Model
|
| 176 |
-
model = StudentUNet(cfg)
|
| 177 |
-
model.load_state_dict(ckpt["model_state"])
|
| 178 |
-
model.eval()
|
| 179 |
-
|
| 180 |
-
# Scheduler
|
| 181 |
-
scheduler = DDPMScheduler(cfg.timesteps, cfg.beta_start, cfg.beta_end)
|
| 182 |
-
|
| 183 |
-
print(f"\nπ Generating {args.n_images} images")
|
| 184 |
-
print(f"βοΈ Steps: {args.steps} | Seed: {args.seed}")
|
| 185 |
-
|
| 186 |
-
samples = generate_samples(
|
| 187 |
-
model,
|
| 188 |
-
scheduler,
|
| 189 |
-
n=args.n_images,
|
| 190 |
-
steps=args.steps,
|
| 191 |
-
dtype=cfg.dtype
|
| 192 |
-
)
|
| 193 |
-
|
| 194 |
-
save_samples(samples, Path(args.out))
|
| 195 |
-
|
| 196 |
-
print(f"β
Saved to: {args.out}")
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
if __name__ == "__main__":
|
| 200 |
-
main()
|
| 201 |
```
|
| 202 |
|
| 203 |
```bash
|
|
@@ -206,6 +166,7 @@ python generate.py \
|
|
| 206 |
--n_images 8 \
|
| 207 |
--steps 50 \
|
| 208 |
--seed 42
|
|
|
|
| 209 |
|
| 210 |
π Output
|
| 211 |
|
|
|
|
| 75 |
|
| 76 |
---
|
| 77 |
|
| 78 |
+
## THE INITIAL IDEA WAS A STUDENT U-NET FROM A TEACHER U-NET, BUT THIS WAS DISCONTINUED BECAUSE THE TEACHER WAS INITIALIZATED WITH RANDOM WEIGHTS, THAT WOULD KILL THE STUDENT LEARNING
|
| 79 |
+
|
| 80 |
## π Usage
|
| 81 |
|
| 82 |
### Generate images
|
| 83 |
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
```python
|
| 87 |
|
|
|
|
| 88 |
import torch
|
| 89 |
from pathlib import Path
|
| 90 |
+
from PIL import Image
|
| 91 |
|
| 92 |
+
# ===== CONFIG =====
|
| 93 |
+
CHECKPOINT = "model.pt"
|
| 94 |
+
N_IMAGES = 8
|
| 95 |
+
STEPS = 50
|
| 96 |
+
SEED = 42
|
| 97 |
+
OUT = "generated.png"
|
| 98 |
+
|
| 99 |
+
# ===== IMPORT MODEL =====
|
| 100 |
from train import StudentUNet, DDPMScheduler, Config
|
| 101 |
|
| 102 |
+
# ===== LOAD =====
|
| 103 |
+
torch.manual_seed(SEED)
|
| 104 |
|
| 105 |
+
ckpt = torch.load(CHECKPOINT, map_location="cpu")
|
| 106 |
+
cfg = ckpt.get("config", Config())
|
| 107 |
+
|
| 108 |
+
model = StudentUNet(cfg)
|
| 109 |
+
model.load_state_dict(ckpt["model_state"])
|
| 110 |
+
model.eval()
|
| 111 |
|
| 112 |
+
scheduler = DDPMScheduler(cfg.timesteps, cfg.beta_start, cfg.beta_end)
|
| 113 |
+
|
| 114 |
+
# ===== SAMPLING =====
|
| 115 |
+
@torch.no_grad()
|
| 116 |
+
def sample(model, scheduler, n, steps):
|
| 117 |
+
x = torch.randn(n, 3, cfg.image_size, cfg.image_size)
|
| 118 |
|
|
|
|
| 119 |
step_size = scheduler.T // steps
|
| 120 |
+
timesteps = list(range(0, scheduler.T, step_size))[::-1]
|
|
|
|
| 121 |
|
| 122 |
for t_val in timesteps:
|
| 123 |
+
t = torch.full((n,), t_val, dtype=torch.long)
|
| 124 |
+
|
| 125 |
noise_pred = model(x, t)
|
| 126 |
|
| 127 |
if t_val > 0:
|
| 128 |
+
ab = scheduler.alpha_bar[t_val]
|
| 129 |
prev_t = max(t_val - step_size, 0)
|
| 130 |
+
ab_prev = scheduler.alpha_bar[prev_t]
|
| 131 |
|
| 132 |
beta_t = 1.0 - (ab / ab_prev)
|
| 133 |
alpha_t = 1.0 - beta_t
|
|
|
|
| 136 |
x - (beta_t / (1.0 - ab).sqrt()) * noise_pred
|
| 137 |
)
|
| 138 |
|
| 139 |
+
x = mean + beta_t.sqrt() * torch.randn_like(x)
|
|
|
|
| 140 |
else:
|
| 141 |
x = scheduler.predict_x0(x, noise_pred, t)
|
| 142 |
|
|
|
|
| 143 |
return x.clamp(-1, 1)
|
| 144 |
|
| 145 |
+
samples = sample(model, scheduler, N_IMAGES, STEPS)
|
| 146 |
|
| 147 |
+
# ===== SAVE =====
|
| 148 |
+
samples = (samples + 1) / 2
|
| 149 |
+
samples = (samples * 255).byte().permute(0, 2, 3, 1).numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
+
grid = Image.new("RGB", (cfg.image_size * N_IMAGES, cfg.image_size))
|
| 152 |
|
| 153 |
+
for i, img in enumerate(samples):
|
| 154 |
+
grid.paste(Image.fromarray(img), (i * cfg.image_size, 0))
|
| 155 |
|
| 156 |
+
grid.save(OUT)
|
|
|
|
| 157 |
|
| 158 |
+
print(f"β
Saved to {OUT}")
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
```
|
| 162 |
|
| 163 |
```bash
|
|
|
|
| 166 |
--n_images 8 \
|
| 167 |
--steps 50 \
|
| 168 |
--seed 42
|
| 169 |
+
```
|
| 170 |
|
| 171 |
π Output
|
| 172 |
|