diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index be630e0..90e0df3 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -7,6 +7,8 @@ on: tags: - 'v*' pull_request: + branches: + - main jobs: @@ -33,7 +35,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.10', '3.11', '3.12', "3.13"] + python-version: ['3.12', "3.13"] steps: @@ -47,6 +49,7 @@ jobs: python -m pip install --upgrade pip pip install wheel pip install --upgrade-strategy eager -e .[dev] + pip install pyspark[connect] - name: Build run: | pip install build @@ -61,7 +64,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python-version: ['3.10', '3.11', '3.12', "3.13"] + python-version: ['3.12', "3.13"] steps: diff --git a/datachecker/checks_loaders_and_exporters/checks.py b/datachecker/checks_loaders_and_exporters/checks.py index 6d14d72..0151765 100644 --- a/datachecker/checks_loaders_and_exporters/checks.py +++ b/datachecker/checks_loaders_and_exporters/checks.py @@ -20,7 +20,10 @@ def get_dtype_lib(df): return pap - if (mod, name) == ("pyspark.sql.dataframe", "DataFrame"): + if (mod, name) == ("pyspark.sql.dataframe", "DataFrame") or (mod, name) == ( + "pyspark.sql.classic.dataframe", + "DataFrame", + ): import pandera.pyspark as paspk return paspk diff --git a/datachecker/data_checkers/pyspark_validator.py b/datachecker/data_checkers/pyspark_validator.py index 8c663af..7e92862 100644 --- a/datachecker/data_checkers/pyspark_validator.py +++ b/datachecker/data_checkers/pyspark_validator.py @@ -1,4 +1,5 @@ import pandas as pd +from pyspark.sql import SparkSession from pyspark.sql import functions as F from datachecker.data_checkers.general_validator import Validator @@ -14,7 +15,7 @@ def __init__( hard_check: bool = True, custom_checks: dict = None, ): - raise NotImplementedError("PySpark support is not implemented yet") + # raise NotImplementedError("PySpark support is not implemented yet") super().__init__(schema, data, file, format, hard_check, custom_checks) def _check_duplicates(self): @@ -75,27 +76,29 @@ def _check_completeness(self): if __name__ == "__main__": # Example usage (pandas) + spark = SparkSession.builder.master("local[2]").appName("create-DFs").getOrCreate() + data = pd.DataFrame( [ (1, "A"), - (2, "B"), + (2.1, "B"), (1, "A"), # Duplicate row (3, "C"), ], columns=["id", "value"], ) + data_spark = spark.createDataFrame(data) schema = { "check_duplicates": True, "check_completeness": True, "completeness_columns": ["id", "value"], "columns": { - "id": {"type": "integer", "check_duplicates": True}, + "id": {"type": "int", "check_duplicates": True}, "value": {"type": "string", "check_duplicates": True}, }, } - validator = PySparkValidator(schema, data, "datafile.csv", "csv") - validator.run_checks() - for entry in validator.qa_report: - print(entry) + validator = PySparkValidator(schema, data_spark, "datafile.csv", "csv") + validator.validate() + print(validator) diff --git a/pyproject.toml b/pyproject.toml index ea5643e..6f70fd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ "pandas<3.0.0", "tomli", "jinja2", - "pandera[pandas]>=0.26.1", + "pandera>=0.26.1", ] [tool.setuptools.packages.find] @@ -53,8 +53,8 @@ dev = [ "bump_my_version", "pre-commit", "polars", - "pandera[polars]", - "pyarrow" + "pyarrow", + "grpcio" ] docs = [ "mkdocs", @@ -65,12 +65,11 @@ docs = [ "mkdocs-mermaid2-plugin" ] polars = [ - "polars", - "pandera[polars]", + "polars>=0.20.0", "pyarrow" ] pyspark = [ - "pandera[pyspark]" + "grpcio" ] [tool.setuptools] diff --git a/tests/test_pyspark_validator.py b/tests/test_pyspark_validator.py new file mode 100644 index 0000000..6aa0857 --- /dev/null +++ b/tests/test_pyspark_validator.py @@ -0,0 +1,109 @@ +import importlib.util +import os + +import pandas as pd +import pytest + + +@pytest.mark.skipif( + importlib.util.find_spec("pyspark") is None, + reason="pyspark is not installed", +) +class TestPysparkValidator: + def setup_method(self): + from pyspark.sql import SparkSession + + self.spark = SparkSession.builder.master("local[1]").appName("Test").getOrCreate() + + def test_pyspark_validator(self): + from datachecker.data_checkers.pyspark_validator import PySparkValidator + + df = pd.DataFrame( + { + "id": [1, 2, 3, 2], + "name": ["Alice", "Bob", "Charlie", "Bob"], + "score": [90.5, 82.0, 95.25, 82.0], + "passed": [True, True, True, True], + } + ) + spark_df = self.spark.createDataFrame(df) + + schema = { + "check_duplicates": True, + "check_completeness": True, + "columns": { + "id": {"type": "int", "nullable": False}, + "name": {"type": "str", "nullable": False}, + "score": {"type": "float", "nullable": False, "min": 0, "max": 100}, + "passed": {"type": "bool", "nullable": False}, + }, + } + + new_validator = PySparkValidator( + schema=schema, data=spark_df, file="temp.html", format="html", hard_check=False + ) + new_validator.validate() + new_validator.export() + + assert isinstance(new_validator, PySparkValidator) + assert len(new_validator.log) > 0 + assert os.path.exists("temp.html") + + # Clean up + os.remove("temp.html") + + def test_pyspark_all_dtypes(self): + from datachecker.data_checkers.pyspark_validator import PySparkValidator + + df = pd.DataFrame( + { + "id": [1, 2, 3, 2], + "name": ["Alice", "Bob", "Charlie", "Bob"], + "score": [90.5, 82.0, 95.25, 82.0], + "passed": [True, True, True, True], + } + ) + spark_df = self.spark.createDataFrame(df) + + schema = { + "check_duplicates": True, + "check_completeness": True, + "columns": { + "id": { + "type": "int", + "allow_na": False, + "max_val": 2, + "min_val": 0, + "optional": False, + }, + "name": { + "type": "str", + "allow_na": False, + "optional": False, + "min_length": 4, + "max_length": 10, + }, + "score": { + "type": "float", + "allow_na": False, + "min_val": 0, + "max_val": 100, + "max_decimal": 5, + "min_decimal": 2, + "optional": False, + }, + "passed": {"type": "bool", "allow_na": False, "optional": False}, + }, + } + new_validator = PySparkValidator( + schema=schema, data=spark_df, file="temp.html", format="html", hard_check=False + ) + new_validator.validate() + new_validator.export() + + assert isinstance(new_validator, PySparkValidator) + assert len(new_validator.log) > 0 + assert os.path.exists("temp.html") + + # Clean up + os.remove("temp.html")