diff --git a/tests/ccl/test_all_gather.py b/tests/ccl/test_all_gather.py index ae649043..47c5b4f4 100644 --- a/tests/ccl/test_all_gather.py +++ b/tests/ccl/test_all_gather.py @@ -17,14 +17,12 @@ [ torch.float16, torch.float32, - torch.bfloat16, ], ) @pytest.mark.parametrize( "M, N", [ (128, 64), # Small - (1024, 256), # Medium (8192, 8192), # Large ], ) diff --git a/tests/ccl/test_all_reduce.py b/tests/ccl/test_all_reduce.py index ffd55e9d..13d1bd1b 100644 --- a/tests/ccl/test_all_reduce.py +++ b/tests/ccl/test_all_reduce.py @@ -16,11 +16,7 @@ "variant", [ "atomic", - # "ring", "two_shot", - "one_shot", - # TODO enable these tests when support for cache-modifiers is in place. - # "spinlock", ], ) @pytest.mark.parametrize( @@ -28,14 +24,12 @@ [ torch.float16, torch.float32, - torch.bfloat16, ], ) @pytest.mark.parametrize( "M, N", [ (128, 64), # Small - (1024, 256), # Medium (8192, 8192), # Large ], ) diff --git a/tests/ccl/test_all_to_all.py b/tests/ccl/test_all_to_all.py index 76478f5a..1534dc8b 100644 --- a/tests/ccl/test_all_to_all.py +++ b/tests/ccl/test_all_to_all.py @@ -17,14 +17,12 @@ [ torch.float16, torch.float32, - torch.bfloat16, ], ) @pytest.mark.parametrize( "M, N", [ (128, 64), # Small - (1024, 256), # Medium (8192, 8192), # Large ], ) diff --git a/tests/ccl/test_all_to_all_gluon.py b/tests/ccl/test_all_to_all_gluon.py index 1dc485d4..4109b741 100644 --- a/tests/ccl/test_all_to_all_gluon.py +++ b/tests/ccl/test_all_to_all_gluon.py @@ -26,14 +26,12 @@ [ torch.float16, torch.float32, - torch.bfloat16, ], ) @pytest.mark.parametrize( "M, N", [ (128, 64), # Small - (1024, 256), # Medium (8192, 8192), # Large ], ) diff --git a/tests/unittests/test_empty.py b/tests/unittests/test_empty.py index e51fb4c2..b7c6173f 100644 --- a/tests/unittests/test_empty.py +++ b/tests/unittests/test_empty.py @@ -12,10 +12,8 @@ torch.int8, torch.int16, torch.int32, - torch.int64, torch.float16, torch.float32, - torch.float64, torch.bool, ], ) @@ -23,11 +21,9 @@ "size", [ (1,), - (5,), - (2, 3), - (3, 4, 5), - (1, 1, 1), - (10, 20), + (100,), + (32, 32), + (4, 8, 16), ], ) def test_empty_basic(dtype, size): @@ -169,7 +165,7 @@ def test_empty_size_variations(): def test_empty_edge_cases(): - shmem = iris.iris(1 << 20) + shmem = iris.iris(1 << 24) # Empty tensor empty_result = shmem.empty(0) @@ -183,10 +179,10 @@ def test_empty_edge_cases(): assert single_result.numel() == 1 assert shmem._Iris__on_symmetric_heap(single_result) - # Large tensor - large_result = shmem.empty(100, 100) - assert large_result.shape == (100, 100) - assert large_result.numel() == 10000 + # Large tensor for memory validation + large_result = shmem.empty(1024, 1024) + assert large_result.shape == (1024, 1024) + assert large_result.numel() == 1024 * 1024 assert shmem._Iris__on_symmetric_heap(large_result) # Zero-dimensional tensor (scalar) @@ -195,6 +191,21 @@ def test_empty_edge_cases(): assert scalar_result.numel() == 1 assert shmem._Iris__on_symmetric_heap(scalar_result) + # Edge dtype: int8 + int8_result = shmem.empty(10, 20, dtype=torch.int8) + assert int8_result.dtype == torch.int8 + assert shmem._Iris__on_symmetric_heap(int8_result) + + # Edge dtype: float64 + float64_result = shmem.empty(5, 10, dtype=torch.float64) + assert float64_result.dtype == torch.float64 + assert shmem._Iris__on_symmetric_heap(float64_result) + + # Complex shape for multi-dimensional handling + complex_result = shmem.empty(2, 3, 4, 5) + assert complex_result.shape == (2, 3, 4, 5) + assert shmem._Iris__on_symmetric_heap(complex_result) + def test_empty_pytorch_equivalence(): shmem = iris.iris(1 << 20) diff --git a/tests/unittests/test_full.py b/tests/unittests/test_full.py index a42d4ddb..86b19bbb 100644 --- a/tests/unittests/test_full.py +++ b/tests/unittests/test_full.py @@ -15,20 +15,15 @@ 3.141592, -2.718, 42, - -100, - 0.5, - -0.25, ], ) @pytest.mark.parametrize( "size", [ (1,), - (5,), - (2, 3), - (3, 4, 5), - (1, 1, 1), - (10, 20), + (100,), + (32, 32), + (4, 8, 16), ], ) def test_full_basic(fill_value, size): @@ -194,7 +189,7 @@ def test_full_size_variations(): def test_full_edge_cases(): - shmem = iris.iris(1 << 20) + shmem = iris.iris(1 << 24) # Empty tensor empty_result = shmem.full((0,), 1.0) @@ -209,10 +204,10 @@ def test_full_edge_cases(): assert single_result[0] == 5.0 assert shmem._Iris__on_symmetric_heap(single_result) - # Large tensor - large_result = shmem.full((100, 100), 0.1) - assert large_result.shape == (100, 100) - assert large_result.numel() == 10000 + # Large tensor for memory validation + large_result = shmem.full((1024, 1024), 0.1) + assert large_result.shape == (1024, 1024) + assert large_result.numel() == 1024 * 1024 assert torch.all(large_result == 0.1) assert shmem._Iris__on_symmetric_heap(large_result) @@ -223,6 +218,33 @@ def test_full_edge_cases(): assert torch.allclose(scalar_result, torch.tensor(2.718)) assert shmem._Iris__on_symmetric_heap(scalar_result) + # Edge dtype: int8 + int8_result = shmem.full((10, 20), 42, dtype=torch.int8) + assert int8_result.dtype == torch.int8 + assert torch.all(int8_result == 42) + assert shmem._Iris__on_symmetric_heap(int8_result) + + # Edge dtype: float64 + float64_result = shmem.full((5, 10), -2.718, dtype=torch.float64) + assert float64_result.dtype == torch.float64 + assert torch.allclose(float64_result, torch.tensor(-2.718, dtype=torch.float64)) + assert shmem._Iris__on_symmetric_heap(float64_result) + + # Complex shape for multi-dimensional handling + complex_result = shmem.full((2, 3, 4, 5), 0.5) + assert complex_result.shape == (2, 3, 4, 5) + assert torch.all(complex_result == 0.5) + assert shmem._Iris__on_symmetric_heap(complex_result) + + # Additional fill values + fill_values_result = shmem.full((5, 5), -100) + assert torch.all(fill_values_result == -100) + assert shmem._Iris__on_symmetric_heap(fill_values_result) + + fill_values_result2 = shmem.full((5, 5), -0.25) + assert torch.allclose(fill_values_result2, torch.tensor(-0.25)) + assert shmem._Iris__on_symmetric_heap(fill_values_result2) + def test_full_pytorch_equivalence(): shmem = iris.iris(1 << 20) diff --git a/tests/unittests/test_ones.py b/tests/unittests/test_ones.py index e70c63f8..e98b5e8f 100644 --- a/tests/unittests/test_ones.py +++ b/tests/unittests/test_ones.py @@ -12,10 +12,8 @@ torch.int8, torch.int16, torch.int32, - torch.int64, torch.float16, torch.float32, - torch.float64, torch.bool, ], ) @@ -23,11 +21,9 @@ "size", [ (1,), - (5,), - (2, 3), - (3, 4, 5), - (1, 1, 1), - (10, 20), + (100,), + (32, 32), + (4, 8, 16), ], ) def test_ones_basic(dtype, size): @@ -183,7 +179,7 @@ def test_ones_size_variations(): def test_ones_edge_cases(): - shmem = iris.iris(1 << 20) + shmem = iris.iris(1 << 24) # Empty tensor empty_result = shmem.ones(0) @@ -198,10 +194,10 @@ def test_ones_edge_cases(): assert single_result[0] == 1 assert shmem._Iris__on_symmetric_heap(single_result) - # Large tensor - large_result = shmem.ones(100, 100) - assert large_result.shape == (100, 100) - assert large_result.numel() == 10000 + # Large tensor for memory validation + large_result = shmem.ones(1024, 1024) + assert large_result.shape == (1024, 1024) + assert large_result.numel() == 1024 * 1024 assert torch.all(large_result == 1) assert shmem._Iris__on_symmetric_heap(large_result) @@ -212,6 +208,24 @@ def test_ones_edge_cases(): assert scalar_result.item() == 1 assert shmem._Iris__on_symmetric_heap(scalar_result) + # Edge dtype: int8 + int8_result = shmem.ones(10, 20, dtype=torch.int8) + assert int8_result.dtype == torch.int8 + assert torch.all(int8_result == 1) + assert shmem._Iris__on_symmetric_heap(int8_result) + + # Edge dtype: float64 + float64_result = shmem.ones(5, 10, dtype=torch.float64) + assert float64_result.dtype == torch.float64 + assert torch.all(float64_result == 1) + assert shmem._Iris__on_symmetric_heap(float64_result) + + # Complex shape for multi-dimensional handling + complex_result = shmem.ones(2, 3, 4, 5) + assert complex_result.shape == (2, 3, 4, 5) + assert torch.all(complex_result == 1) + assert shmem._Iris__on_symmetric_heap(complex_result) + def test_ones_pytorch_equivalence(): shmem = iris.iris(1 << 20) diff --git a/tests/unittests/test_randint.py b/tests/unittests/test_randint.py index a636be38..2cce569f 100644 --- a/tests/unittests/test_randint.py +++ b/tests/unittests/test_randint.py @@ -12,7 +12,6 @@ torch.int8, torch.int16, torch.int32, - torch.int64, torch.uint8, ], ) @@ -20,11 +19,9 @@ "size", [ (1,), - (5,), - (2, 3), - (3, 4, 5), - (1, 1, 1), - (10, 20), + (100,), + (32, 32), + (4, 8, 16), ], ) def test_randint_basic(dtype, size): @@ -177,7 +174,7 @@ def test_randint_size_variations(): def test_randint_edge_cases(): - shmem = iris.iris(1 << 20) + shmem = iris.iris(1 << 24) # Empty tensor empty_result = shmem.randint(0, 5, (0,)) @@ -193,10 +190,10 @@ def test_randint_edge_cases(): assert torch.all(single_result < 10) assert shmem._Iris__on_symmetric_heap(single_result) - # Large tensor - large_result = shmem.randint(0, 100, (100, 100)) - assert large_result.shape == (100, 100) - assert large_result.numel() == 10000 + # Large tensor for memory validation + large_result = shmem.randint(0, 100, (1024, 1024)) + assert large_result.shape == (1024, 1024) + assert large_result.numel() == 1024 * 1024 assert torch.all(large_result >= 0) assert torch.all(large_result < 100) assert shmem._Iris__on_symmetric_heap(large_result) @@ -209,6 +206,20 @@ def test_randint_edge_cases(): assert torch.all(scalar_result < 10) assert shmem._Iris__on_symmetric_heap(scalar_result) + # Edge dtype: int16 + int16_result = shmem.randint(0, 10, (10, 20), dtype=torch.int16) + assert int16_result.dtype == torch.int16 + assert torch.all(int16_result >= 0) + assert torch.all(int16_result < 10) + assert shmem._Iris__on_symmetric_heap(int16_result) + + # Complex shape for multi-dimensional handling + complex_result = shmem.randint(0, 10, (2, 3, 4, 5)) + assert complex_result.shape == (2, 3, 4, 5) + assert torch.all(complex_result >= 0) + assert torch.all(complex_result < 10) + assert shmem._Iris__on_symmetric_heap(complex_result) + def test_randint_pytorch_equivalence(): shmem = iris.iris(1 << 20) diff --git a/tests/unittests/test_zeros.py b/tests/unittests/test_zeros.py index 51126fed..83869371 100644 --- a/tests/unittests/test_zeros.py +++ b/tests/unittests/test_zeros.py @@ -12,10 +12,8 @@ torch.int8, torch.int16, torch.int32, - torch.int64, torch.float16, torch.float32, - torch.float64, torch.bool, ], ) @@ -23,11 +21,9 @@ "size", [ (1,), - (5,), - (2, 3), - (3, 4, 5), - (1, 1, 1), - (10, 20), + (100,), + (32, 32), + (4, 8, 16), ], ) def test_zeros_basic(dtype, size): @@ -183,7 +179,7 @@ def test_zeros_size_variations(): def test_zeros_edge_cases(): - shmem = iris.iris(1 << 20) + shmem = iris.iris(1 << 24) # Empty tensor empty_result = shmem.zeros(0) @@ -198,10 +194,10 @@ def test_zeros_edge_cases(): assert single_result[0] == 0 assert shmem._Iris__on_symmetric_heap(single_result) - # Large tensor - large_result = shmem.zeros(100, 100) - assert large_result.shape == (100, 100) - assert large_result.numel() == 10000 + # Large tensor for memory validation + large_result = shmem.zeros(1024, 1024) + assert large_result.shape == (1024, 1024) + assert large_result.numel() == 1024 * 1024 assert torch.all(large_result == 0) assert shmem._Iris__on_symmetric_heap(large_result) @@ -212,6 +208,24 @@ def test_zeros_edge_cases(): assert scalar_result.item() == 0 assert shmem._Iris__on_symmetric_heap(scalar_result) + # Edge dtype: int8 + int8_result = shmem.zeros(10, 20, dtype=torch.int8) + assert int8_result.dtype == torch.int8 + assert torch.all(int8_result == 0) + assert shmem._Iris__on_symmetric_heap(int8_result) + + # Edge dtype: float64 + float64_result = shmem.zeros(5, 10, dtype=torch.float64) + assert float64_result.dtype == torch.float64 + assert torch.all(float64_result == 0) + assert shmem._Iris__on_symmetric_heap(float64_result) + + # Complex shape for multi-dimensional handling + complex_result = shmem.zeros(2, 3, 4, 5) + assert complex_result.shape == (2, 3, 4, 5) + assert torch.all(complex_result == 0) + assert shmem._Iris__on_symmetric_heap(complex_result) + def test_zeros_pytorch_equivalence(): shmem = iris.iris(1 << 20) diff --git a/tests/unittests/test_zeros_like.py b/tests/unittests/test_zeros_like.py index b7a0ff0c..84758dda 100644 --- a/tests/unittests/test_zeros_like.py +++ b/tests/unittests/test_zeros_like.py @@ -12,10 +12,8 @@ torch.int8, torch.int16, torch.int32, - torch.int64, torch.float16, torch.float32, - torch.float64, torch.bool, ], ) @@ -23,11 +21,9 @@ "shape", [ (1,), - (5,), - (2, 3), - (3, 4, 5), - (1, 1, 1), - (10, 20), + (100,), + (32, 32), + (4, 8, 16), ], ) def test_zeros_like_basic(dtype, shape): @@ -322,7 +318,7 @@ def test_zeros_like_pytorch_equivalence(): def test_zeros_like_edge_cases(): - shmem = iris.iris(1 << 20) + shmem = iris.iris(1 << 24) # Empty tensor empty_tensor = shmem.full((0,), 1, dtype=torch.float32) @@ -337,17 +333,38 @@ def test_zeros_like_edge_cases(): assert single_result.numel() == 1 assert single_result[0] == 0 - # Large tensor - large_tensor = shmem.full((100, 100), 10, dtype=torch.float32) + # Large tensor for memory validation + large_tensor = shmem.full((1024, 1024), 10, dtype=torch.float32) large_result = shmem.zeros_like(large_tensor) - assert large_result.shape == (100, 100) - assert large_result.numel() == 10000 + assert large_result.shape == (1024, 1024) + assert large_result.numel() == 1024 * 1024 assert torch.all(large_result == 0) + # Edge dtype: int8 + int8_tensor = shmem.full((10, 20), 5, dtype=torch.int8) + int8_result = shmem.zeros_like(int8_tensor) + assert int8_result.dtype == torch.int8 + assert torch.all(int8_result == 0) + + # Edge dtype: float64 + float64_tensor = shmem.full((5, 10), 3.14, dtype=torch.float64) + float64_result = shmem.zeros_like(float64_tensor) + assert float64_result.dtype == torch.float64 + assert torch.all(float64_result == 0) + + # Complex shape for multi-dimensional handling + complex_tensor = shmem.full((2, 3, 4, 5), 7, dtype=torch.float32) + complex_result = shmem.zeros_like(complex_tensor) + assert complex_result.shape == (2, 3, 4, 5) + assert torch.all(complex_result == 0) + # Verify all edge case results are on symmetric heap assert shmem._Iris__on_symmetric_heap(empty_result) assert shmem._Iris__on_symmetric_heap(single_result) assert shmem._Iris__on_symmetric_heap(large_result) + assert shmem._Iris__on_symmetric_heap(int8_result) + assert shmem._Iris__on_symmetric_heap(float64_result) + assert shmem._Iris__on_symmetric_heap(complex_result) @pytest.mark.parametrize(