Source code for kedro_datasets.snowflake.snowpark_dataset

"""``AbstractDataset`` implementation to access Snowflake using Snowpark dataframes
"""
import logging
import warnings
from copy import deepcopy
from typing import Any, Dict

import snowflake.snowpark as sp

from kedro_datasets import KedroDeprecationWarning
from kedro_datasets._io import AbstractDataset, DatasetError

logger = logging.getLogger(__name__)


[docs]class SnowparkTableDataset(AbstractDataset): """``SnowparkTableDataset`` loads and saves Snowpark dataframes. As of Mar-2023, the snowpark connector only works with Python 3.8. Example usage for the `YAML API <https://kedro.readthedocs.io/en/stable/data/\ data_catalog_yaml_examples.html>`_: .. code-block:: yaml weather: type: kedro_datasets.snowflake.SnowparkTableDataset table_name: "weather_data" database: "meteorology" schema: "observations" credentials: db_credentials save_args: mode: overwrite column_order: name table_type: '' You can skip everything but "table_name" if the database and schema are provided via credentials. That way catalog entries can be shorter if, for example, all used Snowflake tables live in same database/schema. Values in the dataset definition take priority over those defined in credentials. Example: Credentials file provides all connection attributes, catalog entry "weather" reuses credentials parameters, "polygons" catalog entry reuses all credentials parameters except providing a different schema name. Second example of credentials file uses ``externalbrowser`` authentication. catalog.yml .. code-block:: yaml weather: type: kedro_datasets.snowflake.SnowparkTableDataset table_name: "weather_data" database: "meteorology" schema: "observations" credentials: snowflake_client save_args: mode: overwrite column_order: name table_type: '' polygons: type: kedro_datasets.snowflake.SnowparkTableDataset table_name: "geopolygons" credentials: snowflake_client schema: "geodata" credentials.yml .. code-block:: yaml snowflake_client: account: 'ab12345.eu-central-1' port: 443 warehouse: "datascience_wh" database: "detailed_data" schema: "observations" user: "service_account_abc" password: "supersecret" credentials.yml (with externalbrowser authenticator) .. code-block:: yaml snowflake_client: account: 'ab12345.eu-central-1' port: 443 warehouse: "datascience_wh" database: "detailed_data" schema: "observations" user: "john_doe@wdomain.com" authenticator: "externalbrowser" """ # this dataset cannot be used with ``ParallelRunner``, # therefore it has the attribute ``_SINGLE_PROCESS = True`` # for parallelism within a pipeline please consider # ``ThreadRunner`` instead _SINGLE_PROCESS = True DEFAULT_LOAD_ARGS: Dict[str, Any] = {} DEFAULT_SAVE_ARGS: Dict[str, Any] = {}
[docs] def __init__( # noqa: PLR0913 self, table_name: str, schema: str = None, database: str = None, load_args: Dict[str, Any] = None, save_args: Dict[str, Any] = None, credentials: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: """Creates a new instance of ``SnowparkTableDataset``. Args: table_name: The table name to load or save data to. schema: Name of the schema where ``table_name`` is. Optional as can be provided as part of ``credentials`` dictionary. Argument value takes priority over one provided in ``credentials`` if any. database: Name of the database where ``schema`` is. Optional as can be provided as part of ``credentials`` dictionary. Argument value takes priority over one provided in ``credentials`` if any. load_args: Currently not used save_args: Provided to underlying snowpark ``save_as_table`` To find all supported arguments, see here: https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/api/snowflake.snowpark.DataFrameWriter.saveAsTable.html credentials: A dictionary with a snowpark connection string. To find all supported arguments, see here: https://docs.snowflake.com/en/user-guide/python-connector-api.html#connect metadata: Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins. """ if not table_name: raise DatasetError("'table_name' argument cannot be empty.") if not credentials: raise DatasetError("'credentials' argument cannot be empty.") if not database: if not ("database" in credentials and credentials["database"]): raise DatasetError( "'database' must be provided by credentials or dataset." ) database = credentials["database"] if not schema: if not ("schema" in credentials and credentials["schema"]): raise DatasetError( "'schema' must be provided by credentials or dataset." ) schema = credentials["schema"] # Handle default load and save arguments self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) if load_args is not None: self._load_args.update(load_args) self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) if save_args is not None: self._save_args.update(save_args) self._table_name = table_name self._database = database self._schema = schema connection_parameters = credentials connection_parameters.update( {"database": self._database, "schema": self._schema} ) self._connection_parameters = connection_parameters self._session = self._get_session(self._connection_parameters) self.metadata = metadata
def _describe(self) -> Dict[str, Any]: return { "table_name": self._table_name, "database": self._database, "schema": self._schema, } @staticmethod def _get_session(connection_parameters) -> sp.Session: """Given a connection string, create singleton connection to be used across all instances of `SnowparkTableDataset` that need to connect to the same source. connection_parameters is a dictionary of any values supported by snowflake python connector: https://docs.snowflake.com/en/user-guide/python-connector-api.html#connect example: connection_parameters = { "account": "", "user": "", "password": "", (optional) "role": "", (optional) "warehouse": "", (optional) "database": "", (optional) "schema": "", (optional) "authenticator: "" (optional) } """ try: logger.debug("Trying to reuse active snowpark session...") session = sp.context.get_active_session() except sp.exceptions.SnowparkSessionException: logger.debug("No active snowpark session found. Creating") session = sp.Session.builder.configs(connection_parameters).create() return session def _load(self) -> sp.DataFrame: table_name = [ self._database, self._schema, self._table_name, ] sp_df = self._session.table(".".join(table_name)) return sp_df def _save(self, data: sp.DataFrame) -> None: table_name = [ self._database, self._schema, self._table_name, ] data.write.save_as_table(table_name, **self._save_args) def _exists(self) -> bool: session = self._session query = "SELECT COUNT(*) FROM {database}.INFORMATION_SCHEMA.TABLES \ WHERE TABLE_SCHEMA = '{schema}' \ AND TABLE_NAME = '{table_name}'" rows = session.sql( query.format( database=self._database, schema=self._schema, table_name=self._table_name, ) ).collect() return rows[0][0] == 1
_DEPRECATED_CLASSES = { "SnowparkTableDataSet": SnowparkTableDataset, } def __getattr__(name): if name in _DEPRECATED_CLASSES: alias = _DEPRECATED_CLASSES[name] warnings.warn( f"{repr(name)} has been renamed to {repr(alias.__name__)}, " f"and the alias will be removed in Kedro-Datasets 2.0.0", KedroDeprecationWarning, stacklevel=2, ) return alias raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}")