Skip to content

SparkDataset

SparkDataset loads and saves data using Apache Spark DataFrames.

kedro_datasets.spark.SparkDataset

SparkDataset(
    *,
    filepath,
    file_format="parquet",
    load_args=None,
    save_args=None,
    version=None,
    credentials=None,
    metadata=None
)

Bases: AbstractVersionedDataset[DataFrame, DataFrame]

SparkDataset loads and saves Spark dataframes.

Examples:

Using the YAML API:

weather:
  type: spark.SparkDataset
  filepath: s3a://your_bucket/data/01_raw/weather/*
  file_format: csv
  load_args:
    header: True
    inferSchema: True
  save_args:
    sep: '|'
    header: True

weather_with_schema:
  type: spark.SparkDataset
  filepath: s3a://your_bucket/data/01_raw/weather/*
  file_format: csv
  load_args:
    header: True
    schema:
      filepath: path/to/schema.json
  save_args:
    sep: '|'
    header: True

weather_cleaned:
  type: spark.SparkDataset
  filepath: data/02_intermediate/data.parquet
  file_format: parquet

Using the Python API:

>>> from kedro_datasets.spark import SparkDataset
>>> from pyspark.sql import SparkSession
>>> from pyspark.sql.types import IntegerType, Row, StringType, StructField, StructType
>>>
>>> schema = StructType(
...     [StructField("name", StringType(), True), StructField("age", IntegerType(), True)]
... )
>>> data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)]
>>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema)
>>>
>>> dataset = SparkDataset(filepath=tmp_path / "test_data")
>>> dataset.save(spark_df)
>>> reloaded = dataset.load()
>>> assert Row(name="Bob", age=12) in reloaded.take(4)

Parameters:

  • filepath (str) –

    Filepath in POSIX format to a Spark dataframe. When using Databricks specify filepaths starting with /dbfs/.

  • file_format (str, default: 'parquet' ) –

    File format used during load and save operations. These are formats supported by the running SparkContext include parquet, csv, delta. For a list of supported formats please refer to Apache Spark documentation at https://spark.apache.org/docs/latest/sql-programming-guide.html

  • load_args (dict[str, Any] | None, default: None ) –

    Load args passed to Spark DataFrameReader load method. It is dependent on the selected file format. You can find a list of read options for each supported format in Spark DataFrame read documentation: https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html

  • save_args (dict[str, Any] | None, default: None ) –

    Save args passed to Spark DataFrame write options. Similar to load_args this is dependent on the selected file format. You can pass mode and partitionBy to specify your overwrite mode and partitioning respectively. You can find a list of options for each format in Spark DataFrame write documentation: https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html

  • version (Version | None, default: None ) –

    If specified, should be an instance of kedro.io.core.Version. If its load attribute is None, the latest version will be loaded. If its save attribute is None, save version will be autogenerated.

  • credentials (dict[str, Any] | None, default: None ) –

    Credentials to access the S3 bucket, such as key, secret, if filepath prefix is s3a:// or s3n://. Optional keyword arguments passed to hdfs.client.InsecureClient if filepath prefix is hdfs://. Ignored otherwise.

  • metadata (dict[str, Any] | None, default: None ) –

    Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins.

