| """Image generation using SDXL + LoRA styles via fal.ai API. |
| |
| API counterpart to image_generator_hf.py (on-device diffusers). |
| Uses the fal-ai/lora endpoint which accepts HuggingFace LoRA repo IDs |
| directly, so styles.py works unchanged. |
| |
| Set FAL_KEY env var before use. |
| """ |
|
|
| import json |
| import time |
| from pathlib import Path |
| from typing import Optional |
|
|
| import requests |
| from dotenv import load_dotenv |
|
|
| from src.styles import get_style |
|
|
| load_dotenv() |
|
|
| |
| |
| |
|
|
| FAL_MODEL_ID = "fal-ai/lora" |
|
|
| BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" |
|
|
| WIDTH = 768 |
| HEIGHT = 1344 |
| NUM_STEPS = 30 |
| GUIDANCE_SCALE = 7.5 |
|
|
|
|
| def _build_loras(style: dict) -> list[dict]: |
| """Build the LoRA list for the fal.ai API from a style dict. |
| |
| Note: Hyper-SD speed LoRA is NOT used here (it's an on-device optimization |
| requiring specific scheduler config). fal.ai runs on fast GPUs so we use |
| standard settings (30 steps, DPM++ 2M Karras) instead. |
| """ |
| loras = [] |
|
|
| if style["source"] is not None: |
| |
| |
| loras.append({"path": style["source"], "scale": style["weight"]}) |
|
|
| return loras |
|
|
|
|
| def _download_image(url: str, output_path: Path, retries: int = 3) -> Path: |
| """Download an image from URL to a local file with retry.""" |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| for attempt in range(retries): |
| try: |
| resp = requests.get(url, timeout=120) |
| resp.raise_for_status() |
| with open(output_path, "wb") as f: |
| f.write(resp.content) |
| return output_path |
| except (requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e: |
| if attempt < retries - 1: |
| print(f" Download failed (attempt {attempt + 1}), retrying...") |
| else: |
| raise |
|
|
|
|
| def generate_image( |
| prompt: str, |
| negative_prompt: str = "", |
| loras: list[dict] | None = None, |
| seed: Optional[int] = None, |
| ) -> dict: |
| """Generate a single image via fal.ai API. |
| |
| Args: |
| prompt: SDXL prompt. |
| negative_prompt: Negative prompt. |
| loras: List of LoRA dicts with 'path' and 'scale'. |
| seed: Random seed. |
| |
| Returns: |
| API response dict with 'images' list and 'seed'. |
| """ |
| import fal_client |
|
|
| args = { |
| "model_name": BASE_MODEL, |
| "prompt": prompt, |
| "negative_prompt": negative_prompt, |
| "image_size": {"width": WIDTH, "height": HEIGHT}, |
| "num_inference_steps": NUM_STEPS, |
| "guidance_scale": GUIDANCE_SCALE, |
| "scheduler": "DPM++ 2M Karras", |
| "num_images": 1, |
| "image_format": "png", |
| "enable_safety_checker": False, |
| } |
| if loras: |
| args["loras"] = loras |
| if seed is not None: |
| args["seed"] = seed |
|
|
| result = fal_client.subscribe(FAL_MODEL_ID, arguments=args) |
| return result |
|
|
|
|
| def generate_all( |
| segments: list[dict], |
| output_dir: str | Path, |
| style_name: str = "Warm Sunset", |
| seed: int = 42, |
| progress_callback=None, |
| ) -> list[Path]: |
| """Generate images for all segments via fal.ai. |
| |
| Args: |
| segments: List of segment dicts (with 'prompt' and 'negative_prompt'). |
| output_dir: Directory to save images. |
| style_name: Style from styles.py registry. |
| seed: Base seed (incremented per segment). |
| |
| Returns: |
| List of saved image paths. |
| """ |
| style = get_style(style_name) |
| loras = _build_loras(style) |
| trigger = style["trigger"] |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| paths = [] |
| for seg in segments: |
| idx = seg["segment"] |
| path = output_dir / f"segment_{idx:03d}.png" |
|
|
| if path.exists(): |
| print(f" Segment {idx}/{len(segments)}: already exists, skipping") |
| paths.append(path) |
| continue |
|
|
| prompt = seg["prompt"] |
| if trigger: |
| prompt = f"{trigger} style, {prompt}" |
| neg = seg.get("negative_prompt", "") |
|
|
| print(f" Segment {idx}/{len(segments)}: generating image (fal.ai)...") |
| t0 = time.time() |
| result = generate_image(prompt, neg, loras=loras, seed=seed + idx) |
| elapsed = time.time() - t0 |
|
|
| image_url = result["images"][0]["url"] |
| _download_image(image_url, path) |
| paths.append(path) |
| print(f" Saved {path.name} ({elapsed:.1f}s)") |
| if progress_callback: |
| progress_callback(idx, len(segments)) |
|
|
| return paths |
|
|
|
|
| def run( |
| data_dir: str | Path, |
| style_name: str = "Warm Sunset", |
| seed: int = 42, |
| progress_callback=None, |
| ) -> list[Path]: |
| """Full image generation pipeline: read segments, generate via API, save. |
| |
| Args: |
| data_dir: Run directory containing segments.json. |
| style_name: Style from the registry (see src/styles.py). |
| seed: Base random seed. |
| |
| Returns: |
| List of saved image paths. |
| """ |
| data_dir = Path(data_dir) |
|
|
| with open(data_dir / "segments.json") as f: |
| segments = json.load(f) |
|
|
| paths = generate_all(segments, data_dir / "images", style_name, seed, progress_callback) |
|
|
| print(f"\nGenerated {len(paths)} images in {data_dir / 'images'}") |
| return paths |
|
|
|
|
| if __name__ == "__main__": |
| import os |
| import sys |
|
|
| if len(sys.argv) < 2: |
| print("Usage: python -m src.image_generator_api <data_dir> [style_name]") |
| print(' e.g. python -m src.image_generator_api data/Gone/run_001 "Warm Sunset"') |
| print("\nRequires FAL_KEY environment variable.") |
| sys.exit(1) |
|
|
| if not os.getenv("FAL_KEY"): |
| print("Error: FAL_KEY environment variable not set.") |
| print("Get your key at https://fal.ai/dashboard/keys") |
| sys.exit(1) |
|
|
| style = sys.argv[2] if len(sys.argv) > 2 else "Warm Sunset" |
| run(sys.argv[1], style_name=style) |
|
|