Spaces:
Runtime error
Runtime error
File size: 6,691 Bytes
743fd42 2d0d0c7 743fd42 2d0d0c7 743fd42 2d0d0c7 743fd42 7e1c3f7 2d0d0c7 743fd42 7e1c3f7 743fd42 7e1c3f7 743fd42 2d0d0c7 743fd42 2d0d0c7 743fd42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
from gradio.flagging import FlaggingCallback, _get_dataset_features_info
from gradio.components import IOComponent
from gradio import utils
from typing import Any, List, Optional
from dotenv import load_dotenv
from datetime import datetime
import csv, os, pytz
# --- Load environments vars ---
load_dotenv()
# --- Classes declaration ---
class DateLogs:
def __init__(
self,
zone: str="America/Argentina/Cordoba"
) -> None:
self.time_zone = pytz.timezone(zone)
def full(
self
) -> str:
now = datetime.now(self.time_zone)
return now.strftime("%H:%M:%S %d-%m-%Y")
def day(
self
) -> str:
now = datetime.now(self.time_zone)
return now.strftime("%d-%m-%Y")
class HuggingFaceDatasetSaver(FlaggingCallback):
"""
A callback that saves each flagged sample (both the input and output data)
to a HuggingFace dataset.
Example:
import gradio as gr
hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "image-classification-mistakes")
def image_classifier(inp):
return {'cat': 0.3, 'dog': 0.7}
demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
allow_flagging="manual", flagging_callback=hf_writer)
Guides: using_flagging
"""
def __init__(
self,
dataset_name: str=None,
hf_token: str=os.getenv('HF_TOKEN'),
organization: Optional[str]=os.getenv('ORG_NAME'),
private: bool=True,
available_logs: bool=False
) -> None:
"""
Parameters:
hf_token: The HuggingFace token to use to create (and write the flagged sample to) the HuggingFace dataset.
dataset_name: The name of the dataset to save the data to, e.g. "image-classifier-1"
organization: The organization to save the dataset under. The hf_token must provide write access to this organization. If not provided, saved under the name of the user corresponding to the hf_token.
private: Whether the dataset should be private (defaults to False).
"""
assert(dataset_name is not None), "Error: Parameter 'dataset_name' cannot be empty!."
self.dataset_name = dataset_name
self.hf_token = hf_token
self.organization_name = organization
self.dataset_private = private
self.datetime = DateLogs()
self.available_logs = available_logs
if not available_logs:
print("Push: logs DISABLED!...")
def setup(
self,
components: List[IOComponent],
flagging_dir: str
) -> None:
"""
Params:
flagging_dir (str): local directory where the dataset is cloned,
updated, and pushed from.
"""
if self.available_logs:
try:
import huggingface_hub
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Package `huggingface_hub` not found is needed "
"for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'."
)
path_to_dataset_repo = huggingface_hub.create_repo(
repo_id=os.path.join(self.organization_name, self.dataset_name),
token=self.hf_token,
private=self.dataset_private,
repo_type="dataset",
exist_ok=True,
)
self.path_to_dataset_repo = path_to_dataset_repo
self.components = components
self.flagging_dir = flagging_dir
self.dataset_dir = self.dataset_name
self.repo = huggingface_hub.Repository(
local_dir=self.dataset_dir,
clone_from=path_to_dataset_repo,
use_auth_token=self.hf_token,
)
self.repo.git_pull(lfs=True)
# Should filename be user-specified?
# log_file_name = self.datetime.day()+"_"+self.flagging_dir+".csv"
self.log_file = os.path.join(self.dataset_dir, self.flagging_dir+".csv")
def flag(
self,
flag_data: List[Any],
flag_option: Optional[str]=None,
flag_index: Optional[int]=None,
username: Optional[str]=None,
) -> int:
if self.available_logs:
self.repo.git_pull(lfs=True)
is_new = not os.path.exists(self.log_file)
with open(self.log_file, "a", newline="", encoding="utf-8") as csvfile:
writer = csv.writer(csvfile)
# File previews for certain input and output types
infos, file_preview_types, headers = _get_dataset_features_info(
is_new, self.components
)
# Generate the headers and dataset_infos
if is_new:
headers = [
component.label or f"component {idx}"
for idx, component in enumerate(self.components)
] + [
"flag",
"username",
"timestamp",
]
writer.writerow(utils.sanitize_list_for_csv(headers))
# Generate the row corresponding to the flagged sample
csv_data = []
for component, sample in zip(self.components, flag_data):
save_dir = os.path.join(
self.dataset_dir,
utils.strip_invalid_filename_characters(component.label),
)
filepath = component.deserialize(sample, save_dir, None)
csv_data.append(filepath)
if isinstance(component, tuple(file_preview_types)):
csv_data.append(
"{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
)
csv_data.append(flag_option if flag_option is not None else "")
csv_data.append(username if username is not None else "")
csv_data.append(self.datetime.full())
writer.writerow(utils.sanitize_list_for_csv(csv_data))
with open(self.log_file, "r", encoding="utf-8") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
self.repo.push_to_hub(commit_message="Flagged sample #{}".format(line_count))
else:
line_count = 0
print("Logs: Virtual push...")
return line_count |