alessandro trinca tornidor commited on
Commit
c1390d7
1 Parent(s): a6fcc1e

test: update tests

Browse files
aip_trainer/lambdas/lambdaTTS.py CHANGED
@@ -5,7 +5,7 @@ from pathlib import Path
5
  from aip_trainer import app_logger
6
 
7
 
8
- def get_tts(text: str, language: str):
9
  from aip_trainer.models import models
10
 
11
  if text is None or len(text) == 0:
@@ -24,7 +24,7 @@ def get_tts(text: str, language: str):
24
  )
25
  app_logger.info(f"model speaker #0: {speaker} ...")
26
 
27
- with tempfile.NamedTemporaryFile(prefix="audio_", suffix=".wav", delete=False) as tmp_audio_file:
28
  app_logger.info(f"tmp_audio_file output: {tmp_audio_file.name} ...")
29
  audio_paths = model.save_wav(text=text, speaker=speaker, sample_rate=sample_rate, audio_path=str(tmp_audio_file.name))
30
  app_logger.info(f"audio_paths output: {audio_paths} ...")
 
5
  from aip_trainer import app_logger
6
 
7
 
8
+ def get_tts(text: str, language: str, tmp_prefix="audio_", tmp_suffix=".wav") -> str:
9
  from aip_trainer.models import models
10
 
11
  if text is None or len(text) == 0:
 
24
  )
25
  app_logger.info(f"model speaker #0: {speaker} ...")
26
 
27
+ with tempfile.NamedTemporaryFile(prefix=tmp_prefix, suffix=tmp_suffix, delete=False) as tmp_audio_file:
28
  app_logger.info(f"tmp_audio_file output: {tmp_audio_file.name} ...")
29
  audio_paths = model.save_wav(text=text, speaker=speaker, sample_rate=sample_rate, audio_path=str(tmp_audio_file.name))
30
  app_logger.info(f"audio_paths output: {audio_paths} ...")
