File size: 3,698 Bytes
db1ee85
 
 
 
 
 
 
4f361e4
db1ee85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f361e4
db1ee85
 
 
 
 
 
 
 
 
4f361e4
db1ee85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from gradio import FlaggingCallback, utils
import csv
import datetime
import os
import re
import secrets
from pathlib import Path
from typing import Any, List, Union

class CSVLogger(FlaggingCallback):
    """
    The default implementation of the FlaggingCallback abstract class. Each flagged
    sample (both the input and output data) is logged to a CSV file with headers on the machine running the gradio app.
    Example:
        import gradio as gr
        def image_classifier(inp):
            return {'cat': 0.3, 'dog': 0.7}
        demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
                            flagging_callback=CSVLogger())
    Guides: using_flagging
    """

    def __init__(self):
        pass

    def setup(
        self,
        components: List[Any],
        flagging_dir: Union[str, Path],
    ):
        self.components = components
        self.flagging_dir = flagging_dir
        os.makedirs(flagging_dir, exist_ok=True)

    def flag(
        self,
        flag_data: List[Any],
        flag_option: str = "",
        username: Union[str, None] = None,
        filename="log.csv",
    ) -> int:
        flagging_dir = self.flagging_dir
        filename = re.sub(r"[/\\?%*:|\"<>\x7F\x00-\x1F]", "-", filename)
        log_filepath = Path(flagging_dir) / filename
        is_new = not Path(log_filepath).exists()
        headers = [
            getattr(component, "label", None) or f"component {idx}"
            for idx, component in enumerate(self.components)
        ] + [
            "flag",
            "username",
            "timestamp",
        ]

        csv_data = []
        for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
            save_dir = Path(
                flagging_dir
            ) / (
                getattr(component, "label", None) or f"component {idx}"
            )
            if utils.is_update(sample):
                csv_data.append(str(sample))
            else:
                csv_data.append(
                    component.deserialize(sample, save_dir=save_dir)
                    if sample is not None
                    else ""
                )
        csv_data.append(flag_option)
        csv_data.append(username if username is not None else "")
        csv_data.append(str(datetime.datetime.now()))

        try:
            with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile:
                writer = csv.writer(csvfile)
                if is_new:
                    writer.writerow(utils.sanitize_list_for_csv(headers))
                writer.writerow(utils.sanitize_list_for_csv(csv_data))
        except Exception as e:
            # workaround "OSError: [Errno 95] Operation not supported" with open(log_filepath, "a") on some cloud mounted directory
            random_hex = secrets.token_hex(16)
            tmp_log_filepath = str(log_filepath) + f".tmp_{random_hex}"
            with open(tmp_log_filepath, "a", newline="", encoding="utf-8") as csvfile:
                writer = csv.writer(csvfile)
                if is_new:
                    writer.writerow(utils.sanitize_list_for_csv(headers))
                writer.writerow(utils.sanitize_list_for_csv(csv_data))
            os.system(f"mv '{log_filepath}' '{log_filepath}.old_{random_hex}'")
            os.system(f"cat '{log_filepath}.old_{random_hex}' '{tmp_log_filepath}' > '{log_filepath}'")
            os.system(f"rm '{tmp_log_filepath}'")
            os.system(f"rm '{log_filepath}.old_{random_hex}'")

        with open(log_filepath, "r", encoding="utf-8") as csvfile:
            line_count = len([None for row in csv.reader(csvfile)]) - 1
        return line_count