Skip to content

Conversation

@wagner-austin
Copy link

@wagner-austin wagner-austin commented Dec 30, 2025

Summary

Add return type annotations to all predict methods in sklearn module:

  • LGBMModel.predict
  • LGBMClassifier.predict
  • LGBMClassifier.predict_proba

Type narrowing in predict_proba

The predict() method returns Union[np.ndarray, spmatrix, List[spmatrix]]. The else branch in predict_proba does np.vstack((1.0 - result, result)) which only works with np.ndarray.

This else branch only runs when pred_contrib=False, and predict() always returns np.ndarray when pred_contrib=False. Added an assertion with error message to tell mypy this, following the pattern in _get_label_from_constructed_dataset.

Mypy Errors (before fix)

python-package/lightgbm/sklearn.py:1093: error: Function is missing a return type annotation  [no-untyped-def]
python-package/lightgbm/sklearn.py:1586: error: Function is missing a return type annotation  [no-untyped-def]
python-package/lightgbm/sklearn.py:1616: error: Function is missing a return type annotation  [no-untyped-def]

Related Issue

Fixes: #3867

Test Plan

  • Ran pre-commit run --files python-package/lightgbm/sklearn.py - all checks passed
  • Verified mypy no longer reports no-untyped-def errors for predict methods

Copy link
Collaborator

@jameslamb jameslamb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your interest in LightGBM and for working on this!

There are multiple predict() methods implemented in this file, could you please try fixing the hints on all of them? I think they all are passing through output from Booster.predict() without type coercion so all should be the same.

python-package/lightgbm/sklearn.py:1093: error: Function is missing a return type annotation  [no-untyped-def]
python-package/lightgbm/sklearn.py:1586: error: Function is missing a return type annotation  [no-untyped-def]
python-package/lightgbm/sklearn.py:1616: error: Function is missing a return type annotation  [no-untyped-def

validate_features: bool = False,
**kwargs: Any,
):
) -> Union[np.ndarray, scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right, this should just be identical to the return type hint on Booster.predict()!

I'd been avoiding investigating these because I thought it was possible for pandas outputs to make it in here via the set_output() API (scikit-learn docs) but I think that's something scikit-learn estimators have to opt into and implement themselves, and we haven't done that in LightGBM.

Since you've identified this... let's enforce that relationship by moving the Booster.predict() type hint into a shared type hint variable and re-using it in both places.

  1. add a _LGBM_PredictReturnType after this:
    _LGBM_PredictDataType = Union[
    str,
    Path,
    np.ndarray,
    pd_DataFrame,
    scipy.sparse.spmatrix,
    pa_Table,
    ]
  2. re-use it in all the relevant places in basic.py and sklearn.py

@jameslamb
Copy link
Collaborator

side note: in 2026 I am finally eliminating that List[scipy.sparse.spmatrix] in this return type hint, it really bothers me 😅 (ref: #6348)

@wagner-austin wagner-austin force-pushed the fix-lgbmmodel-predict-return-type branch from d88c791 to b99db2b Compare December 30, 2025 07:22
@wagner-austin wagner-austin changed the title [python-package] Add return type annotation to LGBMModel.predict method [python-package] Add return type annotations to predict methods in sklearn module Dec 30, 2025
…learn module

Add return type annotations to LGBMModel.predict, LGBMClassifier.predict, and LGBMClassifier.predict_proba methods. The predict_proba method requires an isinstance assertion for type narrowing in the binary classification branch, since predict() returns a union type but only returns np.ndarray when pred_contrib=False. This follows the existing error message pattern used in _get_label_from_constructed_dataset and similar helper functions.

Fixes: microsoft#3867
@wagner-austin wagner-austin force-pushed the fix-lgbmmodel-predict-return-type branch from b99db2b to ac6a308 Compare December 30, 2025 14:31
Copy link
Collaborator

@jameslamb jameslamb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for expanding this to all of the predict() methods. I'll wait to re-review until my suggestions in #7116 (comment) are fully addressed.

Please also do comment on the suggestion threads when you push new changes, explaining what you did.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[ci] warnings from mypy

2 participants