Source code in kedro_datasets/spark/spark_dataset.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
def __init__(  # noqa: PLR0913
    self,
    *,
    filepath: str,
    file_format: str = "parquet",
    load_args: dict[str, Any] | None = None,
    save_args: dict[str, Any] | None = None,
    version: Version | None = None,
    credentials: dict[str, Any] | None = None,
    metadata: dict[str, Any] | None = None,
) -> None:
    """Creates a new instance of ``SparkDataset``.

    Args:
        filepath: Filepath in POSIX format to a Spark dataframe. When using Databricks
            specify ``filepath``s starting with ``/dbfs/``.
        file_format: File format used during load and save
            operations. These are formats supported by the running
            SparkContext include parquet, csv, delta. For a list of supported
            formats please refer to Apache Spark documentation at
            https://spark.apache.org/docs/latest/sql-programming-guide.html
        load_args: Load args passed to Spark DataFrameReader load method.
            It is dependent on the selected file format. You can find
            a list of read options for each supported format
            in Spark DataFrame read documentation:
            https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html
        save_args: Save args passed to Spark DataFrame write options.
            Similar to load_args this is dependent on the selected file
            format. You can pass ``mode`` and ``partitionBy`` to specify
            your overwrite mode and partitioning respectively. You can find
            a list of options for each format in Spark DataFrame
            write documentation:
            https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html
        version: If specified, should be an instance of
            ``kedro.io.core.Version``. If its ``load`` attribute is
            None, the latest version will be loaded. If its ``save``
            attribute is None, save version will be autogenerated.
        credentials: Credentials to access the S3 bucket, such as
            ``key``, ``secret``, if ``filepath`` prefix is ``s3a://`` or ``s3n://``.
            Optional keyword arguments passed to ``hdfs.client.InsecureClient``
            if ``filepath`` prefix is ``hdfs://``. Ignored otherwise.
        metadata: Any arbitrary metadata.
            This is ignored by Kedro, but may be consumed by users or external plugins.
    """
    credentials = deepcopy(credentials) or {}
    fs_prefix, filepath = split_filepath(filepath)
    path = PurePosixPath(filepath)
    exists_function = None
    glob_function = None
    self.metadata = metadata

    if (
        not (filepath.startswith("/dbfs") or filepath.startswith("/Volumes"))
        and fs_prefix not in (protocol + "://" for protocol in CLOUD_PROTOCOLS)
        and deployed_on_databricks()
    ):
        logger.warning(
            "Using SparkDataset on Databricks without the `/dbfs/` or `/Volumes` prefix in the "
            "filepath is a known source of error. You must add this prefix to %s",
            filepath,
        )
    if fs_prefix and fs_prefix in ("s3a://"):
        _s3 = S3FileSystem(**credentials)
        exists_function = _s3.exists
        # Ensure cache is not used so latest version is retrieved correctly.
        glob_function = partial(_s3.glob, refresh=True)

    elif fs_prefix == "hdfs://":
        if version:
            warn(
                f"HDFS filesystem support for versioned {self.__class__.__name__} is "
                f"in beta and uses 'hdfs.client.InsecureClient', please use with "
                f"caution"
            )

        # default namenode address
        credentials.setdefault("url", "http://localhost:9870")
        credentials.setdefault("user", "hadoop")

        _hdfs_client = KedroHdfsInsecureClient(**credentials)
        exists_function = _hdfs_client.hdfs_exists
        glob_function = _hdfs_client.hdfs_glob  # type: ignore

    elif filepath.startswith("/dbfs/"):
        # dbfs add prefix to Spark path by default
        # See https://github.com/kedro-org/kedro-plugins/issues/117
        dbutils = get_dbutils(get_spark())
        if dbutils:
            glob_function = partial(dbfs_glob, dbutils=dbutils)
            exists_function = partial(dbfs_exists, dbutils=dbutils)
    else:
        filesystem = fsspec.filesystem(fs_prefix.strip("://"), **credentials)
        exists_function = filesystem.exists
        glob_function = filesystem.glob

    super().__init__(
        filepath=path,
        version=version,
        exists_function=exists_function,
        glob_function=glob_function,
    )

    # Handle default load and save arguments
    self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})}
    self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}

    # Handle schema load argument
    self._schema = self._load_args.pop("schema", None)
    if self._schema is not None:
        if isinstance(self._schema, dict):
            self._schema = self._load_schema_from_file(self._schema)

    self._file_format = file_format
    self._fs_prefix = fs_prefix
    self._handle_delta_format()

DEFAULT_LOAD_ARGS class-attribute instance-attribute

DEFAULT_LOAD_ARGS = {}

DEFAULT_SAVE_ARGS class-attribute instance-attribute

DEFAULT_SAVE_ARGS = {}

