diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index be630e0..90e0df3 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -7,6 +7,8 @@ on: tags: - 'v*' pull_request: + branches: + - main jobs: @@ -33,7 +35,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.10', '3.11', '3.12', "3.13"] + python-version: ['3.12', "3.13"] steps: @@ -47,6 +49,7 @@ jobs: python -m pip install --upgrade pip pip install wheel pip install --upgrade-strategy eager -e .[dev] + pip install pyspark[connect] - name: Build run: | pip install build @@ -61,7 +64,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python-version: ['3.10', '3.11', '3.12', "3.13"] + python-version: ['3.12', "3.13"] steps: diff --git a/datachecker/checks_loaders_and_exporters/checks.py b/datachecker/checks_loaders_and_exporters/checks.py index 6d14d72..bd69970 100644 --- a/datachecker/checks_loaders_and_exporters/checks.py +++ b/datachecker/checks_loaders_and_exporters/checks.py @@ -20,7 +20,10 @@ def get_dtype_lib(df): return pap - if (mod, name) == ("pyspark.sql.dataframe", "DataFrame"): + if (mod, name) == ("pyspark.sql.dataframe", "DataFrame") or (mod, name) == ( + "pyspark.sql.classic.dataframe", + "DataFrame", + ): import pandera.pyspark as paspk return paspk @@ -352,10 +355,11 @@ def validate_using_pandera( # ISSUE - NA pass not present in output log try: - converted_schema.validate(data, lazy=True) + df_output = converted_schema.validate(data, lazy=True) - # The following code is to add all checks when validation passes - grouped_validation_return = None + # The following code is to add all checks when validation passes or pyspark + # validation + grouped_validation_return = process_pyspark_errors(df_output) except get_dtype_lib(data).errors.SchemaErrors as e: # validation_return is now a pandas dataframe validation_return = e.failure_cases[["column", "check", "failure_case", "index"]] @@ -431,3 +435,32 @@ def convert_schema_into_log_entries(converted_schema: pa.DataFrameSchema) -> pd. "invalid_ids": [[]] * len(list_of_checks), } ) + + +def process_pyspark_errors(df_output): + if "pyspark" not in str(type(df_output)): + return None + elif df_output.pandera.errors is None or df_output.pandera.errors == {}: + # No failed validation checks + return None + failed_cases = pd.DataFrame({}) + for _key in df_output.pandera.errors: + for check_name in df_output.pandera.errors[_key]: + failed_cases = pd.concat( + [failed_cases, pd.DataFrame(df_output.pandera.errors[_key][check_name])], + ignore_index=True, + ) + + # failed_cases = pd.DataFrame(df_output.pandera.errors["DATA"]["DATAFRAME_CHECK"]) + validation_return = pd.DataFrame({}) + validation_return["column"] = failed_cases["column"] + validation_return["check"] = failed_cases["check"] + validation_return["failure_case"] = failed_cases["error"] + validation_return["index"] = None + grouped_validation_return = ( + validation_return.groupby(["column", "check"]) + .agg({"failure_case": list, "index": list}) + .reset_index() + .rename(columns={"index": "invalid_ids"}) + ) + return grouped_validation_return diff --git a/datachecker/data_checkers/pyspark_validator.py b/datachecker/data_checkers/pyspark_validator.py index 8c663af..28b027d 100644 --- a/datachecker/data_checkers/pyspark_validator.py +++ b/datachecker/data_checkers/pyspark_validator.py @@ -1,4 +1,6 @@ -import pandas as pd +import re + +import pyspark.sql.types as T from pyspark.sql import functions as F from datachecker.data_checkers.general_validator import Validator @@ -14,8 +16,53 @@ def __init__( hard_check: bool = True, custom_checks: dict = None, ): - raise NotImplementedError("PySpark support is not implemented yet") + # raise NotImplementedError("PySpark support is not implemented yet") super().__init__(schema, data, file, format, hard_check, custom_checks) + self._convert_schema_dtypes() + + def validate(self): + super().validate() + self._convert_pyspark_error_messages() + + def _convert_pyspark_error_messages(self): + message = "Pyspark does not return cases or index" + for i in range(1, len(self.log) - 1): + entry = self.log[i] + if ( + entry["failing_ids"] is None + or len(entry["failing_ids"]) == 0 + or not isinstance(entry["failing_ids"][0], str) + ): + continue + # is the message when a check fails for pyspark + # replace it with blanket statement. Should still pass other important errors + # back to user if they are not related to pyspark checks + elif re.search(r"=0.26.1", + "pandera>=0.26.1", ] [tool.setuptools.packages.find] @@ -53,8 +53,8 @@ dev = [ "bump_my_version", "pre-commit", "polars", - "pandera[polars]", - "pyarrow" + "pyarrow", + "grpcio" ] docs = [ "mkdocs", @@ -65,12 +65,11 @@ docs = [ "mkdocs-mermaid2-plugin" ] polars = [ - "polars", - "pandera[polars]", + "polars>=0.20.0", "pyarrow" ] pyspark = [ - "pandera[pyspark]" + "grpcio" ] [tool.setuptools] diff --git a/tests/test_pyspark_validator.py b/tests/test_pyspark_validator.py new file mode 100644 index 0000000..7357222 --- /dev/null +++ b/tests/test_pyspark_validator.py @@ -0,0 +1,174 @@ +import importlib.util +import os + +import pandas as pd +import pytest + + +@pytest.mark.skipif( + importlib.util.find_spec("pyspark") is None, + reason="pyspark is not installed", +) +class TestPysparkValidator: + def setup_method(self): + from pyspark.sql import SparkSession + + self.spark = SparkSession.builder.master("local[1]").appName("Test").getOrCreate() + + def test_pyspark_validator(self): + import pyspark.sql.types as T + + from datachecker.data_checkers.pyspark_validator import PySparkValidator + + df = pd.DataFrame( + { + "id": [1, 2, 3, 2], + "name": ["Alice", "Bob", "Charlie", "Bob"], + "score": [90.5, 82.0, 95.25, 82.0], + "passed": [True, True, True, True], + } + ) + spark_df = self.spark.createDataFrame(df) + spark_df = spark_df.withColumn("id", spark_df["id"].cast(T.IntegerType())) + spark_df = spark_df.withColumn("name", spark_df["name"].cast(T.StringType())) + spark_df = spark_df.withColumn("score", spark_df["score"].cast(T.FloatType())) + spark_df = spark_df.withColumn("passed", spark_df["passed"].cast(T.BooleanType())) + + schema = { + "check_duplicates": True, + "check_completeness": True, + "columns": { + "id": {"type": "int", "nullable": False}, + "name": {"type": "string", "nullable": False}, + "score": {"type": "float", "nullable": False}, + "passed": {"type": "bool", "nullable": False}, + }, + } + + new_validator = PySparkValidator( + schema=schema, data=spark_df, file="temp.html", format="html", hard_check=False + ) + new_validator.validate() + new_validator.export() + + assert isinstance(new_validator, PySparkValidator) + assert len(new_validator.log) > 0 + assert os.path.exists("temp.html") + + # Clean up + os.remove("temp.html") + + @pytest.mark.xfail( + reason="Decimal checks seem to be broken in pyspark, need to investigate further" + ) + def test_pyspark_all_dtypes(self): + import pyspark.sql.types as T + + from datachecker.data_checkers.pyspark_validator import PySparkValidator + + df = pd.DataFrame( + { + "id": [1, 2, 3, 2], + "name": ["Alice", "Bob", "@", "Bob"], + "score": [90.5, 82.0, 95.2, 82.0], + "passed": [True, True, True, True], + } + ) + spark_df = self.spark.createDataFrame(df) + spark_df = spark_df.withColumn("id", spark_df["id"].cast(T.IntegerType())) + spark_df = spark_df.withColumn("name", spark_df["name"].cast(T.StringType())) + spark_df = spark_df.withColumn("score", spark_df["score"].cast(T.FloatType())) + spark_df = spark_df.withColumn("passed", spark_df["passed"].cast(T.BooleanType())) + + schema = { + "check_duplicates": True, + "check_completeness": True, + "columns": { + "id": { + "type": T.IntegerType(), + "allow_na": False, + "max_val": 2, + "min_val": 0, + "optional": False, + }, + "name": { + "type": T.StringType(), + "allow_na": False, + "optional": False, + "min_length": 4, + "max_length": 10, + # allowed strings not working for pyspark with regex + "allowed_strings": ["Alice", "Bob", "Charlie"], + }, + "score": { + "type": T.FloatType(), + "allow_na": False, + "min_val": 0, + "max_val": 95, + "max_decimal": 3, + "min_decimal": 2, + "optional": False, + }, + "passed": {"type": T.BooleanType(), "allow_na": False, "optional": False}, + }, + } + new_validator = PySparkValidator( + schema=schema, data=spark_df, file="temp.html", format="html", hard_check=False + ) + new_validator.validate() + + entries_with_fails = [ + entry for entry in new_validator.log[1:] if "fail" in str(entry["outcome"]).lower() + ] + + assert len(new_validator.log) == 23 + # Decimal checks seem to be broken currently have 6 fails, expect + # one extra for the decimal check + assert len(entries_with_fails) == 7 + assert [entry["description"] for entry in entries_with_fails] == [ + "Checking id less than or equal to 2", + "Checking name contains only ['Alice', 'Bob', 'Charlie']", + "Checking name string length greater than or equal to 4", + "Checking score less than or equal to 95", + "Checking score has at least 2 decimal places", + "Checking for duplicate rows in the dataframe", + "Checking for missing rows in the dataframe columns: id, name, score, passed", + ] + + def test_pyspark_validate_boilerplate_checks(self): + import pyspark.sql.types as T + + from datachecker.data_checkers.pyspark_validator import PySparkValidator + + df = pd.DataFrame( + { + "id": [1, 2, 7, 5], + } + ) + spark_df = self.spark.createDataFrame(df) + spark_df = spark_df.withColumn("id", spark_df["id"].cast(T.IntegerType())) + + schema = { + "check_duplicates": False, + "check_completeness": False, + "columns": { + "id": { + "type": "int", + "allow_na": False, + "max_val": 2, + "min_val": 0, + "optional": False, + }, + }, + } + new_validator = PySparkValidator( + schema=schema, data=spark_df, file="temp.json", format="json", hard_check=False + ) + new_validator.validate() + entries_with_fails = [ + entry for entry in new_validator.log[1:-1] if "fail" in str(entry["outcome"]).lower() + ] + assert len(entries_with_fails) == 1 + assert "Pyspark does not return cases or index" in str( + entries_with_fails[0]["failing_ids"][0] + )