| | |
| |
|
| | import unittest |
| |
|
| | import mlx.core as mx |
| | import mlx.nn as nn |
| | import mlx_tests |
| | import numpy as np |
| |
|
| | try: |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | has_torch = True |
| | except ImportError as e: |
| | has_torch = False |
| |
|
| |
|
| | class TestUpsample(mlx_tests.MLXTestCase): |
| | @unittest.skipIf(not has_torch, "requires Torch") |
| | def test_torch_upsample(self): |
| | def run_upsample( |
| | N, |
| | C, |
| | idim, |
| | scale_factor, |
| | mode, |
| | align_corner, |
| | dtype="float32", |
| | atol=1e-5, |
| | ): |
| | with self.subTest( |
| | N=N, |
| | C=C, |
| | idim=idim, |
| | scale_factor=scale_factor, |
| | mode=mode, |
| | align_corner=align_corner, |
| | ): |
| | np_dtype = getattr(np, dtype) |
| | np.random.seed(0) |
| | iH, iW = idim |
| | in_np = np.random.normal(-1.0, 1.0, (N, iH, iW, C)).astype(np_dtype) |
| |
|
| | in_mx = mx.array(in_np) |
| | in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to("cpu") |
| |
|
| | out_mx = nn.Upsample( |
| | scale_factor=scale_factor, |
| | mode=mode, |
| | align_corners=align_corner, |
| | )(in_mx) |
| | mode_pt = { |
| | "nearest": "nearest", |
| | "linear": "bilinear", |
| | "cubic": "bicubic", |
| | }[mode] |
| | out_pt = F.interpolate( |
| | in_pt, |
| | scale_factor=scale_factor, |
| | mode=mode_pt, |
| | align_corners=align_corner if mode != "nearest" else None, |
| | ) |
| | out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) |
| | self.assertEqual(out_pt.shape, out_mx.shape) |
| | self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) |
| |
|
| | for dtype in ("float32",): |
| | for N, C in ((1, 1), (2, 3)): |
| | |
| | |
| | |
| | for idim, scale_factor in ( |
| | ((2, 2), (1.0, 1.0)), |
| | ((2, 2), (1.5, 1.5)), |
| | ((2, 2), (2.0, 2.0)), |
| | ((4, 4), (0.5, 0.5)), |
| | ((7, 7), (2.0, 2.0)), |
| | ((10, 10), (0.2, 0.2)), |
| | ((10, 10), (0.3, 0.3)), |
| | ((11, 21), (3.0, 3.0)), |
| | ((11, 21), (3.0, 2.0)), |
| | ): |
| | for mode in ("cubic", "linear", "nearest"): |
| | for align_corner in (False, True): |
| | if mode == "nearest" and align_corner: |
| | continue |
| | run_upsample( |
| | N, |
| | C, |
| | idim, |
| | scale_factor, |
| | mode, |
| | align_corner, |
| | dtype=dtype, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | mlx_tests.MLXTestRunner() |
| |
|