Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,368 Bytes
562c833 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from __future__ import annotations
import datetime
import json
import time
import uuid
from collections import OrderedDict
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import gradio
import gradio as gr
import huggingface_hub
from gradio import FlaggingCallback
from gradio_client import utils as client_utils
class HuggingFaceDatasetSaver(gradio.HuggingFaceDatasetSaver):
def flag(
self,
flag_data: list[Any],
flag_option: str = "",
username: str | None = None,
) -> int:
if self.separate_dirs:
# JSONL files to support dataset preview on the Hub
current_utc_time = datetime.now(timezone.utc)
iso_format_without_microseconds = current_utc_time.strftime(
"%Y-%m-%dT%H:%M:%S"
)
milliseconds = int(current_utc_time.microsecond / 1000)
unique_id = f"{iso_format_without_microseconds}.{milliseconds:03}Z"
if username not in (None, ""):
unique_id += f"_U_{username}"
else:
unique_id += f"_{str(uuid.uuid4())[:8]}"
components_dir = self.dataset_dir / unique_id
data_file = components_dir / "metadata.jsonl"
path_in_repo = unique_id # upload in sub folder (safer for concurrency)
else:
# Unique CSV file
components_dir = self.dataset_dir
data_file = components_dir / "data.csv"
path_in_repo = None # upload at root level
return self._flag_in_dir(
data_file=data_file,
components_dir=components_dir,
path_in_repo=path_in_repo,
flag_data=flag_data,
flag_option=flag_option,
username=username or "",
)
def _deserialize_components(
self,
data_dir: Path,
flag_data: list[Any],
flag_option: str = "",
username: str = "",
) -> tuple[dict[Any, Any], list[Any]]:
"""Deserialize components and return the corresponding row for the flagged sample.
Images/audio are saved to disk as individual files.
"""
# Components that can have a preview on dataset repos
file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
# Generate the row corresponding to the flagged sample
features = OrderedDict()
row = []
for component, sample in zip(self.components, flag_data):
# Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
label = component.label or ""
save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
save_dir.mkdir(exist_ok=True, parents=True)
deserialized = component.flag(sample, save_dir)
# Base component .flag method returns JSON; extract path from it when it is FileData
if component.data_model:
data = component.data_model.from_json(json.loads(deserialized))
if component.data_model == gr.data_classes.FileData:
deserialized = data.path
# Add deserialized object to row
features[label] = {"dtype": "string", "_type": "Value"}
try:
deserialized_path = Path(deserialized)
if not deserialized_path.exists():
raise FileNotFoundError(f"File {deserialized} not found")
row.append(str(deserialized_path.relative_to(self.dataset_dir)))
except (FileNotFoundError, TypeError, ValueError):
deserialized = "" if deserialized is None else str(deserialized)
row.append(deserialized)
# If component is eligible for a preview, add the URL of the file
# Be mindful that images and audio can be None
if isinstance(component, tuple(file_preview_types)): # type: ignore
for _component, _type in file_preview_types.items():
if isinstance(component, _component):
features[label + " file"] = {"_type": _type}
break
if deserialized:
path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
Path(deserialized).relative_to(self.dataset_dir)
).replace(
"\\", "/"
)
row.append(
huggingface_hub.hf_hub_url(
repo_id=self.dataset_id,
filename=path_in_repo,
repo_type="dataset",
)
)
else:
row.append("")
features["flag"] = {"dtype": "string", "_type": "Value"}
features["username"] = {"dtype": "string", "_type": "Value"}
row.append(flag_option)
row.append(username)
return features, row
class FlagMethod:
"""
Helper class that contains the flagging options and calls the flagging method. Also
provides visual feedback to the user when flag is clicked.
"""
def __init__(
self,
flagging_callback: FlaggingCallback,
label: str,
value: str,
visual_feedback: bool = True,
):
self.flagging_callback = flagging_callback
self.label = label
self.value = value
self.__name__ = "Flag"
self.visual_feedback = visual_feedback
def __call__(
self,
request: gr.Request,
profile: gr.OAuthProfile | None,
*flag_data,
):
username = None
if profile is not None:
username = profile.username
try:
self.flagging_callback.flag(
list(flag_data), flag_option=self.value, username=username
)
except Exception as e:
print(f"Error while sharing: {e}")
if self.visual_feedback:
return gr.Button(value="Sharing error", interactive=False)
if not self.visual_feedback:
return
time.sleep(0.8) # to provide enough time for the user to observe button change
return gr.Button(value="Sharing complete", interactive=False) |