Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ on:
tags:
- 'v*'
pull_request:
branches:
- main


jobs:
Expand All @@ -33,7 +35,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.10', '3.11', '3.12', "3.13"]
python-version: ['3.12', "3.13"]


steps:
Expand All @@ -47,6 +49,7 @@ jobs:
python -m pip install --upgrade pip
pip install wheel
pip install --upgrade-strategy eager -e .[dev]
pip install pyspark[connect]
- name: Build
run: |
pip install build
Expand All @@ -61,7 +64,7 @@ jobs:
runs-on: windows-latest
strategy:
matrix:
python-version: ['3.10', '3.11', '3.12', "3.13"]
python-version: ['3.12', "3.13"]


steps:
Expand Down
5 changes: 4 additions & 1 deletion datachecker/checks_loaders_and_exporters/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ def get_dtype_lib(df):

return pap

if (mod, name) == ("pyspark.sql.dataframe", "DataFrame"):
if (mod, name) == ("pyspark.sql.dataframe", "DataFrame") or (mod, name) == (
"pyspark.sql.classic.dataframe",
"DataFrame",
):
import pandera.pyspark as paspk

return paspk
Expand Down
17 changes: 10 additions & 7 deletions datachecker/data_checkers/pyspark_validator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

from datachecker.data_checkers.general_validator import Validator
Expand All @@ -14,7 +15,7 @@ def __init__(
hard_check: bool = True,
custom_checks: dict = None,
):
raise NotImplementedError("PySpark support is not implemented yet")
# raise NotImplementedError("PySpark support is not implemented yet")
super().__init__(schema, data, file, format, hard_check, custom_checks)

def _check_duplicates(self):
Expand Down Expand Up @@ -75,27 +76,29 @@ def _check_completeness(self):
if __name__ == "__main__":
# Example usage (pandas)

spark = SparkSession.builder.master("local[2]").appName("create-DFs").getOrCreate()

data = pd.DataFrame(
[
(1, "A"),
(2, "B"),
(2.1, "B"),
(1, "A"), # Duplicate row
(3, "C"),
],
columns=["id", "value"],
)
data_spark = spark.createDataFrame(data)

schema = {
"check_duplicates": True,
"check_completeness": True,
"completeness_columns": ["id", "value"],
"columns": {
"id": {"type": "integer", "check_duplicates": True},
"id": {"type": "int", "check_duplicates": True},
"value": {"type": "string", "check_duplicates": True},
},
}

validator = PySparkValidator(schema, data, "datafile.csv", "csv")
validator.run_checks()
for entry in validator.qa_report:
print(entry)
validator = PySparkValidator(schema, data_spark, "datafile.csv", "csv")
validator.validate()
print(validator)
11 changes: 5 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies = [
"pandas<3.0.0",
"tomli",
"jinja2",
"pandera[pandas]>=0.26.1",
"pandera>=0.26.1",
]

[tool.setuptools.packages.find]
Expand Down Expand Up @@ -53,8 +53,8 @@ dev = [
"bump_my_version",
"pre-commit",
"polars",
"pandera[polars]",
"pyarrow"
"pyarrow",
"grpcio"
]
docs = [
"mkdocs",
Expand All @@ -65,12 +65,11 @@ docs = [
"mkdocs-mermaid2-plugin"
]
polars = [
"polars",
"pandera[polars]",
"polars>=0.20.0",
"pyarrow"
]
pyspark = [
"pandera[pyspark]"
"grpcio"
]

[tool.setuptools]
Expand Down
109 changes: 109 additions & 0 deletions tests/test_pyspark_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import importlib.util
import os

import pandas as pd
import pytest


@pytest.mark.skipif(
importlib.util.find_spec("pyspark") is None,
reason="pyspark is not installed",
)
class TestPysparkValidator:
def setup_method(self):
from pyspark.sql import SparkSession

self.spark = SparkSession.builder.master("local[1]").appName("Test").getOrCreate()

def test_pyspark_validator(self):
from datachecker.data_checkers.pyspark_validator import PySparkValidator

df = pd.DataFrame(
{
"id": [1, 2, 3, 2],
"name": ["Alice", "Bob", "Charlie", "Bob"],
"score": [90.5, 82.0, 95.25, 82.0],
"passed": [True, True, True, True],
}
)
spark_df = self.spark.createDataFrame(df)

schema = {
"check_duplicates": True,
"check_completeness": True,
"columns": {
"id": {"type": "int", "nullable": False},
"name": {"type": "str", "nullable": False},
"score": {"type": "float", "nullable": False, "min": 0, "max": 100},
"passed": {"type": "bool", "nullable": False},
},
}

new_validator = PySparkValidator(
schema=schema, data=spark_df, file="temp.html", format="html", hard_check=False
)
new_validator.validate()
new_validator.export()

assert isinstance(new_validator, PySparkValidator)
assert len(new_validator.log) > 0
assert os.path.exists("temp.html")

# Clean up
os.remove("temp.html")

def test_pyspark_all_dtypes(self):
from datachecker.data_checkers.pyspark_validator import PySparkValidator

df = pd.DataFrame(
{
"id": [1, 2, 3, 2],
"name": ["Alice", "Bob", "Charlie", "Bob"],
"score": [90.5, 82.0, 95.25, 82.0],
"passed": [True, True, True, True],
}
)
spark_df = self.spark.createDataFrame(df)

schema = {
"check_duplicates": True,
"check_completeness": True,
"columns": {
"id": {
"type": "int",
"allow_na": False,
"max_val": 2,
"min_val": 0,
"optional": False,
},
"name": {
"type": "str",
"allow_na": False,
"optional": False,
"min_length": 4,
"max_length": 10,
},
"score": {
"type": "float",
"allow_na": False,
"min_val": 0,
"max_val": 100,
"max_decimal": 5,
"min_decimal": 2,
"optional": False,
},
"passed": {"type": "bool", "allow_na": False, "optional": False},
},
}
new_validator = PySparkValidator(
schema=schema, data=spark_df, file="temp.html", format="html", hard_check=False
)
new_validator.validate()
new_validator.export()

assert isinstance(new_validator, PySparkValidator)
assert len(new_validator.log) > 0
assert os.path.exists("temp.html")

# Clean up
os.remove("temp.html")