Skip to content

polars.PolarsDatabaseDataset

kedro_datasets_experimental.polars.PolarsDatabaseDataset

PolarsDatabaseDataset(
    *,
    sql=None,
    credentials=None,
    load_args=None,
    fs_args=None,
    filepath=None,
    table_name=None,
    save_args=None,
    metadata=None
)

Bases: AbstractDataset[None, DataFrame]

PolarsDatabaseDataset loads data from a provided SQL query or write data to a table.

It supports all allowed polars options on read_database and write_database. Since Polars uses SQLAlchemy behind the scenes, when instantiating PolarsDatabaseDataset one needs to pass a compatible connection string either in credentials (see the example code snippet below) or in load_args. Connection string formats supported by SQLAlchemy can be found here: https://docs.sqlalchemy.org/core/engines.html#database-urls

Example usage for the YAML API:
shuttle_id_dataset:
    type: polars.PolarsDatabaseDataset
    sql: "select shuttle, shuttle_id from spaceflights.shuttles;"
    credentials: db_credentials

Sample database credentials entry in credentials.yml:

db_credentials:
    con: postgresql://scott:tiger@localhost/test  # pragma: allowlist secret
    pool_size: 10 # additional parameters
Example usage for the Python API:
>>> from pathlib import Path
>>> import polars as pl
>>> import sqlite3
>>>
>>> from kedro_datasets_experimental.polars import PolarsDatabaseDataset
>>>
>>> data = pl.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>> sql = "SELECT * FROM table_a"
>>> tmp_path = Path.cwd() / "tmp"
>>> tmp_path.mkdir(parents=True, exist_ok=True)
>>> credentials = {"con": f"sqlite:///{tmp_path / 'test.db'}"}
>>> dataset = PolarsDatabaseDataset(sql=sql, credentials=credentials, table_name="table_a")
>>>
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>>
>>> assert data.equals(reloaded)
Source code in kedro_datasets_experimental/polars/polars_database_dataset.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def __init__(  # noqa: PLR0913
    self,
    *,
    sql: str | None = None,
    credentials: dict[str, Any] | None = None,
    load_args: dict[str, Any] | None = None,
    fs_args: dict[str, Any] | None = None,
    filepath: str | None = None,
    table_name: str | None = None,
    save_args: dict[str, Any] | None = None,
    metadata: dict[str, Any] | None = None,
) -> None:
    """Creates a new ``PolarsDatabaseDataset``."""
    if sql and filepath:
        raise DatasetError(
            "'sql' and 'filepath' arguments cannot both be provided."
            "Please only provide one."
        )

    if not table_name or (sql or filepath):
        raise DatasetError(
            "Either 'table_name' or one of 'sql' or 'filepath' arguments cannot both be empty."
            "Please provide a sql query or path to a sql query file."
        )

    if not (credentials and "con" in credentials and credentials["con"]):
        raise DatasetError(
            "'con' argument cannot be empty. Please "
            "provide a SQLAlchemy connection string."
        )

    default_load_args: dict[str, Any] = {}
    default_save_args: dict[str, Any] = {
        "if_exists": "replace"
    }

    self._load_args = (
        {**default_load_args, **load_args}
        if load_args is not None
        else default_load_args
    )

    self.table_name = table_name
    self._save_args = (
        {**default_save_args, **save_args}
        if save_args is not None
        else default_save_args
    )

    self.metadata = metadata

    # load sql query from file
    if sql:
        self._load_args["sql"] = sql
        self._filepath = None
    else:
        # filesystem for loading sql file
        _fs_args = copy.deepcopy(fs_args) or {}
        _fs_credentials = _fs_args.pop("credentials", {})
        protocol, path = get_protocol_and_path(str(filepath))

        self._protocol = protocol
        self._fs = fsspec.filesystem(self._protocol, **_fs_credentials, **_fs_args)
        self._filepath = path
    self._connection_str = credentials["con"]
    self._connection_args = {
        k: credentials[k] for k in credentials.keys() if k != "con"
    }
    if "mssql" in self._connection_str:
        self.adapt_mssql_date_params()

_connection_args instance-attribute

_connection_args = {
    k: (credentials[k]) for k in (keys()) if k != "con"
}

