tskwvr / tests /unit_tests /test_plugin.py
TRaw's picture
Upload 297 files
3d3d712
import os
from injector import Injector
from taskweaver.config.config_mgt import AppConfigSource
from taskweaver.logging import LoggingModule
from taskweaver.memory.plugin import PluginModule, PluginRegistry
def test_load_plugin_yaml():
app_injector = Injector(
[LoggingModule, PluginModule],
)
app_config = AppConfigSource(
config={
"plugin.base_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/plugins"),
},
)
app_injector.binder.bind(AppConfigSource, to=app_config)
plugin_registry = app_injector.get(PluginRegistry)
assert len(plugin_registry.registry) == 4
assert "anomaly_detection" in plugin_registry.registry
assert plugin_registry.registry["anomaly_detection"].spec.name == "anomaly_detection"
assert plugin_registry.registry["anomaly_detection"].spec.description.startswith(
"anomaly_detection function identifies anomalies",
)
assert plugin_registry.registry["anomaly_detection"].impl == "anomaly_detection"
assert len(plugin_registry.registry["anomaly_detection"].spec.args) == 3
assert plugin_registry.registry["anomaly_detection"].spec.args[0].name == "df"
assert plugin_registry.registry["anomaly_detection"].spec.args[0].type == "DataFrame"
assert (
plugin_registry.registry["anomaly_detection"].spec.args[0].description
== "the input data from which we can identify the "
"anomalies with the 3-sigma algorithm."
)
assert plugin_registry.registry["anomaly_detection"].spec.args[0].required == True
assert len(plugin_registry.registry["anomaly_detection"].spec.returns) == 2
assert plugin_registry.registry["anomaly_detection"].spec.returns[0].name == "df"
assert plugin_registry.registry["anomaly_detection"].spec.returns[0].type == "DataFrame"
assert (
plugin_registry.registry["anomaly_detection"].spec.returns[0].description == "This DataFrame extends the input "
"DataFrame with a newly-added column "
'"Is_Anomaly" containing the anomaly detection result.'
)
def test_plugin_format_prompt():
app_injector = Injector(
[PluginModule, LoggingModule],
)
app_config = AppConfigSource(
config={
"plugin.base_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/plugins"),
},
)
app_injector.binder.bind(AppConfigSource, to=app_config)
plugin_registry = app_injector.get(PluginRegistry)
assert plugin_registry.registry["anomaly_detection"].format_prompt() == (
"# anomaly_detection function identifies anomalies from an input DataFrame of time series. It will add a new "
'column "Is_Anomaly", where each entry will be marked with "True" if the value is an anomaly or "False" '
"otherwise.\n"
"def anomaly_detection(\n"
"# the input data from which we can identify the anomalies with the 3-sigma algorithm.\n"
"df: Any,\n"
"# name of the column that contains the datetime\n"
"time_col_name: Any,\n"
"# name of the column that contains the numeric values.\n"
"value_col_name: Any) -> Tuple[\n"
'# df: This DataFrame extends the input DataFrame with a newly-added column "Is_Anomaly" containing the '
"anomaly detection result.\n"
"DataFrame,\n"
"# description: This is a string describing the anomaly detection results.\n"
"str]:...\n"
)