Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion datachecker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .data_checkers.general_validator import Validator
from .data_checkers.pandas_validator import DataValidator
from .data_checkers.polars_validator import PolarsValidator
from .data_checkers.pyspark_validator import PySparkValidator
from .main import check_and_export

__all__ = ["DataValidator", "PolarsValidator", "Validator", "check_and_export"]
__all__ = ["DataValidator", "PolarsValidator", "PySparkValidator", "Validator", "check_and_export"]
10 changes: 6 additions & 4 deletions datachecker/data_checkers/pyspark_validator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import re

import pyspark.sql.types as T
from pyspark.sql import functions as F

from datachecker.data_checkers.general_validator import Validator


Expand All @@ -16,7 +13,6 @@ def __init__(
hard_check: bool = True,
custom_checks: dict = None,
):
# raise NotImplementedError("PySpark support is not implemented yet")
super().__init__(schema, data, file, format, hard_check, custom_checks)
self._convert_schema_dtypes()

Expand All @@ -43,6 +39,8 @@ def _convert_pyspark_error_messages(self):
continue

def _convert_schema_dtypes(self):
import pyspark.sql.types as T

mapping_dtypes = {
"int": T.IntegerType(),
"float": T.FloatType(),
Expand All @@ -65,6 +63,8 @@ def _convert_schema_dtypes(self):
)

def _check_duplicates(self):
from pyspark.sql import functions as F

# Check for duplicate rows in the dataframe
if self.schema.get("check_duplicates", False):
# Find duplicate rows (based on all columns)
Expand All @@ -81,6 +81,8 @@ def _check_duplicates(self):
)

def _check_completeness(self):
from pyspark.sql import functions as F

if self.schema.get("check_completeness", False):
cols_to_check = self.schema.get("completeness_columns", self.data.columns)

Expand Down
25 changes: 20 additions & 5 deletions datachecker/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import pandas as pd
import polars as pl

from datachecker.checks_loaders_and_exporters.checks import _type_id
from datachecker.data_checkers.general_validator import Validator
from datachecker.data_checkers.pandas_validator import DataValidator
from datachecker.data_checkers.polars_validator import PolarsValidator
from datachecker.data_checkers.pyspark_validator import PySparkValidator


def check_and_export(schema, data, file, format, hard_check=True, custom_checks=None) -> Validator:
Expand All @@ -30,7 +29,9 @@ def check_and_export(schema, data, file, format, hard_check=True, custom_checks=
DataValidator
Returns data validator object after validation and export.
"""
if type(data) is pl.DataFrame:
mod, name = _type_id(data)

if (mod, name) == ("polars.dataframe.frame", "DataFrame"):
validator = PolarsValidator(
schema=schema,
data=data,
Expand All @@ -39,7 +40,8 @@ def check_and_export(schema, data, file, format, hard_check=True, custom_checks=
hard_check=hard_check,
custom_checks=custom_checks,
)
elif type(data) is pd.DataFrame:

if (mod, name) == ("pandas.core.frame", "DataFrame"):
validator = DataValidator(
schema=schema,
data=data,
Expand All @@ -49,6 +51,19 @@ def check_and_export(schema, data, file, format, hard_check=True, custom_checks=
custom_checks=custom_checks,
)

if (mod, name) == ("pyspark.sql.dataframe", "DataFrame") or (mod, name) == (
"pyspark.sql.classic.dataframe",
"DataFrame",
):
validator = PySparkValidator(
schema=schema,
data=data,
file=file,
format=format,
hard_check=hard_check,
custom_checks=custom_checks,
)

validator.validate()
validator.export()
return validator
59 changes: 58 additions & 1 deletion docs/user_guide/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ schema = {
}
}
}

```

Next we need to load our dataset, for this example we will instead create our dataframe within our script. Note for this example we have actually created the `is_active` column to be a boolean and not a integer as outlined in our schema. This should be picked up in our validation checks! Also one email is slightly incorrect and a name contains a number.
Expand Down Expand Up @@ -188,3 +187,61 @@ From out yaml output we can see it failed 5 checks. These were:
4. an invalid email address was found in row 2
5. data type of is_active was a boolean when it was expecting an integer.

### Validating a Spark DataFrame

We can also perform validation checks on a Spark DataFrame using the `PySparkValidator` class. To do this we will first start a Spark session and use the existing `data` variable to create a Spark DataFrame.

```Python
from pyspark.sql import SparkSession

spark = (
SparkSession.builder.master("local")
.appName("local_session")
.getOrCreate()
)

sdf = spark.createDataFrame(data)
sdf.show()
```

```
+---+--------------------+---------+------------+
|age| email|is_active| name|
+---+--------------------+---------+------------+
| 30|john.doe@example.com| true| John Doe|
| 25|jane.smith@exampl...| false| Jane Smith|
| 40| alice.brown.com| true| Alice Brown|
|-22|bob.white@example...| false| Bob White|
| 35|carol.green@examp...| true|Carol Green1|
| 28|eve.black@example...| false| Eve Black|
+---+--------------------+---------+------------+
```

We can now create an instance of `PySparkValidator`, using the new Spark DataFrame and the existing schema as parameters.

```Python
from datachecker import PySparkValidator

validator = PySparkValidator(
schema=schema,
data=sdf,
file="pyspark_report.yaml",
format="yaml",
hard_check=False
)
validator.validate()
validator.export()
print(new_validator)
```

`check_and_export` is also able to take a Spark DataFrame as a parameter to export a validation report in a single function call.

```Python
check_and_export(
schema=schema,
data=sdf,
file="pyspark_report.html",
format="html",
hard_check=False,
)
```
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ classifiers = [
"Programming Language :: Python :: 3.12"
]
dependencies = [
"findspark",
"pyyaml",
"pandas<3.0.0",
"tomli",
Expand Down
6 changes: 5 additions & 1 deletion tests/test_pyspark_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@
)
class TestPysparkValidator:
def setup_method(self):
import findspark
from pyspark.sql import SparkSession

self.spark = SparkSession.builder.master("local[1]").appName("Test").getOrCreate()
# This helps for running local Spark sessions
findspark.init()

self.spark = SparkSession.builder.master("local").appName("Test").getOrCreate()

def test_pyspark_validator(self):
import pyspark.sql.types as T
Expand Down