Spaces:
Sleeping
Sleeping
finished
Browse files- sklearn_proxy.py +16 -1
- tests.py +18 -5
sklearn_proxy.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16 |
import evaluate
|
17 |
import datasets
|
18 |
from sklearn.metrics import get_scorer
|
|
|
19 |
|
20 |
|
21 |
# TODO: Add BibTeX citation
|
@@ -61,6 +62,9 @@ BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
|
|
61 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
62 |
class SklearnProxy(evaluate.Metric):
|
63 |
"""TODO: Short description of my evaluation module."""
|
|
|
|
|
|
|
64 |
|
65 |
def _info(self):
|
66 |
# TODO: Specifies the evaluate.EvaluationModuleInfo object
|
@@ -74,6 +78,7 @@ class SklearnProxy(evaluate.Metric):
|
|
74 |
features=datasets.Features({
|
75 |
'predictions': datasets.Value('int64'),
|
76 |
'references': datasets.Value('int64'),
|
|
|
77 |
}),
|
78 |
# Homepage of the module for documentation
|
79 |
homepage="http://module.homepage",
|
@@ -89,4 +94,14 @@ class SklearnProxy(evaluate.Metric):
|
|
89 |
|
90 |
def _compute(self, predictions, references, metric_name="accuracy", **kwargs):
|
91 |
scorer = get_scorer(metric_name)
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
import evaluate
|
17 |
import datasets
|
18 |
from sklearn.metrics import get_scorer
|
19 |
+
from sklearn.base import BaseEstimator
|
20 |
|
21 |
|
22 |
# TODO: Add BibTeX citation
|
|
|
62 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
63 |
class SklearnProxy(evaluate.Metric):
|
64 |
"""TODO: Short description of my evaluation module."""
|
65 |
+
def __init__(self, **kwargs):
|
66 |
+
super().__init__(**kwargs)
|
67 |
+
self.dummy_estimator = PassThroughtEstimator()
|
68 |
|
69 |
def _info(self):
|
70 |
# TODO: Specifies the evaluate.EvaluationModuleInfo object
|
|
|
78 |
features=datasets.Features({
|
79 |
'predictions': datasets.Value('int64'),
|
80 |
'references': datasets.Value('int64'),
|
81 |
+
|
82 |
}),
|
83 |
# Homepage of the module for documentation
|
84 |
homepage="http://module.homepage",
|
|
|
94 |
|
95 |
def _compute(self, predictions, references, metric_name="accuracy", **kwargs):
|
96 |
scorer = get_scorer(metric_name)
|
97 |
+
|
98 |
+
return {metric_name: scorer(self.dummy_estimator, references, predictions, **kwargs)}
|
99 |
+
|
100 |
+
|
101 |
+
class PassThroughtEstimator(BaseEstimator):
|
102 |
+
def __init__(self):
|
103 |
+
pass
|
104 |
+
def fit(self, X, y):
|
105 |
+
return self
|
106 |
+
def predict(self, X):
|
107 |
+
return X
|
tests.py
CHANGED
@@ -1,17 +1,30 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
2 |
{
|
3 |
"predictions": [0, 0],
|
4 |
"references": [1, 1],
|
5 |
-
"result": {"
|
6 |
},
|
7 |
{
|
8 |
"predictions": [1, 1],
|
9 |
"references": [1, 1],
|
10 |
-
"result": {"
|
11 |
},
|
12 |
{
|
13 |
"predictions": [1, 0],
|
14 |
"references": [1, 1],
|
15 |
-
"result": {"
|
16 |
}
|
17 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn_proxy import SklearnProxy
|
2 |
+
import unittest
|
3 |
+
|
4 |
+
|
5 |
+
accuracy_test_cases = [
|
6 |
{
|
7 |
"predictions": [0, 0],
|
8 |
"references": [1, 1],
|
9 |
+
"result": {"accuracy": 0.0}
|
10 |
},
|
11 |
{
|
12 |
"predictions": [1, 1],
|
13 |
"references": [1, 1],
|
14 |
+
"result": {"accuracy": 1.0}
|
15 |
},
|
16 |
{
|
17 |
"predictions": [1, 0],
|
18 |
"references": [1, 1],
|
19 |
+
"result": {"accuracy": 0.5}
|
20 |
}
|
21 |
+
]
|
22 |
+
|
23 |
+
|
24 |
+
class TestGeneral(unittest.TestCase):
|
25 |
+
|
26 |
+
def test_accuracy(self):
|
27 |
+
metric = SklearnProxy()
|
28 |
+
for test_case in accuracy_test_cases:
|
29 |
+
result = metric.compute(predictions=test_case["predictions"],references=test_case["references"], metric_name="accuracy")
|
30 |
+
self.assertEqual(result, test_case["result"])
|