Source code for kedro.extras.datasets.spark.spark_jdbc_dataset

"""SparkJDBCDataSet to load and save a PySpark DataFrame via JDBC."""

from copy import deepcopy
from typing import Any, Dict

from pyspark.sql import DataFrame, SparkSession

from import AbstractDataset, DatasetError

__all__ = ["SparkJDBCDataSet"]

# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0.
# Any contribution to datasets should be made in kedro-datasets
# in kedro-plugins (

[docs]class SparkJDBCDataSet(AbstractDataset[DataFrame, DataFrame]): """``SparkJDBCDataSet`` loads data from a database table accessible via JDBC URL url and connection properties and saves the content of a PySpark DataFrame to an external database table via JDBC. It uses ``pyspark.sql.DataFrameReader`` and ``pyspark.sql.DataFrameWriter`` internally, so it supports all allowed PySpark options on ``jdbc``. Example usage for the `YAML API <\ data_catalog_yaml_examples.html>`_: .. code-block:: yaml weather: type: spark.SparkJDBCDataSet table: weather_table url: jdbc:postgresql://localhost/test credentials: db_credentials load_args: properties: driver: org.postgresql.Driver save_args: properties: driver: org.postgresql.Driver Example usage for the `Python API <\ advanced_data_catalog_usage.html>`_: :: >>> import pandas as pd >>> >>> from pyspark.sql import SparkSession >>> >>> spark = SparkSession.builder.getOrCreate() >>> data = spark.createDataFrame(pd.DataFrame({'col1': [1, 2], >>> 'col2': [4, 5], >>> 'col3': [5, 6]})) >>> url = 'jdbc:postgresql://localhost/test' >>> table = 'table_a' >>> connection_properties = {'driver': 'org.postgresql.Driver'} >>> data_set = SparkJDBCDataSet( >>> url=url, table=table, credentials={'user': 'scott', >>> 'password': 'tiger'}, >>> load_args={'properties': connection_properties}, >>> save_args={'properties': connection_properties}) >>> >>> >>> reloaded = data_set.load() >>> >>> assert data.toPandas().equals(reloaded.toPandas()) """ DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any]
[docs] def __init__( # noqa: too-many-arguments self, url: str, table: str, credentials: Dict[str, Any] = None, load_args: Dict[str, Any] = None, save_args: Dict[str, Any] = None, ) -> None: """Creates a new ``SparkJDBCDataSet``. Args: url: A JDBC URL of the form ``jdbc:subprotocol:subname``. table: The name of the table to load or save data to. credentials: A dictionary of JDBC database connection arguments. Normally at least properties ``user`` and ``password`` with their corresponding values. It updates ``properties`` parameter in ``load_args`` and ``save_args`` in case it is provided. load_args: Provided to underlying PySpark ``jdbc`` function along with the JDBC URL and the name of the table. To find all supported arguments, see here: save_args: Provided to underlying PySpark ``jdbc`` function along with the JDBC URL and the name of the table. To find all supported arguments, see here: Raises: DatasetError: When either ``url`` or ``table`` is empty or when a property is provided with a None value. """ if not url: raise DatasetError( "'url' argument cannot be empty. Please " "provide a JDBC URL of the form " "'jdbc:subprotocol:subname'." ) if not table: raise DatasetError( "'table' argument cannot be empty. Please " "provide the name of the table to load or save " "data to." ) self._url = url self._table = table # Handle default load and save arguments self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) if load_args is not None: self._load_args.update(load_args) self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) if save_args is not None: self._save_args.update(save_args) # Update properties in load_args and save_args with credentials. if credentials is not None: # Check credentials for bad inputs. for cred_key, cred_value in credentials.items(): if cred_value is None: raise DatasetError( f"Credential property '{cred_key}' cannot be None. " f"Please provide a value." ) load_properties = self._load_args.get("properties", {}) save_properties = self._save_args.get("properties", {}) self._load_args["properties"] = {**load_properties, **credentials} self._save_args["properties"] = {**save_properties, **credentials}
def _describe(self) -> Dict[str, Any]: load_args = self._load_args save_args = self._save_args # Remove user and password values from load and save properties. if "properties" in load_args: load_properties = load_args["properties"].copy() load_properties.pop("user", None) load_properties.pop("password", None) load_args = {**load_args, "properties": load_properties} if "properties" in save_args: save_properties = save_args["properties"].copy() save_properties.pop("user", None) save_properties.pop("password", None) save_args = {**save_args, "properties": save_properties} return { "url": self._url, "table": self._table, "load_args": load_args, "save_args": save_args, } @staticmethod def _get_spark(): # pragma: no cover return SparkSession.builder.getOrCreate() def _load(self) -> DataFrame: return self._get_spark().read.jdbc(self._url, self._table, **self._load_args) def _save(self, data: DataFrame) -> None: return data.write.jdbc(self._url, self._table, **self._save_args)