aip_trainer/utils/serialize.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Serialize objects"""
2
+ from typing import Mapping
3
+
4
+ from aip_trainer import app_logger
5
+
6
+
7
+ def serialize(obj: any, include_none: bool = False):
8
+ """
9
+ Return the input object into a serializable one
10
+
11
+ Args:
12
+ obj: Object to serialize
13
+ include_none: bool to indicate if include also keys with None values during dict serialization
14
+
15
+ Returns:
16
+ serialized object
17
+ """
18
+ return _serialize(obj, include_none)
19
+
20
+
21
+ def _serialize(obj: any, include_none: bool):
22
+ from numpy import ndarray as np_ndarray, floating as np_floating, integer as np_integer
23
+
24
+ primitive = (int, float, str, bool)
25
+ # print(type(obj))
26
+ try:
27
+ if obj is None:
28
+ return None
29
+ elif isinstance(obj, np_integer):
30
+ return int(obj)
31
+ elif isinstance(obj, np_floating):
32
+ return float(obj)
33
+ elif isinstance(obj, np_ndarray):
34
+ return obj.tolist()
35
+ elif isinstance(obj, primitive):
36
+ return obj
37
+ elif type(obj) is list:
38
+ return _serialize_list(obj, include_none)
39
+ elif type(obj) is tuple:
40
+ return list(obj)
41
+ elif type(obj) is bytes:
42
+ return _serialize_bytes(obj)
43
+ elif isinstance(obj, Exception):
44
+ return _serialize_exception(obj)
45
+ # elif isinstance(obj, object):
46
+ # return _serialize_object(obj, include_none)
47
+ else:
48
+ return _serialize_object(obj, include_none)
49
+ except Exception as e_serialize:
50
+ app_logger.error(f"e_serialize::{e_serialize}, type_obj:{type(obj)}, obj:{obj}.")
51
+ return f"object_name:{str(obj)}__object_type_str:{str(type(obj))}."
52
+
53
+
54
+ def _serialize_object(obj: Mapping[any, object], include_none: bool) -> dict[any]:
55
+ from bson import ObjectId
56
+
57
+ res = {}
58
+ if type(obj) is not dict:
59
+ keys = [i for i in obj.__dict__.keys() if (getattr(obj, i) is not None) or include_none]
60
+ else:
61
+ keys = [i for i in obj.keys() if (obj[i] is not None) or include_none]
62
+ for key in keys:
63
+ if type(obj) is not dict:
64
+ res[key] = _serialize(getattr(obj, key), include_none)
65
+ elif isinstance(obj[key], ObjectId):
66
+ continue
67
+ else:
68
+ res[key] = _serialize(obj[key], include_none)
69
+ return res
70
+
71
+
72
+ def _serialize_list(ls: list, include_none: bool) -> list:
73
+ return [_serialize(elem, include_none) for elem in ls]
74
+
75
+
76
+ def _serialize_bytes(b: bytes) -> dict[str, str]:
77
+ import base64
78
+ encoded = base64.b64encode(b)
79
+ return {"value": encoded.decode('ascii'), "type": "bytes"}
80
+
81
+
82
+ def _serialize_exception(e: Exception) -> dict[str, str]:
83
+ return {"msg": str(e), "type": str(type(e)), **e.__dict__}
aip_trainer/utils/utilities.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Various utilities (logger, time benchmark, args dump, numerical and stats info)"""
2
+
3
+ from copy import deepcopy
4
+ from aip_trainer import app_logger
5
+ from aip_trainer.utils.serialize import serialize
6
+
7
+
8
+ def hash_calculate(arr_or_path, is_file: bool, read_mode: str = "rb") -> str | bytes:
9
+ """
10
+ Return computed hash from input variable (typically a numpy array).
11
+
12
+ Args:
13
+ arr: input variable
14
+
15
+ Returns:
16
+ computed hash from input variable
17
+ """
18
+ from hashlib import sha256
19
+ from base64 import b64encode
20
+ from numpy import ndarray as np_ndarray
21
+
22
+ if is_file:
23
+ with open(arr_or_path, read_mode) as file_to_check:
24
+ # read contents of the file
25
+ arr_or_path = file_to_check.read()
26
+ # # pipe contents of the file through
27
+ # try:
28
+ # return hashlib.sha256(data).hexdigest()
29
+ # except TypeError:
30
+ # app_logger.warning(
31
+ # f"TypeError, re-try encoding arg:{arr_or_path},type:{type(arr_or_path)}."
32
+ # )
33
+ # return hashlib.sha256(data.encode("utf-8")).hexdigest()
34
+
35
+ if isinstance(arr_or_path, np_ndarray):
36
+ hash_fn = sha256(arr_or_path.data)
37
+ elif isinstance(arr_or_path, dict):
38
+ import json
39
+
40
+ serialized = serialize(arr_or_path)
41
+ variable_to_hash = json.dumps(serialized, sort_keys=True).encode("utf-8")
42
+ hash_fn = sha256(variable_to_hash)
43
+ elif isinstance(arr_or_path, str):
44
+ try:
45
+ hash_fn = sha256(arr_or_path)
46
+ except TypeError:
47
+ app_logger.warning(
48
+ f"TypeError, re-try encoding arg:{arr_or_path},type:{type(arr_or_path)}."
49
+ )
50
+ hash_fn = sha256(arr_or_path.encode("utf-8"))
51
+ elif isinstance(arr_or_path, bytes):
52
+ hash_fn = sha256(arr_or_path)
53
+ else:
54
+ raise ValueError(
55
+ f"variable 'arr':{arr_or_path} of type '{type(arr_or_path)}' not yet handled."
56
+ )
57
+ return b64encode(hash_fn.digest())
tests/test_lambdaTTS.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import unittest
3
+
4
+ from aip_trainer import app_logger
5
+
6
+
7
+ def exec_test_lambda_tts(text, language, expected_hash):
8
+ import random
9
+ from aip_trainer.lambdas import lambdaTTS
10
+ from aip_trainer.utils import utilities
11
+
12
+ tmp_rnd = str(random.random())
13
+ tmp_prefix = f"test_lambdaTTS_{language}_ok_{tmp_rnd}_"
14
+ tmp_suffix = ".wav"
15
+ output = lambdaTTS.get_tts(
16
+ text, language, tmp_prefix=tmp_prefix, tmp_suffix=tmp_suffix
17
+ )
18
+ assert tmp_prefix in output
19
+ assert tmp_suffix in output
20
+ assert os.path.exists(output) and os.path.isfile(output)
21
+ output_hash = utilities.hash_calculate(output, is_file=True, read_mode="rb")
22
+ app_logger.info(f"output_hash '{text}', '{language}' => {output_hash}")
23
+ assert expected_hash == output_hash
24
+ os.unlink(output)
25
+
26
+
27
+ def assert_raises_get_tts(
28
+ self, real_text, language, exc, error_message
29
+ ):
30
+ from aip_trainer.lambdas import lambdaTTS
31
+
32
+ with self.assertRaises(exc):
33
+ try:
34
+ lambdaTTS.get_tts(real_text, language)
35
+ except exc as e:
36
+ self.assertEqual(str(e), error_message)
37
+ raise e
38
+
39
+
40
+ class TestLambdaTTS(unittest.TestCase):
41
+ def test_lambdaTTS_en_ok(self):
42
+ exec_test_lambda_tts(
43
+ "Hi there, how are you?",
44
+ "en",
45
+ b'vf1QWORrGWlvCEQuvI0fajGG7iz2Zqp83p8dVH8pZtY='
46
+ )
47
+
48
+ def test_lambdaTTS_de_ok(self):
49
+ exec_test_lambda_tts(
50
+ "Ich bin Alex!",
51
+ "de",
52
+ b'jkvM+0Whlb1nwf9eiuDyoQZ1ekzb46k3DcW2glYh0YY='
53
+ )
54
+
55
+ def test_lambdaTTS_empty_text(self):
56
+ assert_raises_get_tts(self, "", "fake language", ValueError, "cannot read an empty/None text: ''...")
57
+
58
+ def test_lambdaTTS_empty_language(self):
59
+ assert_raises_get_tts(self, "fake text", "", NotImplementedError, "Not tested/supported with '' language...")
60
+
61
+
62
+ if __name__ == "__main__":
63
+ unittest.main()
tests/test_serialize.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ import numpy as np
4
+
5
+ from aip_trainer.utils.serialize import serialize
6
+
7
+
8
+ test_dict_list_dict = {
9
+ "type": "FeatureCollection",
10
+ "name": "volcanoes",
11
+ "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}},
12
+ "features": [
13
+ {"type": "Feature", "properties": {"Volcano_Number": 283010, "Volcano_Name": "Izu-Tobu", "prop_none": None},
14
+ "geometry": {"type": "Point", "coordinates": [139.098, 34.9]}},
15
+ {"type": "Feature",
16
+ "properties": {"Volcano_Number": 283020, "Volcano_Name": "Hakoneyama", "ndarray": np.array([1])},
17
+ "geometry": {"type": "Point", "coordinates": [139.021, 35.233]}}
18
+ ]
19
+ }
20
+
21
+
22
+ class TestSerialize(unittest.TestCase):
23
+ def test_serialize(self):
24
+ from bson import ObjectId
25
+
26
+ # remove keys with values as bson.ObjectId
27
+ d1 = {"_id": ObjectId()}
28
+ self.assertDictEqual(serialize(d1), dict())
29
+
30
+ # test: serialize nd.float*, number as key => str
31
+ np_int_4 = np.asarray([87], dtype=np.integer)[0]
32
+ d2 = {"b": np.float32(45.0), 3: 33, 1.56: np_int_4, 3.5: 44.0, "d": "b", "tuple": (1, 2)}
33
+ expected_d2 = {
34
+ 'b': 45.0,
35
+ 3: 33,
36
+ 1.56: 87,
37
+ 3.5: 44.0,
38
+ 'd': 'b',
39
+ "tuple": [1, 2]
40
+ }
41
+ serialized_d2 = serialize(d2)
42
+ self.assertDictEqual(serialized_d2, expected_d2)
43
+
44
+ # # nested dict of list of dict, serialize nd.array
45
+ d3 = {"e": [{"q": 123}, {"q": 456}], "a": np.arange(1.1, 16.88).reshape(4, 4)}
46
+ expected_d3 = {
47
+ "e": [{"q": 123}, {"q": 456}],
48
+ 'a': [[1.1, 2.1, 3.1, 4.1], [5.1, 6.1, 7.1, 8.1], [9.1, 10.1, 11.1, 12.1], [13.1, 14.1, 15.1, 16.1]]
49
+ }
50
+ self.assertDictEqual(serialize(d3), expected_d3)
51
+
52
+ def test_serialize_dict_exception(self):
53
+ from json import JSONDecodeError
54
+
55
+ e = JSONDecodeError(msg="x", doc="what we are?", pos=111)
56
+ exception = serialize({"k": e})
57
+ self.assertDictEqual(
58
+ exception,
59
+ {'k': {'msg': 'x', 'type': "<class 'json.decoder.JSONDecodeError'>", 'doc': 'what we are?', 'pos': 111,
60
+ 'lineno': 1, 'colno': 112}}
61
+ )
62
+
63
+ def test_serialize_bytes(self):
64
+ self.assertDictEqual(
65
+ serialize({"k": b"x"}),
66
+ {'k': {'value': 'eA==', 'type': 'bytes'}}
67
+ )
68
+
69
+ def test_serialize_dict_list_dict(self):
70
+ serialized_dict_no_none = serialize(test_dict_list_dict, include_none=False)
71
+ self.assertDictEqual(serialized_dict_no_none, {
72
+ 'type': 'FeatureCollection',
73
+ 'name': 'volcanoes',
74
+ 'crs': {'type': 'name', 'properties': {'name': 'urn:ogc:def:crs:OGC:1.3:CRS84'}},
75
+ 'features': [
76
+ {'type': 'Feature', 'properties': {'Volcano_Number': 283010, 'Volcano_Name': 'Izu-Tobu'},
77
+ 'geometry': {'type': 'Point', 'coordinates': [139.098, 34.9]}},
78
+ {'type': 'Feature',
79
+ 'properties': {'Volcano_Number': 283020, 'Volcano_Name': 'Hakoneyama', 'ndarray': [1]},
80
+ 'geometry': {'type': 'Point', 'coordinates': [139.021, 35.233]}}
81
+ ]
82
+ })
83
+
84
+ serialized_dict_wiht_none = serialize(test_dict_list_dict, include_none=True)
85
+ self.assertDictEqual(serialized_dict_wiht_none, {
86
+ 'type': 'FeatureCollection',
87
+ 'name': 'volcanoes',
88
+ 'crs': {'type': 'name', 'properties': {'name': 'urn:ogc:def:crs:OGC:1.3:CRS84'}},
89
+ 'features': [
90
+ {'type': 'Feature',
91
+ 'properties': {'Volcano_Number': 283010, 'Volcano_Name': 'Izu-Tobu', 'prop_none': None},
92
+ 'geometry': {'type': 'Point', 'coordinates': [139.098, 34.9]}},
93
+ {'type': 'Feature',
94
+ 'properties': {'Volcano_Number': 283020, 'Volcano_Name': 'Hakoneyama', 'ndarray': [1]},
95
+ 'geometry': {'type': 'Point', 'coordinates': [139.021, 35.233]}}
96
+ ]
97
+ })
tests/test_utilities.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import numpy as np
3
+
4
+ from tests import EVENTS_FOLDER
5
+ from aip_trainer import app_logger
6
+
7
+
8
+ class TestUtilities(unittest.TestCase):
9
+ def test_hash_calculate_not_file(self):
10
+ from aip_trainer.utils.utilities import hash_calculate
11
+
12
+ size = 5
13
+ input_arr = np.arange(size**2).reshape((size, size))
14
+ hash_output = hash_calculate(input_arr, is_file=False)
15
+ self.assertEqual(hash_output, b'KgoWp86FwhH2tuinWOfsCfn9d+Iw6B10wwqFfdUeLeY=')
16
+
17
+ hash_output = hash_calculate({"arr": input_arr}, is_file=False)
18
+ self.assertEqual(hash_output, b'M/EYsBPRQLVP9T299xHyOrtT7bdCkIDaMmW2hslMays=')
19
+
20
+ hash_output = hash_calculate("a test string...", is_file=False)
21
+ self.assertEqual(hash_output, b'29a8JwujQklQ6MKQhPyix6G1i/7Pp0uUg5wFybKuCC0=')
22
+
23
+ hash_output = hash_calculate("123123123", is_file=False)
24
+ self.assertEqual(hash_output, b'ky88G1YlfOhTmsJp16q0JVDaz4gY0HXwvfGZBWKq4+8=')
25
+
26
+ hash_output = hash_calculate(b"a byte test string...", is_file=False)
27
+ self.assertEqual(hash_output, b'dgSt/jiqLk0HJ09Xqe/BWzMvnYiOqzWlcSCCfN767zA=')
28
+
29
+ with self.assertRaises(ValueError):
30
+ try:
31
+ hash_calculate(1, is_file=False)
32
+ except ValueError as ve:
33
+ self.assertEqual(str(ve), "variable 'arr':1 of type '<class 'int'>' not yet handled.")
34
+ raise ve
35
+
36
+ def test_hash_calculate_is_file(self):
37
+ from aip_trainer.utils.utilities import hash_calculate
38
+
39
+ output_hash = hash_calculate(EVENTS_FOLDER / "test_en.wav", is_file=True, read_mode="rb")
40
+ app_logger.info(f"output_hash test_en: {output_hash}")
41
+ assert b'Dsvmm+mj/opHnmKLT7wIqyhqMLeIuVP4hTWi+DAXS8Y=' == output_hash
42
+
43
+ output_hash = hash_calculate(EVENTS_FOLDER / "GetAccuracyFromRecordedAudio.json", is_file=True)
44
+ app_logger.info(f"output_hash json: {output_hash}")
45
+ assert b'i83jKpwzfcPitZsrHsnhyFt8xbc+DStpns9rb3vfigw=' == output_hash