Spaces:
Sleeping
Sleeping
| import random | |
| from typing import List, Optional | |
| from datasets import load_dataset | |
| class PromptLoader: | |
| """ | |
| A class for loading and sampling prompts from a dataset. | |
| """ | |
| def __init__(self, seed: int = 42) -> None: | |
| """ | |
| Initializes the PromptLoader with a specified seed for random sampling. | |
| Args: | |
| seed (int): The seed value for the random number generator. Default is 42. | |
| """ | |
| self.randomizer = random.Random(seed) | |
| self.data: Optional[List[str]] = None | |
| def _get_data(self) -> None: | |
| """ | |
| Loads the dataset of prompts and stores them in the `data` attribute. | |
| This method uses the `datasets` library to load the dataset and extract prompts from the "train" split. | |
| """ | |
| self.data = load_dataset("daspartho/stable-diffusion-prompts")["train"][ | |
| "prompt" | |
| ] | |
| def load_data(self, size: Optional[int] = None) -> List[str]: | |
| """ | |
| Loads and samples prompts from the dataset. | |
| If the dataset is not already loaded, it calls `_get_data()` to load it. | |
| Args: | |
| size (Optional[int]): The number of prompts to sample. If not specified, all loaded prompts are returned. | |
| Returns: | |
| List[str]: A list of sampled prompts. If `size` is specified, returns a random sample of the specified size. | |
| If `size` is not specified, returns all loaded prompts. | |
| Raises: | |
| ValueError: If `size` is specified and is greater than the number of available prompts. | |
| """ | |
| if not self.data: | |
| self._get_data() | |
| if size: | |
| if size > len(self.data): | |
| raise ValueError("Not enough samples available!") | |
| return self.randomizer.sample(self.data, size) | |
| else: | |
| return self.data | |