# Copyright 2022 The T5X Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for state_utils.""" import re from absl.testing import absltest from absl.testing import parameterized import numpy as np from t5x import state_utils class StateUtilsTest(parameterized.TestCase): @parameterized.parameters( dict( state_dict={"a": { "b": 2, "c": 3 }}, intersect_state_dict={ "a": { "b": 4 }, "d": 5 }, expect_state={"a": { "b": 2 }})) def test_intersect_state(self, state_dict, intersect_state_dict, expect_state): actual_state = state_utils.intersect_state(state_dict, intersect_state_dict) self.assertEqual(actual_state, expect_state) @parameterized.parameters( dict( state_dict={"a": { "b": 2, "c": 3 }}, merge_state_dict={ "a": { "b": 4 }, "d": 5 }, expect_state={ "a": { "b": 2, "c": 3 }, "d": 5 })) def test_merge_state(self, state_dict, merge_state_dict, expect_state): actual_state = state_utils.merge_state(state_dict, merge_state_dict) self.assertEqual(actual_state, expect_state) def test_tensorstore_leaf(self): leaf = { "driver": "zarr", "kvstore": { "driver": "gfile", "path": "target.bias" }, "metadata": { "chunks": [4, 1], "compressor": { "id": "gzip", "level": 1 }, "dtype": "