Skip to content

SparkDatasetV2

SparkDatasetV2 loads and saves data using Apache Spark DataFrames with improved support for Spark Connect, Databricks Connect, and automatic Pandas DataFrame conversion.

kedro_datasets.spark.SparkDatasetV2

SparkDatasetV2(
    *,
    filepath,
    file_format="parquet",
    load_args=None,
    save_args=None,
    version=None,
    credentials=None,
    metadata=None
)

Bases: AbstractVersionedDataset

SparkDatasetV2 loads and saves Spark dataframes.

Examples:

Using the YAML API:

weather:
  type: spark.SparkDatasetV2
  filepath: s3a://your_bucket/data/01_raw/weather/*
  file_format: csv
  load_args:
    header: True
    inferSchema: True
  save_args:
    sep: '|'
    header: True

weather_with_schema:
  type: spark.SparkDatasetV2
  filepath: s3a://your_bucket/data/01_raw/weather/*
  file_format: csv
  load_args:
    header: True
    schema:
      filepath: path/to/schema.json
  save_args:
    sep: '|'
    header: True

weather_cleaned:
  type: spark.SparkDatasetV2
  filepath: data/02_intermediate/data.parquet
  file_format: parquet

# Databricks with Unity Catalog
unity_data:
  type: spark.SparkDatasetV2
  filepath: /Volumes/catalog/schema/volume/data.parquet

# Databricks with DBFS
dbfs_data:
  type: spark.SparkDatasetV2
  filepath: /dbfs/mnt/data/output.parquet

Using the Python API:

>>> import tempfile
>>> from pyspark.sql import Row, SparkSession
>>> from pyspark.sql.types import IntegerType, StringType, StructField, StructType
>>>
>>> schema = StructType(
...     [StructField("name", StringType(), True), StructField("age", IntegerType(), True)]
... )
>>> data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)]
>>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema)
>>>
>>> with tempfile.TemporaryDirectory() as tmp_dir:
...     filepath = f"{tmp_dir}/test_data"
...     dataset = SparkDatasetV2(filepath=filepath)
...     dataset.save(spark_df)
...     reloaded = dataset.load()
...     assert Row(name="Bob", age=12) in reloaded.take(4)

You can also save Pandas DataFrames directly they will be automatically converted to Spark DataFrames:

>>> import pandas as pd
>>> pandas_df = pd.DataFrame({"name": ["Alex", "Bob"], "age": [31, 12]})
>>> dataset.save(pandas_df)  # Automatically converts to Spark DataFrame

Parameters:

  • filepath (str) –

    Filepath in POSIX format to a Spark dataframe. Supports: - Local paths: data/output.parquet or /absolute/path/data.parquet - S3: s3://bucket/path or s3a://bucket/path - GCS: gs://bucket/path - Azure: abfs://container@account.dfs.core.windows.net/path - Databricks DBFS: /dbfs/path or dbfs:/path - Unity Catalog: /Volumes/catalog/schema/volume/path

  • file_format (str, default: 'parquet' ) –

    File format used during load and save operations. These are formats supported by the running SparkContext include parquet, csv, delta. For a list of supported formats please refer to Apache Spark documentation at https://spark.apache.org/docs/latest/sql-programming-guide.html

  • load_args (dict[str, Any] | None, default: None ) –

    Load args passed to Spark DataFrameReader load method. It is dependent on the selected file format. You can find a list of read options for each supported format in Spark DataFrame read documentation: https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html

  • save_args (dict[str, Any] | None, default: None ) –

    Save args passed to Spark DataFrame write options. Similar to load_args this is dependent on the selected file format. You can pass mode and partitionBy to specify your overwrite mode and partitioning respectively. You can find a list of options for each format in Spark DataFrame write documentation: https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html

  • version (Version | None, default: None ) –

    If specified, should be an instance of kedro.io.core.Version. If its load attribute is None, the latest version will be loaded. If its save attribute is None, save version will be autogenerated.

  • credentials (dict[str, Any] | None, default: None ) –

    Credentials to access cloud storage, such as key, secret for S3, token for GCS, or account_key for Azure. Structure depends on the cloud provider.

  • metadata (dict[str, Any] | None, default: None ) –

    Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins.

