Felixstro-dev commited on
Commit
e2135e4
Β·
verified Β·
1 Parent(s): f13dc80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -30
app.py CHANGED
@@ -4,8 +4,14 @@ Minecraft Skin Generator – HuggingFace Spaces Demo
4
  LΓ€dt model.pt (EMA-Gewichte) aus dem Repo und generiert Skins per Prompt.
5
  BenΓΆtigte Dateien im Space-Repo:
6
  app.py ← diese Datei
7
- model.pt ← dein exportiertes EMA-Modell
8
  requirements.txt
 
 
 
 
 
 
9
  """
10
 
11
  import math
@@ -111,7 +117,7 @@ def tags_to_vector(tags: list) -> torch.Tensor:
111
  if t in TAG2IDX: vec[TAG2IDX[t]] = 1.0
112
  return vec
113
 
114
- # ─── UV-Masken ────────────────────────────────────────────────────────────────
115
  SKIN_REGIONS = {
116
  "head": (0, 0, 32, 16),
117
  "body": (16, 16, 40, 32),
@@ -129,28 +135,30 @@ OVERLAY_REGIONS = {
129
  "leg_l_overlay": (0, 48, 16, 64),
130
  }
131
 
132
- def _build_base_mask(device):
133
- mask = torch.zeros(1, 1, IMG_SIZE, IMG_SIZE, device=device)
 
134
  for x1,y1,x2,y2 in SKIN_REGIONS.values():
135
  mask[0,0,y1:y2,x1:x2] = 1.0
136
  return mask
137
 
138
- def _build_overlay_mask(device):
139
- mask = torch.zeros(1, 1, IMG_SIZE, IMG_SIZE, device=device)
140
  for x1,y1,x2,y2 in OVERLAY_REGIONS.values():
141
  mask[0,0,y1:y2,x1:x2] = 1.0
142
  return mask
143
 
 
144
  def force_alpha_mask(img: torch.Tensor) -> torch.Tensor:
145
- base = _build_base_mask(img.device)
146
- overlay = _build_overlay_mask(img.device)
147
- outside = (1.0 - base - overlay).clamp(0, 1)
148
- alpha = (
149
- base * torch.ones_like(img[:, 3:4])
150
- + overlay * img[:, 3:4]
151
- + outside * torch.full_like(img[:, 3:4], -1.0)
152
  )
153
- return torch.cat([img[:, :3], alpha], dim=1)
154
 
155
  # ─── UNet (identisch mit train_diffusion.py) ──────────────────────────────────
156
  class SinusoidalPE(nn.Module):
@@ -159,9 +167,10 @@ class SinusoidalPE(nn.Module):
159
  self.dim = dim
160
 
161
  def forward(self, t):
162
- half = self.dim // 2
163
- freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device) / half)
164
- args = t[:, None].float() * freqs[None]
 
165
  return torch.cat([args.sin(), args.cos()], dim=-1)
166
 
167
 
@@ -308,8 +317,8 @@ class DiffusionSchedule:
308
  c2 = torch.cat([cond, null_cond])
309
 
310
  out = model(x2, t2, c2)
311
- n_cond, n_uncond = out.chunk(2)
312
- noise_pred = n_uncond + guidance_scale * (n_cond - n_uncond)
313
 
314
  alpha = self.alphas[t_idx]
315
  alpha_bar = self.alphas_cumprod[t_idx]
@@ -335,25 +344,44 @@ class DiffusionSchedule:
335
  device = "cuda" if torch.cuda.is_available() else "cpu"
336
  print(f"Device: {device}")
337
 
338
- ckpt = torch.load("model.pt", map_location=device, weights_only=False)
339
- base_ch = ckpt.get("base_ch", 96)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  if base_ch is None:
341
  for key in ("enc_in.weight", "_orig_mod.enc_in.weight"):
342
- sd_check = ckpt.get("model", ckpt)
343
- if key in sd_check:
344
- base_ch = sd_check[key].shape[0]
345
  break
346
- base_ch = base_ch or 96
 
 
347
 
348
  model = UNet(base_ch=base_ch).to(device)
349
- sd = ckpt.get("model", ckpt)
350
  model.load_state_dict(sd, strict=False)
351
  model.eval()
352
- try: torch._dynamo.disable(model)
353
- except Exception: pass
 
 
 
354
 
355
  schedule = DiffusionSchedule(device=device)
356
- print(f"Modell geladen: base_ch={base_ch}, {sum(p.numel() for p in model.parameters())/1e6:.1f}M Parameter")
 
357
 
358
 
359
  # ─── Generierungs-Funktion ────────────────────────────────────────────────────
@@ -415,7 +443,7 @@ Generiert 64Γ—64 Minecraft Skins aus einem Text-Prompt. Trainiert mit DDPM auf ~
415
  seed = gr.Slider(label="Seed", minimum=0, maximum=2**31,step=1, value=42)
416
  rand_seed = gr.Checkbox(label="Seed zufΓ€llig", value=True)
417
 
418
- tag_info = gr.Text(label="Erkannte Tags", interactive=False)
419
  seed_out = gr.Number(label="Verwendeter Seed", interactive=False)
420
 
421
  with gr.Column(scale=3):
 
4
  LΓ€dt model.pt (EMA-Gewichte) aus dem Repo und generiert Skins per Prompt.
5
  BenΓΆtigte Dateien im Space-Repo:
6
  app.py ← diese Datei
7
+ model.pt ← mit export_ema_model.py exportiert (EMA-Gewichte!)
8
  requirements.txt
9
+
10
+ FIXES gegenΓΌber der alten app.py:
11
+ [FIX 1] EMA-Gewichte werden korrekt priorisiert (ckpt["ema"] vor ckpt["model"])
12
+ [FIX 2] base_ch Fallback-Kette ist identisch mit train_diffusion.py (Default 96 statt 128)
13
+ [FIX 3] _build_base_mask / _build_overlay_mask ohne device-Parameter (wie im Training)
14
+ [FIX 4] force_alpha_mask identisch mit train_diffusion.py
15
  """
