from gradio import FlaggingCallback, utils import csv import datetime import os import re import secrets from pathlib import Path from typing import Any, List, Union class CSVLogger(FlaggingCallback): """ The default implementation of the FlaggingCallback abstract class. Each flagged sample (both the input and output data) is logged to a CSV file with headers on the machine running the gradio app. Example: import gradio as gr def image_classifier(inp): return {'cat': 0.3, 'dog': 0.7} demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label", flagging_callback=CSVLogger()) Guides: using_flagging """ def __init__(self): pass def setup( self, components: List[Any], flagging_dir: Union[str, Path], ): self.components = components self.flagging_dir = flagging_dir os.makedirs(flagging_dir, exist_ok=True) def flag( self, flag_data: List[Any], flag_option: str = "", username: Union[str, None] = None, filename="log.csv", ) -> int: flagging_dir = self.flagging_dir filename = re.sub(r"[/\\?%*:|\"<>\x7F\x00-\x1F]", "-", filename) log_filepath = Path(flagging_dir) / filename is_new = not Path(log_filepath).exists() headers = [ getattr(component, "label", None) or f"component {idx}" for idx, component in enumerate(self.components) ] + [ "flag", "username", "timestamp", ] csv_data = [] for idx, (component, sample) in enumerate(zip(self.components, flag_data)): save_dir = Path( flagging_dir ) / ( getattr(component, "label", None) or f"component {idx}" ) if utils.is_update(sample): csv_data.append(str(sample)) else: csv_data.append( component.deserialize(sample, save_dir=save_dir) if sample is not None else "" ) csv_data.append(flag_option) csv_data.append(username if username is not None else "") csv_data.append(str(datetime.datetime.now())) try: with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile: writer = csv.writer(csvfile) if is_new: writer.writerow(utils.sanitize_list_for_csv(headers)) writer.writerow(utils.sanitize_list_for_csv(csv_data)) except Exception as e: # workaround "OSError: [Errno 95] Operation not supported" with open(log_filepath, "a") on some cloud mounted directory random_hex = secrets.token_hex(16) tmp_log_filepath = str(log_filepath) + f".tmp_{random_hex}" with open(tmp_log_filepath, "a", newline="", encoding="utf-8") as csvfile: writer = csv.writer(csvfile) if is_new: writer.writerow(utils.sanitize_list_for_csv(headers)) writer.writerow(utils.sanitize_list_for_csv(csv_data)) os.system(f"mv '{log_filepath}' '{log_filepath}.old_{random_hex}'") os.system(f"cat '{log_filepath}.old_{random_hex}' '{tmp_log_filepath}' > '{log_filepath}'") os.system(f"rm '{tmp_log_filepath}'") os.system(f"rm '{log_filepath}.old_{random_hex}'") with open(log_filepath, "r", encoding="utf-8") as csvfile: line_count = len([None for row in csv.reader(csvfile)]) - 1 return line_count