Source code in kedro_datasets/spark/spark_dataset_v2.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def __init__(  # noqa: PLR0913
    self,
    *,
    filepath: str,
    file_format: str = "parquet",
    load_args: dict[str, Any] | None = None,
    save_args: dict[str, Any] | None = None,
    version: Version | None = None,
    credentials: dict[str, Any] | None = None,
    metadata: dict[str, Any] | None = None,
) -> None:
    """Creates a new instance of ``SparkDatasetV2``.

    Args:
        filepath: Filepath in POSIX format to a Spark dataframe. Supports:
            - Local paths: ``data/output.parquet`` or ``/absolute/path/data.parquet``
            - S3: ``s3://bucket/path`` or ``s3a://bucket/path``
            - GCS: ``gs://bucket/path``
            - Azure: ``abfs://container@account.dfs.core.windows.net/path``
            - Databricks DBFS: ``/dbfs/path`` or ``dbfs:/path``
            - Unity Catalog: ``/Volumes/catalog/schema/volume/path``
        file_format: File format used during load and save
            operations. These are formats supported by the running
            SparkContext include parquet, csv, delta. For a list of supported
            formats please refer to Apache Spark documentation at
            https://spark.apache.org/docs/latest/sql-programming-guide.html
        load_args: Load args passed to Spark DataFrameReader load method.
            It is dependent on the selected file format. You can find
            a list of read options for each supported format
            in Spark DataFrame read documentation:
            https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html
        save_args: Save args passed to Spark DataFrame write options.
            Similar to load_args this is dependent on the selected file
            format. You can pass ``mode`` and ``partitionBy`` to specify
            your overwrite mode and partitioning respectively. You can find
            a list of options for each format in Spark DataFrame
            write documentation:
            https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html
        version: If specified, should be an instance of
            ``kedro.io.core.Version``. If its ``load`` attribute is
            None, the latest version will be loaded. If its ``save``
            attribute is None, save version will be autogenerated.
        credentials: Credentials to access cloud storage, such as
            ``key``, ``secret`` for S3, ``token`` for GCS, or
            ``account_key`` for Azure. Structure depends on the
            cloud provider.
        metadata: Any arbitrary metadata.
            This is ignored by Kedro, but may be consumed by users or external plugins.
    """
    # Store original filepath for reference
    self._original_filepath = filepath

    # Process credentials
    credentials = deepcopy(credentials) or {}

    # Parse filepath and detect protocol
    protocol, path = parse_spark_filepath(filepath)

    # Handle relative paths for local files
    if protocol in ("file", "") and not os.path.isabs(path):
        path = str(Path(path).resolve())
        protocol = "file"  # Normalise empty to "file"

    self._protocol = protocol
    self._path = path

    # Validate Databricks paths
    validate_databricks_path(filepath)

    # Get filesystem operations
    exists_function, glob_function = self._get_filesystem_ops(
        protocol, filepath, credentials
    )

    # Initialise attributes
    self._file_format = file_format
    self._load_args = {
        **self.DEFAULT_LOAD_ARGS,
        **(deepcopy(load_args) if load_args is not None else {}),
    }
    self._save_args = {
        **self.DEFAULT_SAVE_ARGS,
        **(deepcopy(save_args) if save_args is not None else {}),
    }
    self._credentials = credentials
    self.metadata = metadata

    # Handle schema - can be a dict with filepath or a StructType directly
    self._schema: StructType | None = self._load_args.pop("schema", None)
    if self._schema is not None and isinstance(self._schema, dict):
        self._schema = load_spark_schema_from_file(self._schema)

    # Call parent constructor
    super().__init__(
        filepath=PurePosixPath(path),
        version=version,
        exists_function=exists_function,
        glob_function=glob_function,
    )

    # Validate delta format
    self._handle_delta_format()

DEFAULT_LOAD_ARGS class-attribute instance-attribute

DEFAULT_LOAD_ARGS = {}

DEFAULT_SAVE_ARGS class-attribute instance-attribute

DEFAULT_SAVE_ARGS = {}

_SINGLE_PROCESS class-attribute instance-attribute

_SINGLE_PROCESS = True

_credentials instance-attribute

_credentials = credentials

_file_format instance-attribute

_file_format = file_format

_load_args instance-attribute

_load_args = {
    None: DEFAULT_LOAD_ARGS,
    None: (
        deepcopy(load_args) if load_args is not None else {}
    ),
}

_original_filepath instance-attribute

_original_filepath = filepath

_path instance-attribute

_path = path

_protocol instance-attribute

_protocol = protocol

_save_args instance-attribute

_save_args = {
    None: DEFAULT_SAVE_ARGS,
    None: (
        deepcopy(save_args) if save_args is not None else {}
    ),
}

_schema instance-attribute

_schema = pop('schema', None)

_spark_path property

_spark_path

Get the Spark-compatible path for this dataset.

