bowdbeg commited on
Commit
6d01d6a
·
1 Parent(s): bdc074b

implemented

Browse files
Files changed (3) hide show
  1. .gitignore +133 -0
  2. __main__.py +56 -0
  3. patch_series.py +99 -22
.gitignore ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .vscode
2
+ data/
3
+ output/
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ pip-wheel-metadata/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
99
+ __pypackages__/
100
+
101
+ # Celery stuff
102
+ celerybeat-schedule
103
+ celerybeat.pid
104
+
105
+ # SageMath parsed files
106
+ *.sage.py
107
+
108
+ # Environments
109
+ .env
110
+ .venv
111
+ env/
112
+ venv/
113
+ ENV/
114
+ env.bak/
115
+ venv.bak/
116
+
117
+ # Spyder project settings
118
+ .spyderproject
119
+ .spyproject
120
+
121
+ # Rope project settings
122
+ .ropeproject
123
+
124
+ # mkdocs documentation
125
+ /site
126
+
127
+ # mypy
128
+ .mypy_cache/
129
+ .dmypy.json
130
+ dmypy.json
131
+
132
+ # Pyre type checker
133
+ .pyre/
__main__.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import time
4
+ from argparse import ArgumentParser
5
+
6
+ import evaluate
7
+ import numpy as np
8
+
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+ parser = ArgumentParser(
12
+ description="Compute the matching series score between two time series freezed in a numpy array"
13
+ )
14
+ parser.add_argument("predictions", type=str, help="Path to the numpy array containing the predictions")
15
+ parser.add_argument("references", type=str, help="Path to the numpy array containing the references")
16
+ parser.add_argument("--output", type=str, help="Path to the output file")
17
+ parser.add_argument("--batch_size", type=int, help="Batch size to use for the computation")
18
+ parser.add_argument("--num_processes", type=int, help="Batch size to use for the computation", default=1)
19
+ parser.add_argument("--dtype", type=str, help="Data type to use for the computation", default="float32")
20
+ parser.add_argument("--debug", action="store_true", help="Debug mode")
21
+ args = parser.parse_args()
22
+
23
+ if not args.predictions or not args.references:
24
+ raise ValueError("You must provide the path to the predictions and references numpy arrays")
25
+
26
+
27
+ predictions = np.load(args.predictions).astype(args.dtype)
28
+ references = np.load(args.references).astype(args.dtype)
29
+
30
+ if args.debug:
31
+ predictions = predictions[:1000]
32
+ references = references[:1000]
33
+
34
+ logger.info(f"predictions shape: {predictions.shape}")
35
+ logger.info(f"references shape: {references.shape}")
36
+
37
+ import patch_series
38
+
39
+ s = time.time()
40
+ metric = patch_series.patch_series()
41
+ # metric = evaluate.load("patch_series.py")
42
+ results = metric.compute(
43
+ predictions=predictions,
44
+ references=references,
45
+ batch_size=args.batch_size,
46
+ num_processes=args.num_process,
47
+ return_each_features=True,
48
+ return_coverages=True,
49
+ dtype=args.dtype,
50
+ )
51
+ logger.info(f"Time taken: {time.time() - s}")
52
+
53
+ print(json.dumps(results))
54
+ if args.output:
55
+ with open(args.output, "w") as f:
56
+ json.dump(results, f)
patch_series.py CHANGED
@@ -13,9 +13,14 @@
13
  # limitations under the License.
14
  """TODO: Add a description here."""
15
 
16
- import evaluate
 
 
17
  import datasets
 
 
18
 
 
19
 
20
  # TODO: Add BibTeX citation
21
  _CITATION = """\
@@ -53,13 +58,13 @@ Examples:
53
  {'accuracy': 1.0}
54
  """
55
 
56
- # TODO: Define external resources urls if needed
57
- BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
58
-
59
 
60
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
61
  class patch_series(evaluate.Metric):
62
- """TODO: Short description of my evaluation module."""
 
 
 
63
 
64
  def _info(self):
65
  # TODO: Specifies the evaluate.EvaluationModuleInfo object
@@ -70,26 +75,98 @@ class patch_series(evaluate.Metric):
70
  citation=_CITATION,
71
  inputs_description=_KWARGS_DESCRIPTION,
72
  # This defines the format of each prediction and reference
73
- features=datasets.Features({
74
- 'predictions': datasets.Value('int64'),
75
- 'references': datasets.Value('int64'),
76
- }),
 
 
77
  # Homepage of the module for documentation
78
  homepage="http://module.homepage",
79
  # Additional links to the codebase or references
80
  codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
81
- reference_urls=["http://path.to.reference.url/new_module"]
82
  )
83
 
