import tempfile from concurrent.futures import wait from pathlib import Path from unittest.mock import patch import pytest import gradio as gr def invalid_fn(message): return message def double(message, history): return message + " " + message async def async_greet(message, history): return "hi, " + message def stream(message, history): for i in range(len(message)): yield message[: i + 1] async def async_stream(message, history): for i in range(len(message)): yield message[: i + 1] def count(message, history): return str(len(history)) def echo_system_prompt_plus_message(message, history, system_prompt, tokens): response = f"{system_prompt} {message}" for i in range(min(len(response), int(tokens))): yield response[: i + 1] class TestInit: def test_no_fn(self): with pytest.raises(TypeError): gr.ChatInterface() def test_configuring_buttons(self): chatbot = gr.ChatInterface(double, submit_btn=None, retry_btn=None) assert chatbot.submit_btn is None assert chatbot.retry_btn is None def test_concurrency_limit(self): chat = gr.ChatInterface(double, concurrency_limit=10) assert chat.concurrency_limit == 10 fns = [fn for fn in chat.fns if fn.name in {"_submit_fn", "_api_submit_fn"}] assert all(fn.concurrency_limit == 10 for fn in fns) def test_events_attached(self): chatbot = gr.ChatInterface(double) dependencies = chatbot.dependencies textbox = chatbot.textbox._id submit_btn = chatbot.submit_btn._id assert next( ( d for d in dependencies if d["targets"] == [(textbox, "submit"), (submit_btn, "click")] ), None, ) for btn_id in [ chatbot.retry_btn._id, chatbot.clear_btn._id, chatbot.undo_btn._id, ]: assert next( (d for d in dependencies if d["targets"][0] == (btn_id, "click")), None, ) def test_example_caching(self, monkeypatch): with patch( "gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()) ): chatbot = gr.ChatInterface( double, examples=["hello", "hi"], cache_examples=True ) prediction_hello = chatbot.examples_handler.load_from_cache(0) prediction_hi = chatbot.examples_handler.load_from_cache(1) assert prediction_hello[0].root[0] == ("hello", "hello hello") assert prediction_hi[0].root[0] == ("hi", "hi hi") def test_example_caching_async(self, monkeypatch): with patch( "gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()) ): chatbot = gr.ChatInterface( async_greet, examples=["abubakar", "tom"], cache_examples=True ) prediction_hello = chatbot.examples_handler.load_from_cache(0) prediction_hi = chatbot.examples_handler.load_from_cache(1) assert prediction_hello[0].root[0] == ("abubakar", "hi, abubakar") assert prediction_hi[0].root[0] == ("tom", "hi, tom") def test_example_caching_with_streaming(self, monkeypatch): with patch( "gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()) ): chatbot = gr.ChatInterface( stream, examples=["hello", "hi"], cache_examples=True ) prediction_hello = chatbot.examples_handler.load_from_cache(0) prediction_hi = chatbot.examples_handler.load_from_cache(1) assert prediction_hello[0].root[0] == ("hello", "hello") assert prediction_hi[0].root[0] == ("hi", "hi") def test_example_caching_with_streaming_async(self, monkeypatch): with patch( "gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()) ): chatbot = gr.ChatInterface( async_stream, examples=["hello", "hi"], cache_examples=True ) prediction_hello = chatbot.examples_handler.load_from_cache(0) prediction_hi = chatbot.examples_handler.load_from_cache(1) assert prediction_hello[0].root[0] == ("hello", "hello") assert prediction_hi[0].root[0] == ("hi", "hi") def test_default_accordion_params(self): chatbot = gr.ChatInterface( echo_system_prompt_plus_message, additional_inputs=["textbox", "slider"], ) accordion = [ comp for comp in chatbot.blocks.values() if comp.get_config().get("name") == "accordion" ][0] assert accordion.get_config().get("open") is False assert accordion.get_config().get("label") == "Additional Inputs" def test_setting_accordion_params(self, monkeypatch): chatbot = gr.ChatInterface( echo_system_prompt_plus_message, additional_inputs=["textbox", "slider"], additional_inputs_accordion=gr.Accordion(open=True, label="MOAR"), ) accordion = [ comp for comp in chatbot.blocks.values() if comp.get_config().get("name") == "accordion" ][0] assert accordion.get_config().get("open") is True assert accordion.get_config().get("label") == "MOAR" def test_example_caching_with_additional_inputs(self, monkeypatch): with patch( "gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()) ): chatbot = gr.ChatInterface( echo_system_prompt_plus_message, additional_inputs=["textbox", "slider"], examples=[["hello", "robot", 100], ["hi", "robot", 2]], cache_examples=True, ) prediction_hello = chatbot.examples_handler.load_from_cache(0) prediction_hi = chatbot.examples_handler.load_from_cache(1) assert prediction_hello[0].root[0] == ("hello", "robot hello") assert prediction_hi[0].root[0] == ("hi", "ro") def test_example_caching_with_additional_inputs_already_rendered(self, monkeypatch): with patch( "gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()) ): with gr.Blocks(): with gr.Accordion("Inputs"): text = gr.Textbox() slider = gr.Slider() chatbot = gr.ChatInterface( echo_system_prompt_plus_message, additional_inputs=[text, slider], examples=[["hello", "robot", 100], ["hi", "robot", 2]], cache_examples=True, ) prediction_hello = chatbot.examples_handler.load_from_cache(0) prediction_hi = chatbot.examples_handler.load_from_cache(1) assert prediction_hello[0].root[0] == ("hello", "robot hello") assert prediction_hi[0].root[0] == ("hi", "ro") class TestAPI: def test_get_api_info(self): chatbot = gr.ChatInterface(double) api_info = chatbot.get_api_info() assert len(api_info["named_endpoints"]) == 1 assert len(api_info["unnamed_endpoints"]) == 0 assert "/chat" in api_info["named_endpoints"] def test_streaming_api(self, connect): chatbot = gr.ChatInterface(stream).queue() with connect(chatbot) as client: job = client.submit("hello") wait([job]) assert job.outputs() == ["h", "he", "hel", "hell", "hello"] def test_streaming_api_async(self, connect): chatbot = gr.ChatInterface(async_stream).queue() with connect(chatbot) as client: job = client.submit("hello") wait([job]) assert job.outputs() == ["h", "he", "hel", "hell", "hello"] def test_non_streaming_api(self, connect): chatbot = gr.ChatInterface(double) with connect(chatbot) as client: result = client.predict("hello") assert result == "hello hello" def test_non_streaming_api_async(self, connect): chatbot = gr.ChatInterface(async_greet) with connect(chatbot) as client: result = client.predict("gradio") assert result == "hi, gradio" def test_streaming_api_with_additional_inputs(self, connect): chatbot = gr.ChatInterface( echo_system_prompt_plus_message, additional_inputs=["textbox", "slider"], ).queue() with connect(chatbot) as client: job = client.submit("hello", "robot", 7) wait([job]) assert job.outputs() == [ "r", "ro", "rob", "robo", "robot", "robot ", "robot h", ]