Spaces:
Sleeping
Sleeping
File size: 1,835 Bytes
268c7f9 3556e6f 268c7f9 1cd5053 268c7f9 1cd5053 268c7f9 da82b2b 1cd5053 268c7f9 1cd5053 da82b2b 1cd5053 268c7f9 da82b2b 268c7f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import random
from typing import List, Optional
from datasets import load_dataset
class PromptLoader:
"""
A class for loading and sampling prompts from a dataset.
"""
def __init__(self, seed: int = 42) -> None:
"""
Initializes the PromptLoader with a specified seed for random sampling.
Args:
seed (int): The seed value for the random number generator. Default is 42.
"""
self.randomizer = random.Random(seed)
self.data: Optional[List[str]] = None
def _get_data(self) -> None:
"""
Loads the dataset of prompts and stores them in the `data` attribute.
This method uses the `datasets` library to load the dataset and extract prompts from the "train" split.
"""
self.data = load_dataset("daspartho/stable-diffusion-prompts")["train"][
"prompt"
]
def load_data(self, size: Optional[int] = None) -> List[str]:
"""
Loads and samples prompts from the dataset.
If the dataset is not already loaded, it calls `_get_data()` to load it.
Args:
size (Optional[int]): The number of prompts to sample. If not specified, all loaded prompts are returned.
Returns:
List[str]: A list of sampled prompts. If `size` is specified, returns a random sample of the specified size.
If `size` is not specified, returns all loaded prompts.
Raises:
ValueError: If `size` is specified and is greater than the number of available prompts.
"""
if not self.data:
self._get_data()
if size:
if size > len(self.data):
raise ValueError("Not enough samples available!")
return self.randomizer.sample(self.data, size)
else:
return self.data
|