Skip to content

Commit b307ab2

Browse files
Treat low-cardinality numeric features as categorical
1 parent 24db31f commit b307ab2

2 files changed

Lines changed: 116 additions & 2 deletions

File tree

risksyn/generator.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,35 @@
1212
# Empirically tested default from dpmm library
1313
_DEFAULT_PROC_EPSILON = 0.1
1414

15+
# Numeric columns with at most this many unique values are auto-treated as categorical
16+
_AUTO_CATEGORICAL_THRESHOLD = 10
17+
18+
19+
def _auto_categorize(
20+
data: pd.DataFrame, domain: Optional[dict]
21+
) -> tuple[pd.DataFrame, dict]:
22+
"""Auto-detect low-cardinality numeric columns and treat as categorical.
23+
24+
Numeric columns with <= _AUTO_CATEGORICAL_THRESHOLD unique values are
25+
cast to string dtype and given categorical domains to avoid private
26+
bounds estimation.
27+
28+
Returns a (possibly modified) copy of data and the augmented domain.
29+
"""
30+
domain = dict(domain) if domain else {}
31+
cols_to_cast = []
32+
for col, series in data.items():
33+
if col in domain:
34+
continue
35+
if series.dtype.kind in "ui": # uint, int only
36+
if series.nunique() <= _AUTO_CATEGORICAL_THRESHOLD:
37+
domain[col] = sorted(str(v) for v in series.unique())
38+
cols_to_cast.append(col)
39+
if cols_to_cast:
40+
data = data.copy()
41+
data[cols_to_cast] = data[cols_to_cast].astype(str)
42+
return data, domain
43+
1544

1645
def _requires_private_preprocessing(data: pd.DataFrame, domain: Optional[dict]) -> bool:
1746
"""Check if any numeric column lacks bounds in domain.
@@ -24,8 +53,10 @@ def _requires_private_preprocessing(data: pd.DataFrame, domain: Optional[dict])
2453
if domain is None:
2554
return True
2655
col_domain = domain.get(col, {})
56+
if isinstance(col_domain, list):
57+
continue # categorical domain provided, no bounds needed
2758
if not isinstance(col_domain, dict):
28-
return True # categorical-style domain for numeric column
59+
return True
2960
if col_domain.get("lower") is None or col_domain.get("upper") is None:
3061
return True
3162
return False
@@ -103,6 +134,8 @@ def fit(self, data: pd.DataFrame, domain: Optional[dict] = None) -> "AIMGenerato
103134
UserWarning
104135
If privacy budget for generation is smaller than for processing.
105136
"""
137+
data, domain = _auto_categorize(data, domain)
138+
106139
needs_preprocessing = _requires_private_preprocessing(data, domain)
107140

108141
if needs_preprocessing:
@@ -135,7 +168,22 @@ def fit(self, data: pd.DataFrame, domain: Optional[dict] = None) -> "AIMGenerato
135168
proc_epsilon=params.get("proc_epsilon"),
136169
gen_kwargs={"degree": self._degree},
137170
)
138-
self._pipeline.fit(data, domain)
171+
_BOUNDS_ERROR_MSG = (
172+
"Private bounds estimation failed for one or more numeric columns. "
173+
"This typically happens when the privacy budget is too small to detect "
174+
"data bounds. Remedies: (1) provide explicit domain bounds for numeric "
175+
"columns via the domain parameter, e.g. domain={'col': {'lower': 0, "
176+
"'upper': 100}}, (2) increase proc_epsilon, or (3) relax the risk "
177+
"requirement."
178+
)
179+
try:
180+
self._pipeline.fit(data, domain)
181+
except (TypeError, KeyError) as e:
182+
raise ValueError(_BOUNDS_ERROR_MSG) from e
183+
except ValueError as e:
184+
if "Private bounds estimation failed" not in str(e):
185+
raise ValueError(_BOUNDS_ERROR_MSG) from e
186+
raise
139187
return self
140188

141189
def generate(self, count: int) -> pd.DataFrame:

tests/test_generator.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from sklearn.datasets import load_wine
55

66
from risksyn import AIMGenerator, Risk
7+
from risksyn.generator import _auto_categorize, _requires_private_preprocessing
78

89

910
# Simple dataset for fast tests
@@ -112,6 +113,71 @@ def test_generate_before_fit_raises():
112113
gen.generate(count=10)
113114

114115

116+
def test_auto_categorize_binary_int_columns():
117+
"""Low-cardinality int columns should be auto-categorized and cast to str."""
118+
df = pd.DataFrame({
119+
"binary": [0, 1, 0, 1, 1],
120+
"ternary": [0, 1, 2, 0, 1],
121+
"continuous": np.random.uniform(0, 100, 5),
122+
"cat": ["a", "b", "c", "a", "b"],
123+
})
124+
out_data, domain = _auto_categorize(df, None)
125+
assert domain["binary"] == ["0", "1"]
126+
assert domain["ternary"] == ["0", "1", "2"]
127+
assert out_data["binary"].dtype == object
128+
assert out_data["ternary"].dtype == object
129+
assert "continuous" not in domain # float, not int
130+
assert "cat" not in domain # not numeric
131+
132+
133+
def test_auto_categorize_respects_existing_domain():
134+
"""Auto-categorization should not override user-provided domain."""
135+
df = pd.DataFrame({"x": [0, 1, 0, 1]})
136+
user_domain = {"x": {"lower": 0, "upper": 1}}
137+
_, domain = _auto_categorize(df, user_domain)
138+
assert domain["x"] == {"lower": 0, "upper": 1}
139+
140+
141+
def test_auto_categorize_skips_high_cardinality():
142+
"""Int columns with >10 unique values should not be auto-categorized."""
143+
df = pd.DataFrame({"x": list(range(11))})
144+
_, domain = _auto_categorize(df, None)
145+
assert "x" not in domain
146+
147+
148+
def test_requires_private_preprocessing_false_for_list_domain():
149+
"""List domain on a numeric column means no preprocessing needed."""
150+
df = pd.DataFrame({"x": [0, 1, 0, 1]})
151+
assert not _requires_private_preprocessing(df, {"x": [0, 1]})
152+
153+
154+
def test_binary_int_columns_fit_without_domain():
155+
"""Binary int columns should fit without explicit domain via auto-categorization."""
156+
np.random.seed(42)
157+
df = pd.DataFrame({
158+
"a": np.random.choice([0, 1], 100),
159+
"b": np.random.choice([0, 1], 100),
160+
"cat": np.random.choice(["x", "y"], 100),
161+
})
162+
risk = Risk.from_advantage(0.25)
163+
gen = AIMGenerator(risk=risk)
164+
gen.fit(df)
165+
synth = gen.generate(count=10)
166+
assert len(synth) == 10
167+
assert list(synth.columns) == list(df.columns)
168+
169+
170+
def test_bounds_estimation_failure_raises_value_error():
171+
"""Should raise ValueError (not TypeError) when private bounds estimation fails."""
172+
np.random.seed(42)
173+
# High-cardinality float column with tiny budget -> approx_bounds will fail
174+
df = pd.DataFrame({"x": np.random.uniform(0, 1, 50)})
175+
risk = Risk.from_zcdp(0.001)
176+
gen = AIMGenerator(risk=risk, proc_epsilon=0.001)
177+
with pytest.raises(ValueError, match="Private bounds estimation failed"):
178+
gen.fit(df)
179+
180+
115181
def test_categorical_only_no_preprocessing():
116182
"""Categorical-only data should not require preprocessing."""
117183
df = pd.DataFrame({

0 commit comments

Comments
 (0)