33import warnings
44from collections import Counter
55from functools import partial
6- from typing import TYPE_CHECKING , Any , Callable , Dict , Literal , Optional , Tuple , Union
6+ from typing import TYPE_CHECKING , Any , Callable , Dict , Literal , Optional , Tuple , Union , cast
77
88import numpy as np
99import pandas as pd
1313 from matplotlib .axes import Axes
1414
1515
16+ ModeTestCallable = Callable [[np .ndarray , int , int , int ], Tuple [float , float ]]
17+
18+
1619def count_descriptive_peaks (x : np .ndarray ) -> int :
1720 """
1821 Count descriptive peaks in a histogram-like vector, matching the R implementation.
@@ -226,7 +229,7 @@ def __init__(
226229 Literal ["descriptive_peaks" , "max_modes" , "none" ], Callable
227230 ] = "descriptive_peaks" ,
228231 statistical_test : Union [
229- Literal ["bootstrap" , "binomial" , "fisher" ], Callable
232+ Literal ["bootstrap" , "binomial" , "fisher" ], ModeTestCallable
230233 ] = "bootstrap" ,
231234 definition : Literal ["shape_based" , "peak_based" ] = "shape_based" ,
232235 n_boot : int = 10000 ,
@@ -278,22 +281,27 @@ def _create_alpha_correction(
278281
279282 self ._create_alpha_correction = _create_alpha_correction
280283
284+ test_callable : ModeTestCallable
285+ sign_test_name : str
281286 if isinstance (statistical_test , str ):
282287 test_name = statistical_test .lower ()
283288 if test_name == "bootstrap" :
284289 bootstrap_rng = np .random .default_rng (random_state )
285- self .statistical_test = partial (
286- perform_bootstrap_test ,
287- n_boot = self .n_boot ,
288- rng = bootstrap_rng ,
290+ test_callable = cast (
291+ ModeTestCallable ,
292+ partial (
293+ perform_bootstrap_test ,
294+ n_boot = self .n_boot ,
295+ rng = bootstrap_rng ,
296+ ),
289297 )
290- self . sign_test = "bootstrap"
298+ sign_test_name = "bootstrap"
291299 elif test_name == "binomial" :
292- self . statistical_test = perform_binomial_test
293- self . sign_test = "binomial"
300+ test_callable = perform_binomial_test
301+ sign_test_name = "binomial"
294302 elif test_name == "fisher" :
295- self . statistical_test = perform_fisher_test
296- self . sign_test = "fisher"
303+ test_callable = perform_fisher_test
304+ sign_test_name = "fisher"
297305 else :
298306 raise ValueError (
299307 "The statistical_test argument must be a callable or one of "
@@ -307,20 +315,26 @@ def _create_alpha_correction(
307315 )
308316 if statistical_test is perform_bootstrap_test :
309317 bootstrap_rng = np .random .default_rng (random_state )
310- self .statistical_test = partial (
311- perform_bootstrap_test ,
312- n_boot = self .n_boot ,
313- rng = bootstrap_rng ,
318+ test_callable = cast (
319+ ModeTestCallable ,
320+ partial (
321+ perform_bootstrap_test ,
322+ n_boot = self .n_boot ,
323+ rng = bootstrap_rng ,
324+ ),
314325 )
315- self . sign_test = "bootstrap"
326+ sign_test_name = "bootstrap"
316327 else :
317- self . statistical_test = statistical_test
328+ test_callable = cast ( ModeTestCallable , statistical_test )
318329 if statistical_test is perform_fisher_test :
319- self . sign_test = "fisher"
330+ sign_test_name = "fisher"
320331 elif statistical_test is perform_binomial_test :
321- self . sign_test = "binomial"
332+ sign_test_name = "binomial"
322333 else :
323- self .sign_test = "custom"
334+ sign_test_name = "custom"
335+
336+ self .statistical_test = test_callable
337+ self .sign_test = sign_test_name
324338
325339 if not isinstance (definition , str ):
326340 raise ValueError ("definition must be either 'shape_based' or 'peak_based'." )
@@ -383,10 +397,10 @@ def _find_maxima(self, xt: np.ndarray) -> list:
383397
384398 result = []
385399 alpha_cor = self ._create_alpha_correction (xt , self .alpha )
386- candidate = np .argmax (xt )
400+ candidate = int ( np .argmax (xt ) )
387401 if candidate != 0 and candidate != len (xt ) - 1 :
388- left_min = np .argmin (xt [:candidate ])
389- right_min = np .argmin (xt [candidate :]) + candidate
402+ left_min = int ( np .argmin (xt [:candidate ]) )
403+ right_min = int ( np .argmin (xt [candidate :]) + candidate )
390404 p_left , p_right = self .statistical_test (xt , candidate , left_min , right_min )
391405 if self .sign_test == "fisher" :
392406 p_value = 1 - (1 - p_left ) * (1 - p_right )
0 commit comments