_connection_str instance-attribute

_connection_str = credentials['con']

_filepath instance-attribute

_filepath = None

_fs instance-attribute

_fs = filesystem(_protocol, **_fs_credentials, **_fs_args)

_load_args instance-attribute

_load_args = (
    {None: default_load_args, None: load_args}
    if load_args is not None
    else default_load_args
)

_protocol instance-attribute

_protocol = protocol

_save_args instance-attribute

_save_args = (
    {None: default_save_args, None: save_args}
    if save_args is not None
    else default_save_args
)

engine property

engine

The Engine object for the dataset's connection string.

engines class-attribute instance-attribute

engines = {}

metadata instance-attribute

metadata = metadata

table_name instance-attribute

table_name = table_name

_describe

_describe()
Source code in kedro_datasets_experimental/polars/polars_database_dataset.py
239
240
241
242
243
244
245
246
247
def _describe(self) -> dict[str, Any]:
    load_args = copy.deepcopy(self._load_args)
    return {
        "sql": str(load_args.pop("sql", None)),
        "filepath": str(self._filepath),
        "load_args": str(load_args),
        "table_name": self.table_name,
        "save_args": str(self._save_args),
    }

adapt_mssql_date_params

adapt_mssql_date_params()

We need to change the format of datetime parameters. MSSQL expects datetime in the exact format %y-%m-%dT%H:%M:%S. Here, we also accept plain dates. pyodbc does not accept named parameters, they must be provided as a list.

Source code in kedro_datasets_experimental/polars/polars_database_dataset.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def adapt_mssql_date_params(self) -> None:
    """We need to change the format of datetime parameters.
    MSSQL expects datetime in the exact format %y-%m-%dT%H:%M:%S.
    Here, we also accept plain dates.
    `pyodbc` does not accept named parameters, they must be provided as a list."""
    params = self._load_args.get("params", [])
    if not isinstance(params, list):
        raise DatasetError(
            "Unrecognized `params` format. It can be only a `list`, "
            f"got {type(params)!r}"
        )
    new_load_args = []
    for value in params:
        try:
            as_date = dt.date.fromisoformat(value)
            new_val = dt.datetime.combine(as_date, dt.time.min)
            new_load_args.append(new_val.strftime("%Y-%m-%dT%H:%M:%S"))
        except (TypeError, ValueError):
            new_load_args.append(value)
    if new_load_args:
        self._load_args["params"] = tuple(new_load_args)

create_connection classmethod

create_connection(connection_str, connection_args=None)

Given a connection string, create singleton connection to be used across all instances of PolarsDatabaseDataset that need to connect to the same source.

Source code in kedro_datasets_experimental/polars/polars_database_dataset.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
@classmethod
def create_connection(
    cls, connection_str: str, connection_args: dict | None = None
) -> None:
    """Given a connection string, create singleton connection
    to be used across all instances of `PolarsDatabaseDataset` that
    need to connect to the same source.
    """
    connection_args = connection_args or {}
    try:
        engine = create_engine(connection_str, **connection_args)
    except ImportError as import_error:
        raise _get_missing_module_error(import_error) from import_error
    except NoSuchModuleError as exc:
        raise _get_sql_alchemy_missing_error() from exc

    cls.engines[connection_str] = engine

load

load()
Source code in kedro_datasets_experimental/polars/polars_database_dataset.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def load(self) -> pl.DataFrame:
    load_args = copy.deepcopy(self._load_args)

    if self._filepath:
        load_path = get_filepath_str(PurePosixPath(self._filepath), self._protocol)
        with self._fs.open(load_path, mode="r") as fs_file:
            query = fs_file.read()
    else:
        query = load_args.pop("sql")

    return pl.read_database(
        query=query,
        connection=self._connection_str,
        **load_args
    )

save

save(data)
Source code in kedro_datasets_experimental/polars/polars_database_dataset.py
265
266
267
268
269
270
271
272
273
274
275
def save(self, data: pl.DataFrame) -> NoReturn:
    if not self.table_name:
        raise DatasetError(
            "'table_name' argument is required to save datasets."
        )

    data.write_database(
        table_name=self.table_name,
        connection=self._connection_str,
        **self._save_args
    )