File size: 4,107 Bytes
0ad74ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import pathlib
import tempfile
from unittest.mock import MagicMock

import pytest

import gradio as gr
from gradio import flagging

os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"


class TestDefaultFlagging:
    def test_default_flagging_callback(self):
        with tempfile.TemporaryDirectory() as tmpdirname:
            io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
            io.launch(prevent_thread_lock=True)
            row_count = io.flagging_callback.flag(["test", "test"])
            assert row_count == 1  # 2 rows written including header
            row_count = io.flagging_callback.flag(["test", "test"])
            assert row_count == 2  # 3 rows written including header
        io.close()

    def test_files_saved_as_file_paths(self):
        image = {"path": str(pathlib.Path(__file__).parent / "test_files" / "bus.png")}
        with tempfile.TemporaryDirectory() as tmpdirname:
            io = gr.Interface(
                lambda x: x,
                "image",
                "image",
                flagging_dir=tmpdirname,
                flagging_mode="auto",
            )
            io.launch(prevent_thread_lock=True)
            io.flagging_callback.flag([image, image])
            io.close()
            with open(os.path.join(tmpdirname, "dataset1.csv")) as f:
                flagged_data = f.readlines()[1].split(",")[0]
                assert flagged_data.endswith("bus.png")
        io.close()

    def test_flagging_does_not_create_unnecessary_directories(self):
        with tempfile.TemporaryDirectory() as tmpdirname:
            io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
            io.launch(prevent_thread_lock=True)
            io.flagging_callback.flag(["test", "test"])
            assert os.listdir(tmpdirname) == ["dataset1.csv"]


class TestSimpleFlagging:
    def test_simple_csv_flagging_callback(self):
        with tempfile.TemporaryDirectory() as tmpdirname:
            io = gr.Interface(
                lambda x: x,
                "text",
                "text",
                flagging_dir=tmpdirname,
                flagging_callback=flagging.SimpleCSVLogger(),
            )
            io.launch(prevent_thread_lock=True)
            row_count = io.flagging_callback.flag(["test", "test"])
            assert row_count == 0  # no header in SimpleCSVLogger
            row_count = io.flagging_callback.flag(["test", "test"])
            assert row_count == 1  # no header in SimpleCSVLogger
        io.close()


class TestDisableFlagging:
    def test_flagging_no_permission_error_with_flagging_disabled(self):
        tmpdirname = tempfile.mkdtemp()
        os.chmod(tmpdirname, 0o444)  # Make directory read-only
        nonwritable_path = os.path.join(tmpdirname, "flagging_dir")
        io = gr.Interface(
            lambda x: x,
            "text",
            "text",
            flagging_mode="never",
            flagging_dir=nonwritable_path,
        )
        io.launch(prevent_thread_lock=True)
        io.close()


class TestInterfaceSetsUpFlagging:
    @pytest.mark.parametrize(
        "flagging_mode, called",
        [
            ("manual", True),
            ("auto", True),
            ("never", False),
        ],
    )
    def test_flag_method_init_called(self, flagging_mode, called):
        flagging.FlagMethod.__init__ = MagicMock()
        flagging.FlagMethod.__init__.return_value = None
        gr.Interface(lambda x: x, "text", "text", flagging_mode=flagging_mode)
        assert flagging.FlagMethod.__init__.called == called

    @pytest.mark.parametrize(
        "options, processed_options",
        [
            (None, [("Flag", None)]),
            (["yes", "no"], [("Flag as yes", "yes"), ("Flag as no", "no")]),
            ([("abc", "de"), ("123", "45")], [("abc", "de"), ("123", "45")]),
        ],
    )
    def test_flagging_options_processed_correctly(self, options, processed_options):
        io = gr.Interface(lambda x: x, "text", "text", flagging_options=options)
        assert io.flagging_options == processed_options