File size: 1,463 Bytes
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""Loads reddit data from Huggingface."""
from typing import Iterable, Optional

from pydantic import Field as PydanticField
from typing_extensions import override

from ...schema import Item
from .huggingface_source import HuggingFaceDataset
from .source import Source, SourceSchema

HF_REDDIT_DATASET_NAME = 'reddit'
HF_SUBREDDIT_COL = 'subreddit'


class RedditDataset(Source):
  """Reddit data loader, using Huggingface.

  Loads data from [huggingface.co/datasets/reddit](https://huggingface.co/datasets/reddit).
  """ # noqa: D415, D400
  name = 'reddit'

  subreddits: Optional[list[str]] = PydanticField(
    required=False,
    description='If defined, only loads the subset of reddit data in these subreddit.',
  )

  _hf_dataset: HuggingFaceDataset

  @override
  def setup(self) -> None:
    self._hf_dataset = HuggingFaceDataset(dataset_name=HF_REDDIT_DATASET_NAME)
    self._hf_dataset.setup()

  @override
  def source_schema(self) -> SourceSchema:
    return self._hf_dataset.source_schema()

  @override
  def process(self) -> Iterable[Item]:
    items = self._hf_dataset.process()

    if not self.subreddits:
      return items

    lower_subreddits = [subreddit.lower() for subreddit in self.subreddits]

    for item in items:
      item_subreddit = item[HF_SUBREDDIT_COL]
      if item_subreddit.lower() not in lower_subreddits:
        # Yield None so that the progress bar is accurate.
        yield None
        continue
      yield item