"""``AbstractDataset`` implementation to access Snowflake using Snowpark dataframes"""
from __future__ import annotations
import logging
from typing import Any, cast
import pandas as pd
from kedro.io.core import AbstractDataset, DatasetError
from snowflake.snowpark import DataFrame, Session
from snowflake.snowpark import context as sp_context
from snowflake.snowpark import exceptions as sp_exceptions
logger = logging.getLogger(__name__)
[docs]
class SnowparkTableDataset(AbstractDataset):
"""``SnowparkTableDataset`` loads and saves Snowpark DataFrames.
As of October 2024, the Snowpark connector works with Python 3.9, 3.10, and 3.11.
Python 3.12 is not supported yet.
Example usage for the
`YAML API <https://docs.kedro.org/en/stable/data/\
data_catalog_yaml_examples.html>`_:
.. code-block:: yaml
weather:
type: kedro_datasets.snowflake.SnowparkTableDataset
table_name: "weather_data"
database: "meteorology"
schema: "observations"
credentials: db_credentials
save_args:
mode: overwrite
column_order: name
table_type: ''
You can skip everything but "table_name" if the database and schema are
provided via credentials. This allows catalog entries to be shorter when
all Snowflake tables are in the same database and schema. Values in the
dataset definition take priority over those defined in credentials.
Example:
The credentials file provides all connection attributes. The catalog entry
for "weather" reuses the credentials parameters, while the "polygons" catalog
entry reuses all credentials parameters except for specifying a different
schema. The second example demonstrates the use of ``externalbrowser`` authentication.
catalog.yml:
.. code-block:: yaml
weather:
type: kedro_datasets.snowflake.SnowparkTableDataset
table_name: "weather_data"
database: "meteorology"
schema: "observations"
credentials: snowflake_client
save_args:
mode: overwrite
column_order: name
table_type: ''
polygons:
type: kedro_datasets.snowflake.SnowparkTableDataset
table_name: "geopolygons"
credentials: snowflake_client
schema: "geodata"
credentials.yml:
.. code-block:: yaml
snowflake_client:
account: 'ab12345.eu-central-1'
port: 443
warehouse: "datascience_wh"
database: "detailed_data"
schema: "observations"
user: "service_account_abc"
password: "supersecret"
credentials.yml (with externalbrowser authentication):
.. code-block:: yaml
snowflake_client:
account: 'ab12345.eu-central-1'
port: 443
warehouse: "datascience_wh"
database: "detailed_data"
schema: "observations"
user: "john_doe@wdomain.com"
authenticator: "externalbrowser"
"""
# this dataset cannot be used with ``ParallelRunner``,
# therefore it has the attribute ``_SINGLE_PROCESS = True``
# for parallelism within a pipeline please consider
# ``ThreadRunner`` instead
_SINGLE_PROCESS = True
DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {}
[docs]
def __init__( # noqa: PLR0913
self,
*,
table_name: str,
schema: str | None = None,
database: str | None = None,
load_args: dict[str, Any] | None = None,
save_args: dict[str, Any] | None = None,
credentials: dict[str, Any] | None = None,
session: Session | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""
Creates a new instance of ``SnowparkTableDataset``.
Args:
table_name: The table name to load or save data to.
schema: Name of the schema where ``table_name`` is.
Optional as can be provided as part of ``credentials``
dictionary. Argument value takes priority over one provided
in ``credentials`` if any.
database: Name of the database where ``schema`` is.
Optional as can be provided as part of ``credentials``
dictionary. Argument value takes priority over one provided
in ``credentials`` if any.
load_args: Currently not used
save_args: Provided to underlying snowpark ``save_as_table``
To find all supported arguments, see here:
https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/api/snowflake.snowpark.DataFrameWriter.saveAsTable.html
credentials: A dictionary with a snowpark connection string.
To find all supported arguments, see here:
https://docs.snowflake.com/en/user-guide/python-connector-api.html#connect
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
if not table_name:
raise DatasetError("'table_name' argument cannot be empty.")
if not credentials:
raise DatasetError("'credentials' argument cannot be empty.")
if not database:
if not ("database" in credentials and credentials["database"]):
raise DatasetError(
"'database' must be provided by credentials or dataset."
)
database = credentials["database"]
if not schema:
if not ("schema" in credentials and credentials["schema"]):
raise DatasetError(
"'schema' must be provided by credentials or dataset."
)
schema = credentials["schema"]
# Handle default load and save arguments
self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})}
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}
self._table_name = table_name
self._database = database
self._schema = schema
connection_parameters = credentials
connection_parameters.update(
{"database": self._database, "schema": self._schema}
)
self._connection_parameters = connection_parameters
self._session = session
self.metadata = metadata
def _describe(self) -> dict[str, Any]:
return {
"table_name": self._table_name,
"database": self._database,
"schema": self._schema,
}
@staticmethod
def _get_session(connection_parameters) -> Session:
"""
Given a connection string, create singleton connection
to be used across all instances of `SnowparkTableDataset` that
need to connect to the same source.
connection_parameters is a dictionary of any values
supported by snowflake python connector:
https://docs.snowflake.com/en/user-guide/python-connector-api.html#connect
example:
connection_parameters = {
"account": "",
"user": "",
"password": "", (optional)
"role": "", (optional)
"warehouse": "", (optional)
"database": "", (optional)
"schema": "", (optional)
"authenticator: "" (optional)
}
"""
try:
logger.debug("Trying to reuse active snowpark session...")
session = sp_context.get_active_session()
except sp_exceptions.SnowparkSessionException:
logger.debug("No active snowpark session found. Creating...")
session = Session.builder.configs(connection_parameters).create()
return session
@property
def session(self) -> Session:
"""
Retrieve or create a session.
Returns:
Session: The current session associated with the object.
"""
if not self._session:
self._session = self._get_session(self._connection_parameters)
return self._session
[docs]
def load(self) -> DataFrame:
"""
Load data from a specified database table.
Returns:
DataFrame: The loaded data as a Snowpark DataFrame.
"""
if self._session is None:
raise DatasetError(
"No active session. Please initialise a Snowpark session before loading data."
)
return self._session.table(self._validate_and_get_table_name())
[docs]
def save(self, data: pd.DataFrame | DataFrame) -> None:
"""
Check if the data is a Snowpark DataFrame or a Pandas DataFrame,
convert it to a Snowpark DataFrame if needed, and save it to the specified table.
Args:
data (pd.DataFrame | DataFrame): The data to save.
"""
if self._session is None:
raise DatasetError(
"No active session. Please initialise a Snowpark session before loading data."
)
if isinstance(data, pd.DataFrame):
snowpark_df = self._session.create_dataframe(data)
elif isinstance(data, DataFrame):
snowpark_df = data
else:
raise DatasetError(
f"Data of type {type(data)} is not supported for saving."
)
snowpark_df.write.save_as_table(
self._validate_and_get_table_name(), **self._save_args
)
def _exists(self) -> bool:
"""
Check if a specified table exists in the database.
Returns:
bool: True if the table exists, False otherwise.
"""
if self._session is None:
raise DatasetError(
"No active session. Please initialise a Snowpark session before loading data."
)
try:
self._session.table(
f"{self._database}.{self._schema}.{self._table_name}"
).show()
return True
except Exception as e:
logger.debug(f"Table {self._table_name} does not exist: {e}")
return False
def _validate_and_get_table_name(self) -> str:
"""
Validate that all parts of the table name are not None and join them into a string.
Args:
parts (list[str | None]): The list containing database, schema, and table name.
Returns:
str: The joined table name in the format 'database.schema.table'.
Raises:
ValueError: If any part of the table name is None.
"""
parts: list[str | None] = [self._database, self._schema, self._table_name]
if any(part is None or part == "" for part in parts):
raise DatasetError("Database, schema or table name cannot be None or empty")
parts_str = cast(list[str], parts) # make linting happy
return ".".join(parts_str)