Returns:

  • Path formatted for Spark (e.g., 's3a

    //bucket/path', 'file:///path').

metadata instance-attribute

metadata = metadata

_describe

_describe()

Describe the dataset configuration.

Returns:

  • dict[str, Any]

    Dictionary with dataset configuration details.

Source code in kedro_datasets/spark/spark_dataset_v2.py
324
325
326
327
328
329
330
331
332
333
334
335
336
def _describe(self) -> dict[str, Any]:
    """Describe the dataset configuration.

    Returns:
        Dictionary with dataset configuration details.
    """
    return {
        "filepath": to_spark_path(self._protocol, self._path),
        "file_format": self._file_format,
        "load_args": self._load_args,
        "save_args": self._save_args,
        "version": self._version,
    }

_exists

_exists()

Check if the dataset exists.

Returns:

  • bool

    True if the dataset exists, False otherwise.

Source code in kedro_datasets/spark/spark_dataset_v2.py
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
def _exists(self) -> bool:
    """Check if the dataset exists.

    Returns:
        True if the dataset exists, False otherwise.
    """
    load_path = self._get_load_path()
    spark_load_path = to_spark_path(self._protocol, str(load_path))

    try:
        spark_session = get_spark_with_remote_support()
        # Try to read the metadata without loading data
        spark_session.read.format(self._file_format).load(spark_load_path).schema
        return True
    except Exception as exc:
        # Check for specific error messages indicating non-existence
        error_msg = str(exc).lower()
        non_existence_indicators = [
            "path does not exist",
            "file not found",
            "is not a delta table",
            "no such file",
            "pathnotfoundexception",
            "filenotfoundexception",
        ]
        if any(msg in error_msg for msg in non_existence_indicators):
            return False
        # Re-raise for unexpected errors
        logger.warning(f"Error checking existence of {spark_load_path}: {exc}")
        raise

_get_filesystem_ops

_get_filesystem_ops(protocol, filepath, credentials)

Get filesystem operations for exists and glob.

Parameters:

  • protocol (str) –

    Filesystem protocol.

  • filepath (str) –

    Original filepath.

  • credentials (dict[str, Any]) –

    Credentials for filesystem access.

Returns:

  • tuple

    Tuple of (exists_function, glob_function).

Source code in kedro_datasets/spark/spark_dataset_v2.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def _get_filesystem_ops(
    self, protocol: str, filepath: str, credentials: dict[str, Any]
) -> tuple:
    """Get filesystem operations for exists and glob.

    Args:
        protocol: Filesystem protocol.
        filepath: Original filepath.
        credentials: Credentials for filesystem access.

    Returns:
        Tuple of (exists_function, glob_function).
    """
    if protocol == "dbfs" and deployed_on_databricks():
        try:
            spark_session = get_spark_with_remote_support()
            dbutils = get_dbutils(spark_session)
            if dbutils:
                logger.debug("Using optimised DBFS operations via dbutils")
                return (
                    partial(dbfs_exists, dbutils=dbutils),
                    partial(dbfs_glob, dbutils=dbutils),
                )
        except Exception as exc:
            logger.warning(f"Failed to get dbutils, falling back to fsspec: {exc}")

    # Regular fsspec for everything else
    fs = get_spark_filesystem(protocol, credentials)
    return fs.exists, fs.glob

_handle_delta_format

_handle_delta_format()

Handle delta-specific configurations.

Source code in kedro_datasets/spark/spark_dataset_v2.py
368
369
370
371
372
373
374
375
376
377
378
379
380
381
def _handle_delta_format(self) -> None:
    """Handle delta-specific configurations."""
    supported_modes = {"append", "overwrite", "error", "errorifexists", "ignore"}
    write_mode = self._save_args.get("mode")
    if (
        write_mode
        and self._file_format == "delta"
        and write_mode not in supported_modes
    ):
        raise DatasetError(
            f"It is not possible to perform 'save()' for file format 'delta' "
            f"with mode '{write_mode}' on 'SparkDataset'. "
            f"Please use 'spark.DeltaTableDataset' instead."
        )

_load

_load()

Loads data from filepath.

Returns:

  • DataFrame

    Data from filepath as pyspark dataframe.

Source code in kedro_datasets/spark/spark_dataset_v2.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
def _load(self) -> DataFrame:
    """Loads data from filepath.

    Returns:
        Data from filepath as pyspark dataframe.
    """
    load_path = self._get_load_path()
    spark_load_path = to_spark_path(self._protocol, str(load_path))

    spark_session = get_spark_with_remote_support()

    reader = spark_session.read
    if self._schema:
        reader = reader.schema(self._schema)

    return (
        reader.format(self._file_format)
        .options(**self._load_args)
        .load(spark_load_path)
    )

_save

_save(data)

Saves pyspark dataframe.

Parameters:

  • data (DataFrame | DataFrame) –

    PySpark DataFrame or Pandas DataFrame to save. Pandas DataFrames will be automatically converted to Spark.

Source code in kedro_datasets/spark/spark_dataset_v2.py
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
def _save(self, data: DataFrame | pd.DataFrame) -> None:
    """Saves pyspark dataframe.

    Args:
        data: PySpark DataFrame or Pandas DataFrame to save.
              Pandas DataFrames will be automatically converted to Spark.
    """
    import pandas as pd  # noqa: PLC0415

    spark_session = get_spark_with_remote_support()

    # Convert Pandas DataFrame to Spark DataFrame if needed
    if isinstance(data, pd.DataFrame):
        data = spark_session.createDataFrame(data)

    save_path = self._get_save_path()
    spark_save_path = to_spark_path(self._protocol, str(save_path))

    # Create a copy of save_args to avoid mutation
    save_args = self._save_args.copy()

    # Extract mode and partitionBy (these are handled separately by Spark)
    mode = save_args.pop("mode", None)
    partition_by = save_args.pop("partitionBy", None)

    # Prepare writer
    writer = data.write

    if mode:
        writer = writer.mode(mode)
    if partition_by:
        writer = writer.partitionBy(partition_by)

    # Save with remaining options
    writer.format(self._file_format).options(**save_args).save(spark_save_path)