_SINGLE_PROCESS class-attribute instance-attribute

_SINGLE_PROCESS = True

_file_format instance-attribute

_file_format = file_format

_fs_prefix instance-attribute

_fs_prefix = fs_prefix

_load_args instance-attribute

_load_args = {
    None: DEFAULT_LOAD_ARGS,
    None: load_args or {},
}

_save_args instance-attribute

_save_args = {
    None: DEFAULT_SAVE_ARGS,
    None: save_args or {},
}

_schema instance-attribute

_schema = pop('schema', None)

metadata instance-attribute

metadata = metadata

_describe

_describe()
Source code in kedro_datasets/spark/spark_dataset.py
291
292
293
294
295
296
297
298
def _describe(self) -> dict[str, Any]:
    return {
        "filepath": self._fs_prefix + str(self._filepath),
        "file_format": self._file_format,
        "load_args": self._load_args,
        "save_args": self._save_args,
        "version": self._version,
    }

_exists

_exists()
Source code in kedro_datasets/spark/spark_dataset.py
314
315
316
317
318
319
320
321
322
323
324
325
def _exists(self) -> bool:
    load_path = strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path()))

    try:
        get_spark().read.load(load_path, self._file_format)
    except AnalysisException as exception:
        # `AnalysisException.desc` is deprecated with pyspark >= 3.4
        message = exception.desc if hasattr(exception, "desc") else str(exception)
        if "Path does not exist:" in message or "is not a Delta table" in message:
            return False
        raise
    return True

_handle_delta_format

_handle_delta_format()
Source code in kedro_datasets/spark/spark_dataset.py
327
328
329
330
331
332
333
334
335
336
337
338
339
def _handle_delta_format(self) -> None:
    supported_modes = {"append", "overwrite", "error", "errorifexists", "ignore"}
    write_mode = self._save_args.get("mode")
    if (
        write_mode
        and self._file_format == "delta"
        and write_mode not in supported_modes
    ):
        raise DatasetError(
            f"It is not possible to perform 'save()' for file format 'delta' "
            f"with mode '{write_mode}' on 'SparkDataset'. "
            f"Please use 'spark.DeltaTableDataset' instead."
        )

_load_schema_from_file staticmethod

_load_schema_from_file(schema)
Source code in kedro_datasets/spark/spark_dataset.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
@staticmethod
def _load_schema_from_file(schema: dict[str, Any]) -> StructType:
    filepath = schema.get("filepath")
    if not filepath:
        raise DatasetError(
            "Schema load argument does not specify a 'filepath' attribute. Please"
            "include a path to a JSON-serialised 'pyspark.sql.types.StructType'."
        )

    credentials = deepcopy(schema.get("credentials")) or {}
    protocol, schema_path = get_protocol_and_path(filepath)
    file_system = fsspec.filesystem(protocol, **credentials)
    pure_posix_path = PurePosixPath(schema_path)
    load_path = get_filepath_str(pure_posix_path, protocol)

    # Open schema file
    with file_system.open(load_path) as fs_file:
        try:
            return StructType.fromJson(json.loads(fs_file.read()))
        except Exception as exc:
            raise DatasetError(
                f"Contents of 'schema.filepath' ({schema_path}) are invalid. Please"
                f"provide a valid JSON-serialised 'pyspark.sql.types.StructType'."
            ) from exc

load

load()
Source code in kedro_datasets/spark/spark_dataset.py
300
301
302
303
304
305
306
307
308
def load(self) -> DataFrame:
    load_path = strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path()))
    read_obj = get_spark().read

    # Pass schema if defined
    if self._schema:
        read_obj = read_obj.schema(self._schema)

    return read_obj.load(load_path, self._file_format, **self._load_args)

save

save(data)
Source code in kedro_datasets/spark/spark_dataset.py
310
311
312
def save(self, data: DataFrame) -> None:
    save_path = strip_dbfs_prefix(self._fs_prefix + str(self._get_save_path()))
    data.write.save(save_path, self._file_format, **self._save_args)