| """ |
| Prompt Sampler Module |
| |
| Provides stratified sampling of prompts for controlled experiments. |
| Ensures diversity across semantic categories while maintaining balance. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import random |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import List, Dict, Any, Optional |
| from collections import defaultdict |
|
|
|
|
| @dataclass |
| class PromptCategory: |
| """A semantic category for prompt stratification.""" |
| name: str |
| description: str |
| keywords: List[str] = field(default_factory=list) |
| examples: List[str] = field(default_factory=list) |
|
|
|
|
| |
| DEFAULT_CATEGORIES = [ |
| PromptCategory( |
| name="nature", |
| description="Natural environments, landscapes, wildlife", |
| keywords=["forest", "ocean", "mountain", "river", "wildlife", "sunset", "rain", "storm", "garden", "meadow"], |
| examples=[ |
| "A calm foggy forest at dawn with distant birds and soft wind", |
| "Ocean waves crashing against rocky cliffs under a stormy sky", |
| "A peaceful meadow with wildflowers swaying in gentle breeze", |
| ], |
| ), |
| PromptCategory( |
| name="urban", |
| description="City scenes, street life, architecture", |
| keywords=["city", "street", "building", "traffic", "neon", "market", "crowd", "subway", "cafe", "downtown"], |
| examples=[ |
| "A rainy neon-lit city street at night with reflections on wet pavement", |
| "Busy morning market with vendors and the sounds of commerce", |
| "Empty subway platform with echoing announcements", |
| ], |
| ), |
| PromptCategory( |
| name="abstract", |
| description="Abstract concepts, emotions, moods", |
| keywords=["peaceful", "chaotic", "melancholy", "joyful", "mysterious", "ethereal", "surreal", "dreamy"], |
| examples=[ |
| "An ethereal dreamscape with floating lights and distant echoes", |
| "The feeling of nostalgia on a quiet autumn afternoon", |
| "A surreal landscape where gravity seems optional", |
| ], |
| ), |
| PromptCategory( |
| name="action", |
| description="Dynamic scenes with movement and activity", |
| keywords=["running", "flying", "dancing", "fighting", "racing", "jumping", "working", "playing"], |
| examples=[ |
| "A horse galloping across an open prairie", |
| "Children playing in a park with laughter and shouts", |
| "Fireworks exploding over a celebration crowd", |
| ], |
| ), |
| PromptCategory( |
| name="domestic", |
| description="Indoor scenes, everyday life, home settings", |
| keywords=["kitchen", "bedroom", "office", "library", "home", "cozy", "fireplace", "window"], |
| examples=[ |
| "A cozy reading nook with rain pattering on the window", |
| "Morning kitchen scene with sizzling breakfast sounds", |
| "A quiet home office with soft keyboard clicks", |
| ], |
| ), |
| ] |
|
|
|
|
| @dataclass |
| class SampledPrompt: |
| """A prompt with its category and metadata.""" |
| text: str |
| category: str |
| source: str = "generated" |
| metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| """Convert to dictionary.""" |
| return { |
| "text": self.text, |
| "category": self.category, |
| "source": self.source, |
| "metadata": self.metadata, |
| } |
|
|
| @classmethod |
| def from_dict(cls, data: Dict[str, Any]) -> "SampledPrompt": |
| """Create from dictionary.""" |
| return cls(**data) |
|
|
|
|
| class PromptSampler: |
| """ |
| Stratified prompt sampler for controlled experiments. |
| |
| Ensures balanced representation across semantic categories |
| while allowing for custom prompts and external datasets. |
| """ |
|
|
| def __init__( |
| self, |
| categories: Optional[List[PromptCategory]] = None, |
| seed: int = 42, |
| ): |
| """ |
| Initialize sampler. |
| |
| Args: |
| categories: List of semantic categories (uses defaults if None) |
| seed: Random seed for reproducibility |
| """ |
| self.categories = categories or DEFAULT_CATEGORIES |
| self.category_names = [c.name for c in self.categories] |
| self.rng = random.Random(seed) |
|
|
| |
| self._keyword_to_category: Dict[str, str] = {} |
| for cat in self.categories: |
| for keyword in cat.keywords: |
| self._keyword_to_category[keyword.lower()] = cat.name |
|
|
| def categorize_prompt(self, prompt: str) -> str: |
| """ |
| Assign a category to a prompt based on keyword matching. |
| |
| Args: |
| prompt: The prompt text |
| |
| Returns: |
| Category name (or "other" if no match) |
| """ |
| prompt_lower = prompt.lower() |
|
|
| |
| category_scores: Dict[str, int] = defaultdict(int) |
|
|
| for keyword, category in self._keyword_to_category.items(): |
| if keyword in prompt_lower: |
| category_scores[category] += 1 |
|
|
| if category_scores: |
| return max(category_scores, key=lambda k: category_scores[k]) |
|
|
| return "other" |
|
|
| def sample_stratified( |
| self, |
| n_total: int, |
| prompts_per_category: Optional[int] = None, |
| custom_prompts: Optional[List[str]] = None, |
| include_examples: bool = True, |
| ) -> List[SampledPrompt]: |
| """ |
| Sample prompts with stratified distribution across categories. |
| |
| Args: |
| n_total: Total number of prompts to sample |
| prompts_per_category: Override for prompts per category |
| custom_prompts: Additional prompts to include (will be categorized) |
| include_examples: Whether to include category example prompts |
| |
| Returns: |
| List of SampledPrompt objects |
| """ |
| n_categories = len(self.categories) |
| per_category = prompts_per_category or (n_total // n_categories) |
|
|
| sampled: List[SampledPrompt] = [] |
|
|
| |
| custom_by_category: Dict[str, List[str]] = defaultdict(list) |
| if custom_prompts: |
| for prompt in custom_prompts: |
| category = self.categorize_prompt(prompt) |
| custom_by_category[category].append(prompt) |
|
|
| |
| for cat in self.categories: |
| category_prompts: List[str] = [] |
|
|
| |
| if include_examples: |
| category_prompts.extend(cat.examples) |
|
|
| |
| category_prompts.extend(custom_by_category.get(cat.name, [])) |
|
|
| |
| self.rng.shuffle(category_prompts) |
| selected = category_prompts[:per_category] |
|
|
| for text in selected: |
| sampled.append(SampledPrompt( |
| text=text, |
| category=cat.name, |
| source="example" if text in cat.examples else "custom", |
| )) |
|
|
| |
| other_prompts = custom_by_category.get("other", []) |
| if other_prompts: |
| self.rng.shuffle(other_prompts) |
| for text in other_prompts[:per_category]: |
| sampled.append(SampledPrompt( |
| text=text, |
| category="other", |
| source="custom", |
| )) |
|
|
| |
| self.rng.shuffle(sampled) |
| return sampled[:n_total] |
|
|
| def load_from_file( |
| self, |
| path: Path, |
| n_samples: Optional[int] = None, |
| ) -> List[SampledPrompt]: |
| """ |
| Load and categorize prompts from a file. |
| |
| Supports JSON (list or object with "prompts" key) and plain text (one per line). |
| |
| Args: |
| path: Path to prompts file |
| n_samples: Optional limit on number of prompts |
| |
| Returns: |
| List of SampledPrompt objects |
| """ |
| path = Path(path) |
|
|
| if path.suffix == ".json": |
| with path.open("r", encoding="utf-8") as f: |
| data = json.load(f) |
|
|
| if isinstance(data, list): |
| prompts = data |
| elif isinstance(data, dict) and "prompts" in data: |
| prompts = data["prompts"] |
| else: |
| raise ValueError(f"Unexpected JSON structure in {path}") |
|
|
| else: |
| |
| with path.open("r", encoding="utf-8") as f: |
| prompts = [line.strip() for line in f if line.strip()] |
|
|
| |
| sampled = [] |
| for prompt in prompts: |
| if isinstance(prompt, dict): |
| text = prompt.get("text") or prompt.get("prompt", "") |
| category = prompt.get("category") or self.categorize_prompt(text) |
| else: |
| text = str(prompt) |
| category = self.categorize_prompt(text) |
|
|
| sampled.append(SampledPrompt( |
| text=text, |
| category=category, |
| source=str(path.name), |
| )) |
|
|
| |
| self.rng.shuffle(sampled) |
| if n_samples: |
| sampled = sampled[:n_samples] |
|
|
| return sampled |
|
|
| def load_from_laion( |
| self, |
| laion_dir: Path = Path("data/laion"), |
| n_samples: int = 50, |
| ) -> List[SampledPrompt]: |
| """ |
| Load prompts from LAION subset. |
| |
| Args: |
| laion_dir: Path to LAION data directory |
| n_samples: Number of prompts to sample |
| |
| Returns: |
| List of SampledPrompt objects |
| """ |
| prompts_file = laion_dir / "prompts_500.json" |
|
|
| if not prompts_file.exists(): |
| |
| for alt in [laion_dir / "prompts.json", laion_dir / "captions.json"]: |
| if alt.exists(): |
| prompts_file = alt |
| break |
|
|
| if not prompts_file.exists(): |
| raise FileNotFoundError(f"No prompts file found in {laion_dir}") |
|
|
| return self.load_from_file(prompts_file, n_samples) |
|
|
| def load_from_audiocaps( |
| self, |
| audiocaps_dir: Path = Path("data/audiocaps"), |
| n_samples: int = 50, |
| ) -> List[SampledPrompt]: |
| """ |
| Load prompts from AudioCaps dataset. |
| |
| Args: |
| audiocaps_dir: Path to AudioCaps data directory |
| n_samples: Number of prompts to sample |
| |
| Returns: |
| List of SampledPrompt objects |
| """ |
| |
| for filename in ["captions.json", "audiocaps_subset.json", "prompts.json"]: |
| captions_file = audiocaps_dir / filename |
| if captions_file.exists(): |
| return self.load_from_file(captions_file, n_samples) |
|
|
| raise FileNotFoundError(f"No captions file found in {audiocaps_dir}") |
|
|
| def get_category_distribution( |
| self, |
| prompts: List[SampledPrompt], |
| ) -> Dict[str, int]: |
| """ |
| Get distribution of prompts across categories. |
| |
| Args: |
| prompts: List of sampled prompts |
| |
| Returns: |
| Dictionary mapping category names to counts |
| """ |
| distribution: Dict[str, int] = defaultdict(int) |
| for prompt in prompts: |
| distribution[prompt.category] += 1 |
| return dict(distribution) |
|
|
| def save_sample_set( |
| self, |
| prompts: List[SampledPrompt], |
| path: Path, |
| ): |
| """ |
| Save a sample set to JSON file. |
| |
| Args: |
| prompts: List of sampled prompts |
| path: Output path |
| """ |
| path = Path(path) |
| path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| data = { |
| "n_prompts": len(prompts), |
| "distribution": self.get_category_distribution(prompts), |
| "prompts": [p.to_dict() for p in prompts], |
| } |
|
|
| with path.open("w", encoding="utf-8") as f: |
| json.dump(data, f, indent=2, ensure_ascii=False) |
|
|
| @classmethod |
| def load_sample_set(cls, path: Path) -> List[SampledPrompt]: |
| """ |
| Load a previously saved sample set. |
| |
| Args: |
| path: Path to JSON file |
| |
| Returns: |
| List of SampledPrompt objects |
| """ |
| with Path(path).open("r", encoding="utf-8") as f: |
| data = json.load(f) |
|
|
| return [SampledPrompt.from_dict(p) for p in data["prompts"]] |
|
|
|
|
| def create_experiment_prompts( |
| n_prompts: int = 50, |
| output_path: Optional[Path] = None, |
| seed: int = 42, |
| custom_prompts: Optional[List[str]] = None, |
| ) -> List[SampledPrompt]: |
| """ |
| Convenience function to create a balanced prompt set for experiments. |
| |
| Args: |
| n_prompts: Total number of prompts |
| output_path: Optional path to save the prompt set |
| seed: Random seed |
| custom_prompts: Optional custom prompts to include |
| |
| Returns: |
| List of SampledPrompt objects |
| """ |
| sampler = PromptSampler(seed=seed) |
|
|
| prompts = sampler.sample_stratified( |
| n_total=n_prompts, |
| custom_prompts=custom_prompts, |
| include_examples=True, |
| ) |
|
|
| if output_path: |
| sampler.save_sample_set(prompts, output_path) |
|
|
| return prompts |
|
|