File size: 4,107 Bytes
0ad74ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import os
import pathlib
import tempfile
from unittest.mock import MagicMock
import pytest
import gradio as gr
from gradio import flagging
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestDefaultFlagging:
def test_default_flagging_callback(self):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
io.launch(prevent_thread_lock=True)
row_count = io.flagging_callback.flag(["test", "test"])
assert row_count == 1 # 2 rows written including header
row_count = io.flagging_callback.flag(["test", "test"])
assert row_count == 2 # 3 rows written including header
io.close()
def test_files_saved_as_file_paths(self):
image = {"path": str(pathlib.Path(__file__).parent / "test_files" / "bus.png")}
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(
lambda x: x,
"image",
"image",
flagging_dir=tmpdirname,
flagging_mode="auto",
)
io.launch(prevent_thread_lock=True)
io.flagging_callback.flag([image, image])
io.close()
with open(os.path.join(tmpdirname, "dataset1.csv")) as f:
flagged_data = f.readlines()[1].split(",")[0]
assert flagged_data.endswith("bus.png")
io.close()
def test_flagging_does_not_create_unnecessary_directories(self):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
io.launch(prevent_thread_lock=True)
io.flagging_callback.flag(["test", "test"])
assert os.listdir(tmpdirname) == ["dataset1.csv"]
class TestSimpleFlagging:
def test_simple_csv_flagging_callback(self):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(
lambda x: x,
"text",
"text",
flagging_dir=tmpdirname,
flagging_callback=flagging.SimpleCSVLogger(),
)
io.launch(prevent_thread_lock=True)
row_count = io.flagging_callback.flag(["test", "test"])
assert row_count == 0 # no header in SimpleCSVLogger
row_count = io.flagging_callback.flag(["test", "test"])
assert row_count == 1 # no header in SimpleCSVLogger
io.close()
class TestDisableFlagging:
def test_flagging_no_permission_error_with_flagging_disabled(self):
tmpdirname = tempfile.mkdtemp()
os.chmod(tmpdirname, 0o444) # Make directory read-only
nonwritable_path = os.path.join(tmpdirname, "flagging_dir")
io = gr.Interface(
lambda x: x,
"text",
"text",
flagging_mode="never",
flagging_dir=nonwritable_path,
)
io.launch(prevent_thread_lock=True)
io.close()
class TestInterfaceSetsUpFlagging:
@pytest.mark.parametrize(
"flagging_mode, called",
[
("manual", True),
("auto", True),
("never", False),
],
)
def test_flag_method_init_called(self, flagging_mode, called):
flagging.FlagMethod.__init__ = MagicMock()
flagging.FlagMethod.__init__.return_value = None
gr.Interface(lambda x: x, "text", "text", flagging_mode=flagging_mode)
assert flagging.FlagMethod.__init__.called == called
@pytest.mark.parametrize(
"options, processed_options",
[
(None, [("Flag", None)]),
(["yes", "no"], [("Flag as yes", "yes"), ("Flag as no", "no")]),
([("abc", "de"), ("123", "45")], [("abc", "de"), ("123", "45")]),
],
)
def test_flagging_options_processed_correctly(self, options, processed_options):
io = gr.Interface(lambda x: x, "text", "text", flagging_options=options)
assert io.flagging_options == processed_options
|