84
- def _download_and_prepare(self, dl_manager):
85
- """Optional: download external resources useful to compute the scores"""
86
- # TODO: Download external resources if needed
87
- pass
88
-
89
- def _compute(self, predictions, references):
90
- """Returns the scores"""
91
- # TODO: Compute the different scores of the module
92
- accuracy = sum(i == j for i, j in zip(predictions, references)) / len(predictions)
93
- return {
94
- "accuracy": accuracy,
95
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # limitations under the License.
14
  """TODO: Add a description here."""
15
 
16
+ import logging
17
+ from typing import List, Optional, Union
18
+
19
  import datasets
20
+ import evaluate
21
+ import numpy as np
22
 
23
+ logger = logging.getLogger(__name__)
24
 
25
  # TODO: Add BibTeX citation
26
  _CITATION = """\
 
58
  {'accuracy': 1.0}
59
  """
60
 
 
 
 
61
 
62
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
63
  class patch_series(evaluate.Metric):
64
+
65
+ def __init__(self, *args, **kwargs):
66
+ super().__init__(*args, **kwargs)
67
+ self.matching_series_metric = evaluate.load("bowdbeg/matching_series")
68
 
69
  def _info(self):
70
  # TODO: Specifies the evaluate.EvaluationModuleInfo object
 
75
  citation=_CITATION,
76
  inputs_description=_KWARGS_DESCRIPTION,
77
  # This defines the format of each prediction and reference
78
+ features=datasets.Features(
79
+ {
80
+ "predictions": datasets.Value("int64"),
81
+ "references": datasets.Value("int64"),
82
+ }
83
+ ),
84
  # Homepage of the module for documentation
85
  homepage="http://module.homepage",
86
  # Additional links to the codebase or references
87
  codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
88
+ reference_urls=["http://path.to.reference.url/new_module"],
89
  )
90
 
91
+ def compute(self, *, predictions=None, references=None, **kwargs) -> Optional[dict]:
92
+ """"""
93
+ all_kwargs = {"predictions": predictions, "references": references, **kwargs}
94
+ if predictions is None and references is None:
95
+ missing_kwargs = {k: None for k in self._feature_names() if k not in all_kwargs}
96
+ all_kwargs.update(missing_kwargs)
97
+ else:
98
+ missing_inputs = [k for k in self._feature_names() if k not in all_kwargs]
99
+ if missing_inputs:
100
+ raise ValueError(
101
+ f"Evaluation module inputs are missing: {missing_inputs}. All required inputs are {list(self._feature_names())}"
102
+ )
103
+ inputs = {input_name: all_kwargs[input_name] for input_name in self._feature_names()}
104
+ compute_kwargs = {k: kwargs[k] for k in kwargs if k not in self._feature_names()}
105
+ return self._compute(**inputs, **compute_kwargs)
106
+
107
+ def _compute(
108
+ self,
109
+ predictions: Union[List, np.ndarray],
110
+ references: Union[List, np.ndarray],
111
+ patch_length: List[int] = [1],
112
+ strides: Union[List[int], None] = None,
113
+ **kwargs,
114
+ ):
115
+ """Compute the evaluation score for bowdbeg/matching_series for each patch and take mean."""
116
+ if strides is None:
117
+ strides = patch_length
118
+ assert len(patch_length) == len(strides), "The patch_length and strides should have the same length."
119
+ predictions = np.array(predictions)
120
+ references = np.array(references)
121
+ if not all(predictions.shape[1] % p == 0 for p in patch_length) and not all(
122
+ references.shape[1] % p == 0 for p in patch_length
123
+ ):
124
+ raise ValueError("The patch_length should divide the length of the predictions and references.")
125
+ if len(predictions.shape) != 3:
126
+ raise ValueError("Predictions should have shape (batch_size, sequence_length, num_features)")
127
+ if len(patch_length) == 0:
128
+ raise ValueError("The patch_length should be a list of integers.")
129
+ res_sum: Union[None, dict] = None
130
+ orig_pred_shape = predictions.shape
131
+ orig_ref_shape = references.shape
132
+ for patch, stride in zip(patch_length, strides):
133
+ # create patched predictions and references
134
+ patched_predictions = self.get_patches(predictions, patch, stride, axis=1)
135
+ patched_references = self.get_patches(references, patch, stride, axis=1)
136
+ patched_predictions = patched_predictions.reshape(-1, patch, orig_pred_shape[2])
137
+ patched_references = patched_references.reshape(-1, patch, orig_ref_shape[2])
138
+
139
+ # compute the score for each patch
140
+ res = self.matching_series_metric.compute(
141
+ predictions=patched_predictions, references=patched_references, **kwargs
142
+ )
143
+ # sum the results
144
+ if res_sum is None:
145
+ res_sum = res
146
+ else:
147
+ assert isinstance(res_sum, dict)
148
+ assert isinstance(res, dict)
149
+ for key in res_sum:
150
+ if isinstance(res_sum[key], (list, np.ndarray)):
151
+ res_sum[key] = np.array(res_sum[key]) + np.array(res[key])
152
+ elif isinstance(res_sum[key], (float, int)):
153
+ res_sum[key] += res[key]
154
+ else:
155
+ logger.warning(f"Unsupported type for key {key}: {type(res_sum[key])}")
156
+ del res_sum[key]
157
+ # take the mean of the results
158
+ assert isinstance(res_sum, dict)
159
+ for key in res_sum:
160
+ if isinstance(res_sum[key], (list, np.ndarray)):
161
+ res_sum[key] = np.array(res_sum[key]) / len(patch_length)
162
+ else:
163
+ res_sum[key] /= len(patch_length)
164
+
165
+ return res_sum
166
+
167
+ @staticmethod
168
+ def get_patches(series: np.ndarray, patch_length: int, stride: int, axis=0):
169
+ # create patched predictions and references
170
+ o = np.lib.stride_tricks.sliding_window_view(series, window_shape=patch_length, axis=axis)
171
+ o = o[::stride]
172
+ return o