Spaces:
Runtime error
Runtime error
File size: 1,463 Bytes
e4f9cbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
"""Loads reddit data from Huggingface."""
from typing import Iterable, Optional
from pydantic import Field as PydanticField
from typing_extensions import override
from ...schema import Item
from .huggingface_source import HuggingFaceDataset
from .source import Source, SourceSchema
HF_REDDIT_DATASET_NAME = 'reddit'
HF_SUBREDDIT_COL = 'subreddit'
class RedditDataset(Source):
"""Reddit data loader, using Huggingface.
Loads data from [huggingface.co/datasets/reddit](https://huggingface.co/datasets/reddit).
""" # noqa: D415, D400
name = 'reddit'
subreddits: Optional[list[str]] = PydanticField(
required=False,
description='If defined, only loads the subset of reddit data in these subreddit.',
)
_hf_dataset: HuggingFaceDataset
@override
def setup(self) -> None:
self._hf_dataset = HuggingFaceDataset(dataset_name=HF_REDDIT_DATASET_NAME)
self._hf_dataset.setup()
@override
def source_schema(self) -> SourceSchema:
return self._hf_dataset.source_schema()
@override
def process(self) -> Iterable[Item]:
items = self._hf_dataset.process()
if not self.subreddits:
return items
lower_subreddits = [subreddit.lower() for subreddit in self.subreddits]
for item in items:
item_subreddit = item[HF_SUBREDDIT_COL]
if item_subreddit.lower() not in lower_subreddits:
# Yield None so that the progress bar is accurate.
yield None
continue
yield item
|