import random from typing import List, Optional from datasets import load_dataset class PromptLoader: """ A class for loading and sampling prompts from a dataset. """ def __init__(self, seed: int = 42) -> None: """ Initializes the PromptLoader with a specified seed for random sampling. Args: seed (int): The seed value for the random number generator. Default is 42. """ self.randomizer = random.Random(seed) self.data: Optional[List[str]] = None def _get_data(self) -> None: """ Loads the dataset of prompts and stores them in the `data` attribute. This method uses the `datasets` library to load the dataset and extract prompts from the "train" split. """ self.data = load_dataset("daspartho/stable-diffusion-prompts")["train"][ "prompt" ] def load_data(self, size: Optional[int] = None) -> List[str]: """ Loads and samples prompts from the dataset. If the dataset is not already loaded, it calls `_get_data()` to load it. Args: size (Optional[int]): The number of prompts to sample. If not specified, all loaded prompts are returned. Returns: List[str]: A list of sampled prompts. If `size` is specified, returns a random sample of the specified size. If `size` is not specified, returns all loaded prompts. Raises: ValueError: If `size` is specified and is greater than the number of available prompts. """ if not self.data: self._get_data() if size: if size > len(self.data): raise ValueError("Not enough samples available!") return self.randomizer.sample(self.data, size) else: return self.data