Skip to content

Commit 4338b3e

Browse files
committed
fix: resolve mypy type errors in statistical test callback
1 parent 064b92b commit 4338b3e

1 file changed

Lines changed: 37 additions & 23 deletions

File tree

remode/remode.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44
from collections import Counter
55
from functools import partial
6-
from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Optional, Tuple, Union
6+
from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Optional, Tuple, Union, cast
77

88
import numpy as np
99
import pandas as pd
@@ -13,6 +13,9 @@
1313
from matplotlib.axes import Axes
1414

1515

16+
ModeTestCallable = Callable[[np.ndarray, int, int, int], Tuple[float, float]]
17+
18+
1619
def count_descriptive_peaks(x: np.ndarray) -> int:
1720
"""
1821
Count descriptive peaks in a histogram-like vector, matching the R implementation.
@@ -226,7 +229,7 @@ def __init__(
226229
Literal["descriptive_peaks", "max_modes", "none"], Callable
227230
] = "descriptive_peaks",
228231
statistical_test: Union[
229-
Literal["bootstrap", "binomial", "fisher"], Callable
232+
Literal["bootstrap", "binomial", "fisher"], ModeTestCallable
230233
] = "bootstrap",
231234
definition: Literal["shape_based", "peak_based"] = "shape_based",
232235
n_boot: int = 10000,
@@ -278,22 +281,27 @@ def _create_alpha_correction(
278281

279282
self._create_alpha_correction = _create_alpha_correction
280283

284+
test_callable: ModeTestCallable
285+
sign_test_name: str
281286
if isinstance(statistical_test, str):
282287
test_name = statistical_test.lower()
283288
if test_name == "bootstrap":
284289
bootstrap_rng = np.random.default_rng(random_state)
285-
self.statistical_test = partial(
286-
perform_bootstrap_test,
287-
n_boot=self.n_boot,
288-
rng=bootstrap_rng,
290+
test_callable = cast(
291+
ModeTestCallable,
292+
partial(
293+
perform_bootstrap_test,
294+
n_boot=self.n_boot,
295+
rng=bootstrap_rng,
296+
),
289297
)
290-
self.sign_test = "bootstrap"
298+
sign_test_name = "bootstrap"
291299
elif test_name == "binomial":
292-
self.statistical_test = perform_binomial_test
293-
self.sign_test = "binomial"
300+
test_callable = perform_binomial_test
301+
sign_test_name = "binomial"
294302
elif test_name == "fisher":
295-
self.statistical_test = perform_fisher_test
296-
self.sign_test = "fisher"
303+
test_callable = perform_fisher_test
304+
sign_test_name = "fisher"
297305
else:
298306
raise ValueError(
299307
"The statistical_test argument must be a callable or one of "
@@ -307,20 +315,26 @@ def _create_alpha_correction(
307315
)
308316
if statistical_test is perform_bootstrap_test:
309317
bootstrap_rng = np.random.default_rng(random_state)
310-
self.statistical_test = partial(
311-
perform_bootstrap_test,
312-
n_boot=self.n_boot,
313-
rng=bootstrap_rng,
318+
test_callable = cast(
319+
ModeTestCallable,
320+
partial(
321+
perform_bootstrap_test,
322+
n_boot=self.n_boot,
323+
rng=bootstrap_rng,
324+
),
314325
)
315-
self.sign_test = "bootstrap"
326+
sign_test_name = "bootstrap"
316327
else:
317-
self.statistical_test = statistical_test
328+
test_callable = cast(ModeTestCallable, statistical_test)
318329
if statistical_test is perform_fisher_test:
319-
self.sign_test = "fisher"
330+
sign_test_name = "fisher"
320331
elif statistical_test is perform_binomial_test:
321-
self.sign_test = "binomial"
332+
sign_test_name = "binomial"
322333
else:
323-
self.sign_test = "custom"
334+
sign_test_name = "custom"
335+
336+
self.statistical_test = test_callable
337+
self.sign_test = sign_test_name
324338

325339
if not isinstance(definition, str):
326340
raise ValueError("definition must be either 'shape_based' or 'peak_based'.")
@@ -383,10 +397,10 @@ def _find_maxima(self, xt: np.ndarray) -> list:
383397

384398
result = []
385399
alpha_cor = self._create_alpha_correction(xt, self.alpha)
386-
candidate = np.argmax(xt)
400+
candidate = int(np.argmax(xt))
387401
if candidate != 0 and candidate != len(xt) - 1:
388-
left_min = np.argmin(xt[:candidate])
389-
right_min = np.argmin(xt[candidate:]) + candidate
402+
left_min = int(np.argmin(xt[:candidate]))
403+
right_min = int(np.argmin(xt[candidate:]) + candidate)
390404
p_left, p_right = self.statistical_test(xt, candidate, left_min, right_min)
391405
if self.sign_test == "fisher":
392406
p_value = 1 - (1 - p_left) * (1 - p_right)

0 commit comments

Comments
 (0)