|
4 | 4 | from sklearn.datasets import load_wine |
5 | 5 |
|
6 | 6 | from risksyn import AIMGenerator, Risk |
| 7 | +from risksyn.generator import _auto_categorize, _requires_private_preprocessing |
7 | 8 |
|
8 | 9 |
|
9 | 10 | # Simple dataset for fast tests |
@@ -112,6 +113,71 @@ def test_generate_before_fit_raises(): |
112 | 113 | gen.generate(count=10) |
113 | 114 |
|
114 | 115 |
|
| 116 | +def test_auto_categorize_binary_int_columns(): |
| 117 | + """Low-cardinality int columns should be auto-categorized and cast to str.""" |
| 118 | + df = pd.DataFrame({ |
| 119 | + "binary": [0, 1, 0, 1, 1], |
| 120 | + "ternary": [0, 1, 2, 0, 1], |
| 121 | + "continuous": np.random.uniform(0, 100, 5), |
| 122 | + "cat": ["a", "b", "c", "a", "b"], |
| 123 | + }) |
| 124 | + out_data, domain = _auto_categorize(df, None) |
| 125 | + assert domain["binary"] == ["0", "1"] |
| 126 | + assert domain["ternary"] == ["0", "1", "2"] |
| 127 | + assert out_data["binary"].dtype == object |
| 128 | + assert out_data["ternary"].dtype == object |
| 129 | + assert "continuous" not in domain # float, not int |
| 130 | + assert "cat" not in domain # not numeric |
| 131 | + |
| 132 | + |
| 133 | +def test_auto_categorize_respects_existing_domain(): |
| 134 | + """Auto-categorization should not override user-provided domain.""" |
| 135 | + df = pd.DataFrame({"x": [0, 1, 0, 1]}) |
| 136 | + user_domain = {"x": {"lower": 0, "upper": 1}} |
| 137 | + _, domain = _auto_categorize(df, user_domain) |
| 138 | + assert domain["x"] == {"lower": 0, "upper": 1} |
| 139 | + |
| 140 | + |
| 141 | +def test_auto_categorize_skips_high_cardinality(): |
| 142 | + """Int columns with >10 unique values should not be auto-categorized.""" |
| 143 | + df = pd.DataFrame({"x": list(range(11))}) |
| 144 | + _, domain = _auto_categorize(df, None) |
| 145 | + assert "x" not in domain |
| 146 | + |
| 147 | + |
| 148 | +def test_requires_private_preprocessing_false_for_list_domain(): |
| 149 | + """List domain on a numeric column means no preprocessing needed.""" |
| 150 | + df = pd.DataFrame({"x": [0, 1, 0, 1]}) |
| 151 | + assert not _requires_private_preprocessing(df, {"x": [0, 1]}) |
| 152 | + |
| 153 | + |
| 154 | +def test_binary_int_columns_fit_without_domain(): |
| 155 | + """Binary int columns should fit without explicit domain via auto-categorization.""" |
| 156 | + np.random.seed(42) |
| 157 | + df = pd.DataFrame({ |
| 158 | + "a": np.random.choice([0, 1], 100), |
| 159 | + "b": np.random.choice([0, 1], 100), |
| 160 | + "cat": np.random.choice(["x", "y"], 100), |
| 161 | + }) |
| 162 | + risk = Risk.from_advantage(0.25) |
| 163 | + gen = AIMGenerator(risk=risk) |
| 164 | + gen.fit(df) |
| 165 | + synth = gen.generate(count=10) |
| 166 | + assert len(synth) == 10 |
| 167 | + assert list(synth.columns) == list(df.columns) |
| 168 | + |
| 169 | + |
| 170 | +def test_bounds_estimation_failure_raises_value_error(): |
| 171 | + """Should raise ValueError (not TypeError) when private bounds estimation fails.""" |
| 172 | + np.random.seed(42) |
| 173 | + # High-cardinality float column with tiny budget -> approx_bounds will fail |
| 174 | + df = pd.DataFrame({"x": np.random.uniform(0, 1, 50)}) |
| 175 | + risk = Risk.from_zcdp(0.001) |
| 176 | + gen = AIMGenerator(risk=risk, proc_epsilon=0.001) |
| 177 | + with pytest.raises(ValueError, match="Private bounds estimation failed"): |
| 178 | + gen.fit(df) |
| 179 | + |
| 180 | + |
115 | 181 | def test_categorical_only_no_preprocessing(): |
116 | 182 | """Categorical-only data should not require preprocessing.""" |
117 | 183 | df = pd.DataFrame({ |
|
0 commit comments