16
 
17
  import math
 
117
  if t in TAG2IDX: vec[TAG2IDX[t]] = 1.0
118
  return vec
119
 
120
+ # ─── UV-Masken (identisch mit train_diffusion.py) ─────────────────────────────
121
  SKIN_REGIONS = {
122
  "head": (0, 0, 32, 16),
123
  "body": (16, 16, 40, 32),
 
135
  "leg_l_overlay": (0, 48, 16, 64),
136
  }
137
 
138
+ # [FIX 3] Keine device-Parameter – identisch mit train_diffusion.py
139
+ def _build_base_mask():
140
+ mask = torch.zeros(1, 1, IMG_SIZE, IMG_SIZE)
141
  for x1,y1,x2,y2 in SKIN_REGIONS.values():
142
  mask[0,0,y1:y2,x1:x2] = 1.0
143
  return mask
144
 
145
+ def _build_overlay_mask():
146
+ mask = torch.zeros(1, 1, IMG_SIZE, IMG_SIZE)
147
  for x1,y1,x2,y2 in OVERLAY_REGIONS.values():
148
  mask[0,0,y1:y2,x1:x2] = 1.0
149
  return mask
150
 
151
+ # [FIX 4] force_alpha_mask identisch mit train_diffusion.py (device ΓΌber .to())
152
  def force_alpha_mask(img: torch.Tensor) -> torch.Tensor:
153
+ base_mask = _build_base_mask().to(img.device)
154
+ overlay_mask = _build_overlay_mask().to(img.device)
155
+ outside_mask = (1.0 - base_mask - overlay_mask).clamp(0, 1)
156
+ alpha_new = (
157
+ base_mask * torch.ones_like(img[:, 3:4])
158
+ + overlay_mask * img[:, 3:4]
159
+ + outside_mask * torch.full_like(img[:, 3:4], -1.0)
160
  )
