File size: 1,835 Bytes
268c7f9
3556e6f
 
 
268c7f9
 
 
1cd5053
 
 
 
268c7f9
1cd5053
 
 
 
 
 
268c7f9
 
 
da82b2b
1cd5053
 
 
 
 
268c7f9
 
 
 
 
1cd5053
 
 
da82b2b
1cd5053
 
 
 
 
 
 
 
 
 
 
268c7f9
da82b2b
268c7f9
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import random
from typing import List, Optional

from datasets import load_dataset


class PromptLoader:
    """
    A class for loading and sampling prompts from a dataset.
    """

    def __init__(self, seed: int = 42) -> None:
        """
        Initializes the PromptLoader with a specified seed for random sampling.

        Args:
        seed (int): The seed value for the random number generator. Default is 42.
        """
        self.randomizer = random.Random(seed)
        self.data: Optional[List[str]] = None

    def _get_data(self) -> None:
        """
        Loads the dataset of prompts and stores them in the `data` attribute.

        This method uses the `datasets` library to load the dataset and extract prompts from the "train" split.
        """
        self.data = load_dataset("daspartho/stable-diffusion-prompts")["train"][
            "prompt"
        ]

    def load_data(self, size: Optional[int] = None) -> List[str]:
        """
        Loads and samples prompts from the dataset.

        If the dataset is not already loaded, it calls `_get_data()` to load it.

        Args:
        size (Optional[int]): The number of prompts to sample. If not specified, all loaded prompts are returned.

        Returns:
        List[str]: A list of sampled prompts. If `size` is specified, returns a random sample of the specified size.
            If `size` is not specified, returns all loaded prompts.

        Raises:
        ValueError: If `size` is specified and is greater than the number of available prompts.
        """
        if not self.data:
            self._get_data()

        if size:
            if size > len(self.data):
                raise ValueError("Not enough samples available!")
            return self.randomizer.sample(self.data, size)
        else:
            return self.data