| | """ |
| | Skybox generator: text → 2:1 equirectangular image (Stable Diffusion, local). |
| | Uses FP16 to reduce VRAM. Output 1024x512 or 2048x1024. |
| | """ |
| |
|
| | import os |
| | import time |
| | from pathlib import Path |
| |
|
| | import torch |
| |
|
| | |
| | DEFAULT_MODEL_ID = "runwayml/stable-diffusion-v1-5" |
| | FALLBACK_MODEL_ID = "runwayml/stable-diffusion-v1-5" |
| |
|
| |
|
| | def get_device() -> str: |
| | return "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| |
|
| | def _is_complete_sd_dir(path: Path) -> bool: |
| | """True if path looks like a complete Stable Diffusion pipeline (has unet weights).""" |
| | if not path.is_dir(): |
| | return False |
| | unet = path / "unet" |
| | if not unet.is_dir(): |
| | return False |
| | return any( |
| | (unet / f).exists() |
| | for f in ("diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.bin") |
| | ) |
| |
|
| |
|
| | def _default_local_weights_dir() -> str | None: |
| | """First complete SD folder under weights/ (sd-v1-5 or stable-diffusion-2-1-base).""" |
| | try: |
| | root = Path(__file__).resolve().parent.parent |
| | for name in ("sd-v1-5", "stable-diffusion-2-1-base"): |
| | local = root / "weights" / name |
| | if _is_complete_sd_dir(local): |
| | return str(local) |
| | return None |
| | except Exception: |
| | return None |
| |
|
| |
|
| | def _resolve_model_path_and_token(): |
| | """Use local path if set or default weights/ folder exists, else Hub id. Token from HF_TOKEN or huggingface-cli login.""" |
| | local = os.environ.get("SD_MODEL_PATH", "").strip() |
| | if local and os.path.isdir(local): |
| | return local, None |
| | default_local = _default_local_weights_dir() |
| | if default_local: |
| | return default_local, None |
| | model_id = os.environ.get("SD_MODEL_ID", DEFAULT_MODEL_ID) |
| | token = os.environ.get("HF_TOKEN") or True |
| | return model_id, token |
| |
|
| |
|
| | def generate_skybox( |
| | prompt: str, |
| | output_dir: str = "outputs", |
| | width: int = 1024, |
| | height: int = 512, |
| | seed: int | None = None, |
| | model_id: str | None = None, |
| | ) -> tuple[str, float, float]: |
| | """ |
| | Generate a 2:1 equirectangular skybox image from a text prompt. |
| | Returns (path_to_image, inference_time_sec, peak_vram_mb). |
| | """ |
| | from diffusers import StableDiffusionPipeline |
| |
|
| | device = get_device() |
| | dtype = torch.float16 if device == "cuda" else torch.float32 |
| |
|
| | Path(output_dir).mkdir(parents=True, exist_ok=True) |
| |
|
| | pretrained, token = _resolve_model_path_and_token() |
| | load_id = model_id or pretrained |
| | local_only = os.path.isdir(load_id) |
| | pipe = None |
| | last_error = None |
| |
|
| | def _load(pid: str, local: bool) -> bool: |
| | nonlocal pipe, last_error |
| | try: |
| | pipe = StableDiffusionPipeline.from_pretrained( |
| | pid, |
| | torch_dtype=dtype, |
| | safety_checker=None, |
| | token=None if local else (token or True), |
| | local_files_only=local, |
| | ) |
| | return True |
| | except Exception as err: |
| | last_error = err |
| | return False |
| |
|
| | if _load(load_id, local_only): |
| | pass |
| | elif not local_only and _load(FALLBACK_MODEL_ID, False): |
| | pass |
| | if pipe is None: |
| | raise RuntimeError( |
| | "Could not load Stable Diffusion. Need internet to download the model (first run).\n" |
| | " - Set HF_TOKEN=your_token if behind firewall (huggingface.co/settings/tokens)\n" |
| | " - Or download once: huggingface-cli download runwayml/stable-diffusion-v1-5 --local-dir ./weights/sd-v1-5" |
| | ) from last_error |
| |
|
| | pipe = pipe.to(device) |
| |
|
| | |
| | |
| | |
| |
|
| | if device == "cuda": |
| | torch.cuda.reset_peak_memory_stats() |
| | torch.cuda.synchronize() |
| |
|
| | generator = None |
| | if seed is not None: |
| | generator = torch.Generator(device=device).manual_seed(seed) |
| |
|
| | t0 = time.perf_counter() |
| | image = pipe( |
| | prompt=prompt, |
| | width=width, |
| | height=height, |
| | num_inference_steps=50, |
| | generator=generator, |
| | ).images[0] |
| |
|
| | if device == "cuda": |
| | torch.cuda.synchronize() |
| | t1 = time.perf_counter() |
| | inference_time = t1 - t0 |
| | peak_vram_mb = ( |
| | torch.cuda.max_memory_allocated() / 1024 / 1024 |
| | if device == "cuda" |
| | else 0.0 |
| | ) |
| |
|
| | |
| | safe_name = "".join(c if c.isalnum() or c in " -_" else "_" for c in prompt)[:60] |
| | out_path = os.path.join(output_dir, f"skybox_{safe_name.strip()}.png") |
| | image.save(out_path) |
| |
|
| | return out_path, inference_time, peak_vram_mb |
| |
|