guangkaixu's picture
upload
562c833
raw history blame
No virus
6.37 kB
from __future__ import annotations
import datetime
import json
import time
import uuid
from collections import OrderedDict
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import gradio
import gradio as gr
import huggingface_hub
from gradio import FlaggingCallback
from gradio_client import utils as client_utils
class HuggingFaceDatasetSaver(gradio.HuggingFaceDatasetSaver):
def flag(
self,
flag_data: list[Any],
flag_option: str = "",
username: str | None = None,
) -> int:
if self.separate_dirs:
# JSONL files to support dataset preview on the Hub
current_utc_time = datetime.now(timezone.utc)
iso_format_without_microseconds = current_utc_time.strftime(
"%Y-%m-%dT%H:%M:%S"
)
milliseconds = int(current_utc_time.microsecond / 1000)
unique_id = f"{iso_format_without_microseconds}.{milliseconds:03}Z"
if username not in (None, ""):
unique_id += f"_U_{username}"
else:
unique_id += f"_{str(uuid.uuid4())[:8]}"
components_dir = self.dataset_dir / unique_id
data_file = components_dir / "metadata.jsonl"
path_in_repo = unique_id # upload in sub folder (safer for concurrency)
else:
# Unique CSV file
components_dir = self.dataset_dir
data_file = components_dir / "data.csv"
path_in_repo = None # upload at root level
return self._flag_in_dir(
data_file=data_file,
components_dir=components_dir,
path_in_repo=path_in_repo,
flag_data=flag_data,
flag_option=flag_option,
username=username or "",
)
def _deserialize_components(
self,
data_dir: Path,
flag_data: list[Any],
flag_option: str = "",
username: str = "",
) -> tuple[dict[Any, Any], list[Any]]:
"""Deserialize components and return the corresponding row for the flagged sample.
Images/audio are saved to disk as individual files.
"""
# Components that can have a preview on dataset repos
file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
# Generate the row corresponding to the flagged sample
features = OrderedDict()
row = []
for component, sample in zip(self.components, flag_data):
# Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
label = component.label or ""
save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
save_dir.mkdir(exist_ok=True, parents=True)
deserialized = component.flag(sample, save_dir)
# Base component .flag method returns JSON; extract path from it when it is FileData
if component.data_model:
data = component.data_model.from_json(json.loads(deserialized))
if component.data_model == gr.data_classes.FileData:
deserialized = data.path
# Add deserialized object to row
features[label] = {"dtype": "string", "_type": "Value"}
try:
deserialized_path = Path(deserialized)
if not deserialized_path.exists():
raise FileNotFoundError(f"File {deserialized} not found")
row.append(str(deserialized_path.relative_to(self.dataset_dir)))
except (FileNotFoundError, TypeError, ValueError):
deserialized = "" if deserialized is None else str(deserialized)
row.append(deserialized)
# If component is eligible for a preview, add the URL of the file
# Be mindful that images and audio can be None
if isinstance(component, tuple(file_preview_types)): # type: ignore
for _component, _type in file_preview_types.items():
if isinstance(component, _component):
features[label + " file"] = {"_type": _type}
break
if deserialized:
path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
Path(deserialized).relative_to(self.dataset_dir)
).replace(
"\\", "/"
)
row.append(
huggingface_hub.hf_hub_url(
repo_id=self.dataset_id,
filename=path_in_repo,
repo_type="dataset",
)
)
else:
row.append("")
features["flag"] = {"dtype": "string", "_type": "Value"}
features["username"] = {"dtype": "string", "_type": "Value"}
row.append(flag_option)
row.append(username)
return features, row
class FlagMethod:
"""
Helper class that contains the flagging options and calls the flagging method. Also
provides visual feedback to the user when flag is clicked.
"""
def __init__(
self,
flagging_callback: FlaggingCallback,
label: str,
value: str,
visual_feedback: bool = True,
):
self.flagging_callback = flagging_callback
self.label = label
self.value = value
self.__name__ = "Flag"
self.visual_feedback = visual_feedback
def __call__(
self,
request: gr.Request,
profile: gr.OAuthProfile | None,
*flag_data,
):
username = None
if profile is not None:
username = profile.username
try:
self.flagging_callback.flag(
list(flag_data), flag_option=self.value, username=username
)
except Exception as e:
print(f"Error while sharing: {e}")
if self.visual_feedback:
return gr.Button(value="Sharing error", interactive=False)
if not self.visual_feedback:
return
time.sleep(0.8) # to provide enough time for the user to observe button change
return gr.Button(value="Sharing complete", interactive=False)