Spaces:
Sleeping
Sleeping
| import unittest | |
| import torch | |
| from src.data.utils import ( | |
| S2_BANDS, | |
| SPACE_TIME_BANDS, | |
| SPACE_TIME_BANDS_GROUPS_IDX, | |
| construct_galileo_input, | |
| ) | |
| class TestDataUtils(unittest.TestCase): | |
| def test_construct_galileo_input_s2(self): | |
| t, h, w = 2, 4, 4 | |
| s2 = torch.randn((t, h, w, len(S2_BANDS))) | |
| for normalize in [True, False]: | |
| masked_output = construct_galileo_input(s2=s2, normalize=normalize) | |
| self.assertTrue((masked_output.space_mask == 1).all()) | |
| self.assertTrue((masked_output.time_mask == 1).all()) | |
| self.assertTrue((masked_output.static_mask == 1).all()) | |
| # check that only the s2 bands got unmasked | |
| not_s2 = [ | |
| idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S2" not in key | |
| ] | |
| self.assertTrue((masked_output.space_time_mask[:, :, :, not_s2] == 1).all()) | |
| # and that s2 got unmasked | |
| s2_mask_indices = [ | |
| idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S2" in key | |
| ] | |
| self.assertTrue((masked_output.space_time_mask[:, :, :, s2_mask_indices] == 0).all()) | |
| # and got assigned to the right indices | |
| if not normalize: | |
| s2_indices = [idx for idx, val in enumerate(SPACE_TIME_BANDS) if val in S2_BANDS] | |
| self.assertTrue(torch.equal(masked_output.space_time_x[:, :, :, s2_indices], s2)) | |