161
+ return torch.cat([img[:, :3], alpha_new], dim=1)
162
 
163
  # ─── UNet (identisch mit train_diffusion.py) ──────────────────────────────────
164
  class SinusoidalPE(nn.Module):
 
167
  self.dim = dim
168
 
169
  def forward(self, t):
170
+ device = t.device
171
+ half = self.dim // 2
172
+ freqs = torch.exp(-math.log(10000) * torch.arange(half, device=device) / half)
173
+ args = t[:, None].float() * freqs[None]
174
  return torch.cat([args.sin(), args.cos()], dim=-1)
175
 
176
 
 
317
  c2 = torch.cat([cond, null_cond])
318
 
319
  out = model(x2, t2, c2)
320
+ noise_cond, noise_uncond = out.chunk(2)
321
+ noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
322
 
323
  alpha = self.alphas[t_idx]
324
  alpha_bar = self.alphas_cumprod[t_idx]
 
344
  device = "cuda" if torch.cuda.is_available() else "cpu"
345
  print(f"Device: {device}")
346
 
347
+ ckpt = torch.load("model.pt", map_location=device, weights_only=False)
348
+ print(f"Checkpoint Keys: {list(ckpt.keys())}")
349
+
350
+ # [FIX 1] EMA-Gewichte priorisieren – das ist der Hauptfehler der alten app.py!
351
+ # "ema" Key = EMA-Gewichte (beste QualitΓ€t, geglΓ€ttet)
352
+ # "model" Key = je nach Datei entweder EMA (bei latest.pt) oder rohe Gewichte (bei ep*.pt)
353
+ sd = ckpt.get("ema") or ckpt.get("model") or ckpt
354
+ if "ema" in ckpt:
355
+ print("βœ… Verwende EMA-Gewichte ('ema' Key) – beste QualitΓ€t")
356
+ elif "model" in ckpt:
357
+ print("ℹ️ Verwende 'model' Key (kein 'ema' Key gefunden)")
358
+ else:
359
+ print("⚠️ Kein 'ema' oder 'model' Key – versuche direktes Laden")
360
+
361
+ # [FIX 2] base_ch Fallback identisch mit train_diffusion.py
362
+ base_ch = ckpt.get("base_ch", None)
363
  if base_ch is None:
364
  for key in ("enc_in.weight", "_orig_mod.enc_in.weight"):
365
+ if key in sd:
366
+ base_ch = sd[key].shape[0]
367
+ print(f"base_ch aus state_dict ermittelt: {base_ch}")
368
  break
369
+ if base_ch is None:
370
+ base_ch = 96 # train_diffusion.py Default ist 96, nicht 128!
371
+ print(f"⚠️ base_ch nicht gefunden, verwende Default: {base_ch}")
372
 
373
  model = UNet(base_ch=base_ch).to(device)
 
374
  model.load_state_dict(sd, strict=False)
375
  model.eval()
376
+
377
+ try:
378
+ torch._dynamo.disable(model)
379
+ except Exception:
380
+ pass
381
 
382
  schedule = DiffusionSchedule(device=device)
383
+ num_params = sum(p.numel() for p in model.parameters()) / 1e6
384
+ print(f"Modell geladen: base_ch={base_ch}, {num_params:.1f}M Parameter")
385
 
386
 
387
  # ─── Generierungs-Funktion ────────────────────────────────────────────────────
 
443
  seed = gr.Slider(label="Seed", minimum=0, maximum=2**31,step=1, value=42)
444
  rand_seed = gr.Checkbox(label="Seed zufΓ€llig", value=True)
445
 
446
+ tag_info = gr.Text(label="Erkannte Tags", interactive=False)
447
  seed_out = gr.Number(label="Verwendeter Seed", interactive=False)
448
 
449
  with gr.Column(scale=3):