|
import os |
|
|
|
from injector import Injector |
|
|
|
from taskweaver.config.config_mgt import AppConfigSource |
|
from taskweaver.logging import LoggingModule |
|
from taskweaver.memory.plugin import PluginModule, PluginRegistry |
|
|
|
|
|
def test_load_plugin_yaml(): |
|
app_injector = Injector( |
|
[LoggingModule, PluginModule], |
|
) |
|
app_config = AppConfigSource( |
|
config={ |
|
"plugin.base_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/plugins"), |
|
}, |
|
) |
|
app_injector.binder.bind(AppConfigSource, to=app_config) |
|
|
|
plugin_registry = app_injector.get(PluginRegistry) |
|
|
|
assert len(plugin_registry.registry) == 4 |
|
assert "anomaly_detection" in plugin_registry.registry |
|
assert plugin_registry.registry["anomaly_detection"].spec.name == "anomaly_detection" |
|
assert plugin_registry.registry["anomaly_detection"].spec.description.startswith( |
|
"anomaly_detection function identifies anomalies", |
|
) |
|
assert plugin_registry.registry["anomaly_detection"].impl == "anomaly_detection" |
|
assert len(plugin_registry.registry["anomaly_detection"].spec.args) == 3 |
|
assert plugin_registry.registry["anomaly_detection"].spec.args[0].name == "df" |
|
assert plugin_registry.registry["anomaly_detection"].spec.args[0].type == "DataFrame" |
|
assert ( |
|
plugin_registry.registry["anomaly_detection"].spec.args[0].description |
|
== "the input data from which we can identify the " |
|
"anomalies with the 3-sigma algorithm." |
|
) |
|
assert plugin_registry.registry["anomaly_detection"].spec.args[0].required == True |
|
|
|
assert len(plugin_registry.registry["anomaly_detection"].spec.returns) == 2 |
|
assert plugin_registry.registry["anomaly_detection"].spec.returns[0].name == "df" |
|
assert plugin_registry.registry["anomaly_detection"].spec.returns[0].type == "DataFrame" |
|
assert ( |
|
plugin_registry.registry["anomaly_detection"].spec.returns[0].description == "This DataFrame extends the input " |
|
"DataFrame with a newly-added column " |
|
'"Is_Anomaly" containing the anomaly detection result.' |
|
) |
|
|
|
|
|
def test_plugin_format_prompt(): |
|
app_injector = Injector( |
|
[PluginModule, LoggingModule], |
|
) |
|
app_config = AppConfigSource( |
|
config={ |
|
"plugin.base_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/plugins"), |
|
}, |
|
) |
|
app_injector.binder.bind(AppConfigSource, to=app_config) |
|
|
|
plugin_registry = app_injector.get(PluginRegistry) |
|
|
|
assert plugin_registry.registry["anomaly_detection"].format_prompt() == ( |
|
"# anomaly_detection function identifies anomalies from an input DataFrame of time series. It will add a new " |
|
'column "Is_Anomaly", where each entry will be marked with "True" if the value is an anomaly or "False" ' |
|
"otherwise.\n" |
|
"def anomaly_detection(\n" |
|
"# the input data from which we can identify the anomalies with the 3-sigma algorithm.\n" |
|
"df: Any,\n" |
|
"# name of the column that contains the datetime\n" |
|
"time_col_name: Any,\n" |
|
"# name of the column that contains the numeric values.\n" |
|
"value_col_name: Any) -> Tuple[\n" |
|
'# df: This DataFrame extends the input DataFrame with a newly-added column "Is_Anomaly" containing the ' |
|
"anomaly detection result.\n" |
|
"DataFrame,\n" |
|
"# description: This is a string describing the anomaly detection results.\n" |
|
"str]:...\n" |
|
) |
|
|