Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ on:
tags:
- 'v*'
pull_request:
branches:
- main


jobs:
Expand All @@ -33,7 +35,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.10', '3.11', '3.12', "3.13"]
python-version: ['3.12', "3.13"]


steps:
Expand All @@ -47,6 +49,7 @@ jobs:
python -m pip install --upgrade pip
pip install wheel
pip install --upgrade-strategy eager -e .[dev]
pip install pyspark[connect]
- name: Build
run: |
pip install build
Expand All @@ -61,7 +64,7 @@ jobs:
runs-on: windows-latest
strategy:
matrix:
python-version: ['3.10', '3.11', '3.12', "3.13"]
python-version: ['3.12', "3.13"]


steps:
Expand Down
41 changes: 37 additions & 4 deletions datachecker/checks_loaders_and_exporters/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ def get_dtype_lib(df):

return pap

if (mod, name) == ("pyspark.sql.dataframe", "DataFrame"):
if (mod, name) == ("pyspark.sql.dataframe", "DataFrame") or (mod, name) == (
"pyspark.sql.classic.dataframe",
"DataFrame",
):
import pandera.pyspark as paspk

return paspk
Expand Down Expand Up @@ -352,10 +355,11 @@ def validate_using_pandera(
# ISSUE - NA pass not present in output log

try:
converted_schema.validate(data, lazy=True)
df_output = converted_schema.validate(data, lazy=True)

# The following code is to add all checks when validation passes
grouped_validation_return = None
# The following code is to add all checks when validation passes or pyspark
# validation
grouped_validation_return = process_pyspark_errors(df_output)
except get_dtype_lib(data).errors.SchemaErrors as e:
# validation_return is now a pandas dataframe
validation_return = e.failure_cases[["column", "check", "failure_case", "index"]]
Expand Down Expand Up @@ -431,3 +435,32 @@ def convert_schema_into_log_entries(converted_schema: pa.DataFrameSchema) -> pd.
"invalid_ids": [[]] * len(list_of_checks),
}
)


def process_pyspark_errors(df_output):
if "pyspark" not in str(type(df_output)):
return None
elif df_output.pandera.errors is None or df_output.pandera.errors == {}:
# No failed validation checks
return None
failed_cases = pd.DataFrame({})
for _key in df_output.pandera.errors:
for check_name in df_output.pandera.errors[_key]:
failed_cases = pd.concat(
[failed_cases, pd.DataFrame(df_output.pandera.errors[_key][check_name])],
ignore_index=True,
)

# failed_cases = pd.DataFrame(df_output.pandera.errors["DATA"]["DATAFRAME_CHECK"])
validation_return = pd.DataFrame({})
validation_return["column"] = failed_cases["column"]
validation_return["check"] = failed_cases["check"]
validation_return["failure_case"] = failed_cases["error"]
validation_return["index"] = None
grouped_validation_return = (
validation_return.groupby(["column", "check"])
.agg({"failure_case": list, "index": list})
.reset_index()
.rename(columns={"index": "invalid_ids"})
)
return grouped_validation_return
80 changes: 49 additions & 31 deletions datachecker/data_checkers/pyspark_validator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pandas as pd
import re

import pyspark.sql.types as T
from pyspark.sql import functions as F

from datachecker.data_checkers.general_validator import Validator
Expand All @@ -14,8 +16,53 @@ def __init__(
hard_check: bool = True,
custom_checks: dict = None,
):
raise NotImplementedError("PySpark support is not implemented yet")
# raise NotImplementedError("PySpark support is not implemented yet")
super().__init__(schema, data, file, format, hard_check, custom_checks)
self._convert_schema_dtypes()

def validate(self):
super().validate()
self._convert_pyspark_error_messages()

def _convert_pyspark_error_messages(self):
message = "Pyspark does not return cases or index"
for i in range(1, len(self.log) - 1):
entry = self.log[i]
if (
entry["failing_ids"] is None
or len(entry["failing_ids"]) == 0
or not isinstance(entry["failing_ids"][0], str)
):
continue
# <Schema Column ...> is the message when a check fails for pyspark
# replace it with blanket statement. Should still pass other important errors
# back to user if they are not related to pyspark checks
elif re.search(r"<Schema Column", entry["failing_ids"][0]) is not None:
entry["failing_ids"][0] = message
else:
continue

def _convert_schema_dtypes(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if this is necessary? Without this function, when supplying "int", "str", etc, the types still get coerced to PySpark types, which I assume pandera is doing for us

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that the schema was being formatted correctly but without converting from strings into pyspark dtypes it would not actually trigger the type checks or other checks. I can take a look into it more in the future and will add a backlog ticket to review

mapping_dtypes = {
"int": T.IntegerType(),
"float": T.FloatType(),
"string": T.StringType(),
"str": T.StringType(),
"bool": T.BooleanType(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anything for timestamps?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice spot will add now

"date": T.DateType(),
"datetime": T.DateType(),
"timestamp": T.TimestampType(),
}
for col in self.schema.get("columns", {}):
input_type = self.schema["columns"][col].get("type")
if input_type not in mapping_dtypes and input_type not in mapping_dtypes.values():
raise ValueError(
f"Unsupported data type '{input_type}' for column '{col}' in schema. "
f"Supported types are: {list(mapping_dtypes.keys())}"
)
self.schema["columns"][col]["type"] = mapping_dtypes.get(
self.schema["columns"][col]["type"], self.schema["columns"][col]["type"]
)

def _check_duplicates(self):
# Check for duplicate rows in the dataframe
Expand Down Expand Up @@ -70,32 +117,3 @@ def _check_completeness(self):
outcome=result,
entry_type="error",
)


if __name__ == "__main__":
# Example usage (pandas)

data = pd.DataFrame(
[
(1, "A"),
(2, "B"),
(1, "A"), # Duplicate row
(3, "C"),
],
columns=["id", "value"],
)

schema = {
"check_duplicates": True,
"check_completeness": True,
"completeness_columns": ["id", "value"],
"columns": {
"id": {"type": "integer", "check_duplicates": True},
"value": {"type": "string", "check_duplicates": True},
},
}

validator = PySparkValidator(schema, data, "datafile.csv", "csv")
validator.run_checks()
for entry in validator.qa_report:
print(entry)
11 changes: 5 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies = [
"pandas<3.0.0",
"tomli",
"jinja2",
"pandera[pandas]>=0.26.1",
"pandera>=0.26.1",
]

[tool.setuptools.packages.find]
Expand Down Expand Up @@ -53,8 +53,8 @@ dev = [
"bump_my_version",
"pre-commit",
"polars",
"pandera[polars]",
"pyarrow"
"pyarrow",
"grpcio"
]
docs = [
"mkdocs",
Expand All @@ -65,12 +65,11 @@ docs = [
"mkdocs-mermaid2-plugin"
]
polars = [
"polars",
"pandera[polars]",
"polars>=0.20.0",
"pyarrow"
]
pyspark = [
"pandera[pyspark]"
"grpcio"
]

[tool.setuptools]
Expand Down
Loading