File size: 6,368 Bytes
562c833
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from __future__ import annotations

import datetime
import json
import time
import uuid
from collections import OrderedDict
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

import gradio
import gradio as gr
import huggingface_hub
from gradio import FlaggingCallback
from gradio_client import utils as client_utils


class HuggingFaceDatasetSaver(gradio.HuggingFaceDatasetSaver):
    def flag(
        self,
        flag_data: list[Any],
        flag_option: str = "",
        username: str | None = None,
    ) -> int:
        if self.separate_dirs:
            # JSONL files to support dataset preview on the Hub
            current_utc_time = datetime.now(timezone.utc)
            iso_format_without_microseconds = current_utc_time.strftime(
                "%Y-%m-%dT%H:%M:%S"
            )
            milliseconds = int(current_utc_time.microsecond / 1000)
            unique_id = f"{iso_format_without_microseconds}.{milliseconds:03}Z"
            if username not in (None, ""):
                unique_id += f"_U_{username}"
            else:
                unique_id += f"_{str(uuid.uuid4())[:8]}"
            components_dir = self.dataset_dir / unique_id
            data_file = components_dir / "metadata.jsonl"
            path_in_repo = unique_id  # upload in sub folder (safer for concurrency)
        else:
            # Unique CSV file
            components_dir = self.dataset_dir
            data_file = components_dir / "data.csv"
            path_in_repo = None  # upload at root level

        return self._flag_in_dir(
            data_file=data_file,
            components_dir=components_dir,
            path_in_repo=path_in_repo,
            flag_data=flag_data,
            flag_option=flag_option,
            username=username or "",
        )

    def _deserialize_components(
        self,
        data_dir: Path,
        flag_data: list[Any],
        flag_option: str = "",
        username: str = "",
    ) -> tuple[dict[Any, Any], list[Any]]:
        """Deserialize components and return the corresponding row for the flagged sample.
        Images/audio are saved to disk as individual files.
        """
        # Components that can have a preview on dataset repos
        file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}

        # Generate the row corresponding to the flagged sample
        features = OrderedDict()
        row = []
        for component, sample in zip(self.components, flag_data):
            # Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
            label = component.label or ""
            save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
            save_dir.mkdir(exist_ok=True, parents=True)
            deserialized = component.flag(sample, save_dir)

            # Base component .flag method returns JSON; extract path from it when it is FileData
            if component.data_model:
                data = component.data_model.from_json(json.loads(deserialized))
                if component.data_model == gr.data_classes.FileData:
                    deserialized = data.path

            # Add deserialized object to row
            features[label] = {"dtype": "string", "_type": "Value"}
            try:
                deserialized_path = Path(deserialized)
                if not deserialized_path.exists():
                    raise FileNotFoundError(f"File {deserialized} not found")
                row.append(str(deserialized_path.relative_to(self.dataset_dir)))
            except (FileNotFoundError, TypeError, ValueError):
                deserialized = "" if deserialized is None else str(deserialized)
                row.append(deserialized)

            # If component is eligible for a preview, add the URL of the file
            # Be mindful that images and audio can be None
            if isinstance(component, tuple(file_preview_types)):  # type: ignore
                for _component, _type in file_preview_types.items():
                    if isinstance(component, _component):
                        features[label + " file"] = {"_type": _type}
                        break
                if deserialized:
                    path_in_repo = str(  # returned filepath is absolute, we want it relative to compute URL
                        Path(deserialized).relative_to(self.dataset_dir)
                    ).replace(
                        "\\", "/"
                    )
                    row.append(
                        huggingface_hub.hf_hub_url(
                            repo_id=self.dataset_id,
                            filename=path_in_repo,
                            repo_type="dataset",
                        )
                    )
                else:
                    row.append("")
        features["flag"] = {"dtype": "string", "_type": "Value"}
        features["username"] = {"dtype": "string", "_type": "Value"}
        row.append(flag_option)
        row.append(username)
        return features, row


class FlagMethod:
    """
    Helper class that contains the flagging options and calls the flagging method. Also
    provides visual feedback to the user when flag is clicked.
    """

    def __init__(
        self,
        flagging_callback: FlaggingCallback,
        label: str,
        value: str,
        visual_feedback: bool = True,
    ):
        self.flagging_callback = flagging_callback
        self.label = label
        self.value = value
        self.__name__ = "Flag"
        self.visual_feedback = visual_feedback

    def __call__(
        self,
        request: gr.Request,
        profile: gr.OAuthProfile | None,
        *flag_data,
    ):
        username = None
        if profile is not None:
            username = profile.username
        try:
            self.flagging_callback.flag(
                list(flag_data), flag_option=self.value, username=username
            )
        except Exception as e:
            print(f"Error while sharing: {e}")
            if self.visual_feedback:
                return gr.Button(value="Sharing error", interactive=False)
        if not self.visual_feedback:
            return
        time.sleep(0.8)  # to provide enough time for the user to observe button change
        return gr.Button(value="Sharing complete", interactive=False)