Spaces:
Sleeping
Sleeping
import random | |
from typing import List, Optional | |
from datasets import load_dataset | |
class PromptLoader: | |
""" | |
A class for loading and sampling prompts from a dataset. | |
""" | |
def __init__(self, seed: int = 42) -> None: | |
""" | |
Initializes the PromptLoader with a specified seed for random sampling. | |
Args: | |
seed (int): The seed value for the random number generator. Default is 42. | |
""" | |
self.randomizer = random.Random(seed) | |
self.data: Optional[List[str]] = None | |
def _get_data(self) -> None: | |
""" | |
Loads the dataset of prompts and stores them in the `data` attribute. | |
This method uses the `datasets` library to load the dataset and extract prompts from the "train" split. | |
""" | |
self.data = load_dataset("daspartho/stable-diffusion-prompts")["train"][ | |
"prompt" | |
] | |
def load_data(self, size: Optional[int] = None) -> List[str]: | |
""" | |
Loads and samples prompts from the dataset. | |
If the dataset is not already loaded, it calls `_get_data()` to load it. | |
Args: | |
size (Optional[int]): The number of prompts to sample. If not specified, all loaded prompts are returned. | |
Returns: | |
List[str]: A list of sampled prompts. If `size` is specified, returns a random sample of the specified size. | |
If `size` is not specified, returns all loaded prompts. | |
Raises: | |
ValueError: If `size` is specified and is greater than the number of available prompts. | |
""" | |
if not self.data: | |
self._get_data() | |
if size: | |
if size > len(self.data): | |
raise ValueError("Not enough samples available!") | |
return self.randomizer.sample(self.data, size) | |
else: | |
return self.data | |