diff --git a/docs/examples/adapters/adbc_postgres_ingest.py b/docs/examples/adapters/adbc_postgres_ingest.py new file mode 100644 index 00000000..ca054435 --- /dev/null +++ b/docs/examples/adapters/adbc_postgres_ingest.py @@ -0,0 +1,161 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "sqlspec[adbc]", +# "pyarrow", +# "rich", +# "rich-click", +# ] +# /// +"""ADBC Postgres ingestion workflow leveraging the storage bridge. + +This example exports arbitrary SELECT statements to a Parquet or Arrow artifact, +then loads the staged data back into a target table using the same ADBC driver. +Use it as a template for warehouse ↔ object-store fan-outs. +""" + +from pathlib import Path +from typing import Any + +import rich_click as click +from rich.console import Console +from rich.table import Table + +from sqlspec import SQLSpec +from sqlspec.adapters.adbc import AdbcConfig +from sqlspec.storage import StorageTelemetry +from sqlspec.utils.serializers import to_json + +__all__ = ("main",) + + +def _build_partitioner(rows_per_chunk: int | None, partitions: int | None) -> "dict[str, Any] | None": + if rows_per_chunk and partitions: + msg = "Use either --rows-per-chunk or --partitions, not both." + raise click.BadParameter(msg, param_hint="--rows-per-chunk / --partitions") + if rows_per_chunk: + return {"kind": "rows_per_chunk", "rows_per_chunk": rows_per_chunk} + if partitions: + return {"kind": "fixed", "partitions": partitions} + return None + + +def _write_telemetry(payload: "dict[str, Any]", output_path: Path) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(to_json(payload), encoding="utf-8") + + +def _format_job(stage: str, telemetry: StorageTelemetry) -> str: + rows = telemetry.get("rows_processed", 0) + bytes_processed = telemetry.get("bytes_processed", 0) + destination = telemetry.get("destination", "") + return f"[{stage}] rows={rows} bytes={bytes_processed} destination={destination}" + + +def _render_capabilities(console: Console, config: AdbcConfig) -> None: + capabilities = config.storage_capabilities() + table = Table(title="Storage Capabilities", highlight=True) + table.add_column("Capability", style="cyan") + table.add_column("Enabled", style="green") + for key, value in capabilities.items(): + table.add_row(str(key), str(value)) + console.print(table) + + +@click.command(context_settings={"help_option_names": ["-h", "--help"], "max_content_width": 100}) +@click.option( + "--uri", + required=True, + envvar="SQLSPEC_ADBC_URI", + help="ADBC connection URI (e.g. postgres://user:pass@host:port/dbname)", +) +@click.option("--source-sql", required=True, help="SELECT statement to export") +@click.option("--target-table", required=True, help="Fully qualified destination table name") +@click.option( + "--destination", + type=click.Path(path_type=Path, dir_okay=False, writable=True, resolve_path=True), + default=Path("./tmp/adbc_export.parquet"), + show_default=True, + help="Local path or mounted volume for the staged artifact", +) +@click.option( + "--format", + "file_format", + type=click.Choice(["parquet", "arrow-ipc"], case_sensitive=False), + default="parquet", + show_default=True, + help="Storage format used for export/import", +) +@click.option( + "--rows-per-chunk", + type=int, + help="Rows per partition chunk. Combine with SQL predicates (e.g. `WHERE id BETWEEN ...`) per worker.", +) +@click.option( + "--partitions", + type=int, + help="Fixed number of partitions. Pair with predicates like `MOD(id, N) = worker_id` when parallelizing.", +) +@click.option( + "--overwrite/--no-overwrite", default=False, show_default=True, help="Overwrite the target table before load" +) +@click.option("--skip-load", is_flag=True, default=False, help="Export only and skip the load stage") +@click.option( + "--output-telemetry", + type=click.Path(path_type=Path, dir_okay=False, writable=True, resolve_path=True), + help="Optional path to persist telemetry JSON", +) +@click.version_option(message="%(version)s") +def main( + *, + uri: str, + source_sql: str, + target_table: str, + destination: Path, + file_format: str, + rows_per_chunk: int | None, + partitions: int | None, + overwrite: bool, + skip_load: bool, + output_telemetry: Path | None, +) -> None: + """ADBC-powered export/import demo for Postgres-compatible backends.""" + + console = Console() + partitioner = _build_partitioner(rows_per_chunk, partitions) + destination.parent.mkdir(parents=True, exist_ok=True) + + db_manager = SQLSpec() + adbc_config = AdbcConfig(connection_config={"uri": uri}) + adbc_key = db_manager.add_config(adbc_config) + + _render_capabilities(console, db_manager.get_config(adbc_key)) + telemetry_records: list[dict[str, Any]] = [] + + with db_manager.provide_session(adbc_key) as session: + export_job = session.select_to_storage( + source_sql, str(destination), format_hint=file_format, partitioner=partitioner + ) + console.print(_format_job("export", export_job.telemetry)) + telemetry_records.append({"stage": "export", "metrics": export_job.telemetry}) + + if not skip_load: + load_job = session.load_from_storage( + target_table, str(destination), file_format=file_format, overwrite=overwrite, partitioner=partitioner + ) + console.print(_format_job("load", load_job.telemetry)) + telemetry_records.append({"stage": "load", "metrics": load_job.telemetry}) + + if partitioner: + console.print( + "[dim]Tip:[/] launch multiple workers with mutually exclusive WHERE clauses (" + "for example, `MOD(id, N) = worker_id`) so each process writes a distinct partition." + ) + + if output_telemetry: + payload: dict[str, Any] = {"telemetry": telemetry_records} + _write_telemetry(payload, output_telemetry) + + +if __name__ == "__main__": + main() diff --git a/docs/examples/arrow_basic_usage.py b/docs/examples/arrow_basic_usage.py index 0570b2c4..04b4fa5a 100644 --- a/docs/examples/arrow_basic_usage.py +++ b/docs/examples/arrow_basic_usage.py @@ -28,11 +28,12 @@ async def example_adbc_native() -> None: from sqlspec import SQLSpec from sqlspec.adapters.adbc import AdbcConfig - sql = SQLSpec() - config = AdbcConfig(connection_config={"driver": "adbc_driver_sqlite", "uri": "file::memory:?cache=shared"}) - sql.add_config(config) + db_manager = SQLSpec() + adbc_db = db_manager.add_config( + AdbcConfig(connection_config={"driver": "adbc_driver_sqlite", "uri": "file::memory:?cache=shared"}) + ) - with config.provide_session() as session: + with db_manager.provide_session(adbc_db) as session: # Create test table session.execute( """ @@ -56,7 +57,7 @@ async def example_adbc_native() -> None: ) # Native Arrow fetch - zero-copy! - result = session.select_to_arrow("SELECT * FROM users WHERE age > ?", (25,)) + result = session.select_to_arrow("SELECT * FROM users WHERE age > :min_age", min_age=25) print("ADBC Native Arrow Results:") print(f" Rows: {len(result)}") @@ -77,11 +78,10 @@ async def example_postgres_conversion() -> None: from sqlspec import SQLSpec from sqlspec.adapters.asyncpg import AsyncpgConfig - sql = SQLSpec() - config = AsyncpgConfig(pool_config={"dsn": "postgresql://localhost/test"}) - sql.add_config(config) + db_manager = SQLSpec() + asyncpg_db = db_manager.add_config(AsyncpgConfig(pool_config={"dsn": "postgresql://localhost/test"})) - async with config.provide_session() as session: + async with db_manager.provide_session(asyncpg_db) as session: # Create test table with PostgreSQL-specific types await session.execute( """ @@ -95,14 +95,16 @@ async def example_postgres_conversion() -> None: ) # Insert test data - await session.execute( - "INSERT INTO products (name, price, tags) VALUES ($1, $2, $3)", - [("Widget", 19.99, ["gadget", "tool"]), ("Gadget", 29.99, ["electronics", "new"])], - many=True, + await session.execute_many( + "INSERT INTO products (name, price, tags) VALUES (:name, :price, :tags)", + [ + {"name": "Widget", "price": 19.99, "tags": ["gadget", "tool"]}, + {"name": "Gadget", "price": 29.99, "tags": ["electronics", "new"]}, + ], ) # Conversion path: dict → Arrow - result = await session.select_to_arrow("SELECT * FROM products WHERE price < $1", (25.00,)) + result = await session.select_to_arrow("SELECT * FROM products WHERE price < :price_limit", price_limit=25.00) print("PostgreSQL Conversion Path Results:") print(f" Rows: {len(result)}") @@ -116,11 +118,10 @@ async def example_pandas_integration() -> None: from sqlspec import SQLSpec from sqlspec.adapters.sqlite import SqliteConfig - sql = SQLSpec() - config = SqliteConfig(pool_config={"database": ":memory:"}) - sql.add_config(config) + db_manager = SQLSpec() + sqlite_db = db_manager.add_config(SqliteConfig(pool_config={"database": ":memory:"})) - with config.provide_session() as session: + with db_manager.provide_session(sqlite_db) as session: # Create and populate table session.execute( """ @@ -133,15 +134,14 @@ async def example_pandas_integration() -> None: """ ) - session.execute( - "INSERT INTO sales VALUES (?, ?, ?, ?)", + session.execute_many( + "INSERT INTO sales (id, region, amount, sale_date) VALUES (:id, :region, :amount, :sale_date)", [ - (1, "North", 1000.00, "2024-01-15"), - (2, "South", 1500.00, "2024-01-20"), - (3, "North", 2000.00, "2024-02-10"), - (4, "East", 1200.00, "2024-02-15"), + {"id": 1, "region": "North", "amount": 1000.00, "sale_date": "2024-01-15"}, + {"id": 2, "region": "South", "amount": 1500.00, "sale_date": "2024-01-20"}, + {"id": 3, "region": "North", "amount": 2000.00, "sale_date": "2024-02-10"}, + {"id": 4, "region": "East", "amount": 1200.00, "sale_date": "2024-02-15"}, ], - many=True, ) # Query to Arrow @@ -164,11 +164,10 @@ async def example_polars_integration() -> None: from sqlspec import SQLSpec from sqlspec.adapters.duckdb import DuckDBConfig - sql = SQLSpec() - config = DuckDBConfig(pool_config={"database": ":memory:"}) - sql.add_config(config) + db_manager = SQLSpec() + duckdb = db_manager.add_config(DuckDBConfig(pool_config={"database": ":memory:"})) - with config.provide_session() as session: + with db_manager.provide_session(duckdb) as session: # Create and populate table session.execute( """ @@ -181,15 +180,14 @@ async def example_polars_integration() -> None: """ ) - session.execute( - "INSERT INTO events VALUES (?, ?, ?, ?)", + session.execute_many( + "INSERT INTO events (id, event_type, user_id, timestamp) VALUES (:id, :event_type, :user_id, :ts)", [ - (1, "login", 100, "2024-01-01 10:00:00"), - (2, "click", 100, "2024-01-01 10:05:00"), - (3, "login", 101, "2024-01-01 10:10:00"), - (4, "purchase", 100, "2024-01-01 10:15:00"), + {"id": 1, "event_type": "login", "user_id": 100, "ts": "2024-01-01 10:00:00"}, + {"id": 2, "event_type": "click", "user_id": 100, "ts": "2024-01-01 10:05:00"}, + {"id": 3, "event_type": "login", "user_id": 101, "ts": "2024-01-01 10:10:00"}, + {"id": 4, "event_type": "purchase", "user_id": 100, "ts": "2024-01-01 10:15:00"}, ], - many=True, ) # Query to Arrow (native DuckDB path) @@ -209,15 +207,19 @@ async def example_return_formats() -> None: from sqlspec import SQLSpec from sqlspec.adapters.duckdb import DuckDBConfig - sql = SQLSpec() - config = DuckDBConfig(pool_config={"database": ":memory:"}) - sql.add_config(config) + db_manager = SQLSpec() + duckdb = db_manager.add_config(DuckDBConfig(pool_config={"database": ":memory:"})) - with config.provide_session() as session: + with db_manager.provide_session(duckdb) as session: # Create test data session.execute("CREATE TABLE items (id INTEGER, name VARCHAR, quantity INTEGER)") - session.execute( - "INSERT INTO items VALUES (?, ?, ?)", [(1, "Apple", 10), (2, "Banana", 20), (3, "Orange", 15)], many=True + session.execute_many( + "INSERT INTO items (id, name, quantity) VALUES (:id, :name, :qty)", + [ + {"id": 1, "name": "Apple", "qty": 10}, + {"id": 2, "name": "Banana", "qty": 20}, + {"id": 3, "name": "Orange", "qty": 15}, + ], ) # Table format (default) @@ -241,16 +243,13 @@ async def example_return_formats() -> None: # Example 6: Export to Parquet async def example_parquet_export() -> None: """Demonstrate exporting Arrow results to Parquet.""" - import pyarrow.parquet as pq - from sqlspec import SQLSpec from sqlspec.adapters.duckdb import DuckDBConfig - sql = SQLSpec() - config = DuckDBConfig(pool_config={"database": ":memory:"}) - sql.add_config(config) + db_manager = SQLSpec() + duckdb = db_manager.add_config(DuckDBConfig(pool_config={"database": ":memory:"})) - with config.provide_session() as session: + with db_manager.provide_session(duckdb) as session: # Create and populate table session.execute( """ @@ -263,25 +262,26 @@ async def example_parquet_export() -> None: """ ) - session.execute( - "INSERT INTO logs VALUES (?, ?, ?, ?)", + session.execute_many( + "INSERT INTO logs (id, timestamp, level, message) VALUES (:id, :ts, :level, :message)", [ - (1, "2024-01-01 10:00:00", "INFO", "Application started"), - (2, "2024-01-01 10:05:00", "WARN", "High memory usage"), - (3, "2024-01-01 10:10:00", "ERROR", "Database connection failed"), + {"id": 1, "ts": "2024-01-01 10:00:00", "level": "INFO", "message": "Application started"}, + {"id": 2, "ts": "2024-01-01 10:05:00", "level": "WARN", "message": "High memory usage"}, + {"id": 3, "ts": "2024-01-01 10:10:00", "level": "ERROR", "message": "Database connection failed"}, ], - many=True, ) # Query to Arrow result = session.select_to_arrow("SELECT * FROM logs") - # Export to Parquet - output_path = Path("/tmp/logs.parquet") - pq.write_table(result.data, output_path) + # Export to Parquet using the storage bridge + output_path = Path("/tmp/arrow_basic_usage_logs.parquet") + telemetry = result.write_to_storage_sync(str(output_path), format_hint="parquet") print("Parquet Export:") print(f" Exported to: {output_path}") + print(f" Rows: {telemetry['rows_processed']}") + print(f" Bytes processed: {telemetry['bytes_processed']}") print(f" File size: {output_path.stat().st_size} bytes") print() @@ -294,11 +294,10 @@ async def example_native_only_mode() -> None: from sqlspec.adapters.sqlite import SqliteConfig # ADBC has native Arrow support - sql = SQLSpec() - config = AdbcConfig(connection_config={"uri": "sqlite://:memory:"}) - sql.add_config(config) + db_manager = SQLSpec() + adbc_sqlite = db_manager.add_config(AdbcConfig(connection_config={"uri": "sqlite://:memory:"})) - with config.provide_session() as session: + with db_manager.provide_session(adbc_sqlite) as session: session.execute("CREATE TABLE test (id INTEGER, name TEXT)") session.execute("INSERT INTO test VALUES (1, 'test')") @@ -309,10 +308,9 @@ async def example_native_only_mode() -> None: print() # SQLite does not have native Arrow support - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - sql.add_config(sqlite_config) + sqlite_db = db_manager.add_config(SqliteConfig(pool_config={"database": ":memory:"})) - with sqlite_config.provide_session() as session: + with db_manager.provide_session(sqlite_db) as session: session.execute("CREATE TABLE test (id INTEGER, name TEXT)") session.execute("INSERT INTO test VALUES (1, 'test')") result = session.select_to_arrow("SELECT * FROM test", native_only=True) diff --git a/docs/examples/index.rst b/docs/examples/index.rst index 46765f9e..121af0bd 100644 --- a/docs/examples/index.rst +++ b/docs/examples/index.rst @@ -60,6 +60,9 @@ Adapters * - File - Adapter - Highlights + * - ``adapters/adbc_postgres_ingest.py`` + - ADBC (Postgres) + - Rich Click CLI that exports SELECT queries to Parquet/Arrow and loads them via the storage bridge. * - ``adapters/asyncpg/connect_pool.py`` - AsyncPG - Minimal pool configuration plus a version probe. @@ -124,6 +127,7 @@ Shared Utilities frameworks/litestar/aiosqlite_app frameworks/litestar/duckdb_app frameworks/litestar/sqlite_app + adapters/adbc_postgres_ingest adapters/asyncpg/connect_pool adapters/psycopg/connect_sync adapters/oracledb/connect_async diff --git a/docs/guides/adapters/adbc.md b/docs/guides/adapters/adbc.md index 92ce4717..55e3ed5e 100644 --- a/docs/guides/adapters/adbc.md +++ b/docs/guides/adapters/adbc.md @@ -23,6 +23,34 @@ This guide provides specific instructions for the `adbc` adapter. - **Driver Installation:** Each database requires a separate ADBC driver to be installed (e.g., `pip install adbc_driver_postgresql`). - **Data Types:** Be aware of how database types are mapped to Arrow types. Use `Context7` to research the specific ADBC driver's documentation for type mapping details. +## Parallel exports and loads + +The storage bridge intentionally supports multi-worker fan-outs. The [ADBC Postgres CLI example](../../examples/adapters/adbc_postgres_ingest.py) exposes `--rows-per-chunk` and `--partitions` so orchestration layers (GNU Parallel, Airflow, Dagster, etc.) can run many exporters in parallel without stepping on each other: + +1. **Pick a partition mode.** `--partitions N` spreads artifacts across `N` numbered partitions, whereas `--rows-per-chunk K` rolls to a new artifact every `K` rows. +2. **Add mutually-exclusive predicates.** The CLI does not rewrite your SQL; give each worker a unique filter such as `WHERE MOD(id, 4) = worker_id` or `WHERE id BETWEEN 1_000 AND 1_999`. +3. **Write to unique destinations.** Point each worker at its own artifact path (`alias://bucket/job-42/worker-00.parquet`). Aliases can be local (file://) or remote (MinIO/S3). +4. **Load concurrently.** Either let each worker call `load_from_storage()` (set `--overwrite` on worker 0 only) or perform a final merge step that iterates over the staged artifacts. + +Example shell loop: + +```bash +export SQLSPEC_ADBC_URI="postgresql://user:pass@host:5432/db" +PARTITIONS=4 +for worker in $(seq 0 $((PARTITIONS-1))); do + uv run python docs/examples/adapters/adbc_postgres_ingest.py \ + --source-sql "SELECT id, amount FROM fact_sales WHERE MOD(id, $PARTITIONS) = $worker" \ + --target-table staging.sales_ingest \ + --destination "alias://local-job/run-42/worker-$worker.parquet" \ + --partitions $PARTITIONS \ + --overwrite=$([[ $worker -eq 0 ]] && echo true || echo false) \ + --skip-load=false & +done +wait +``` + +When either partition option is provided the CLI prints a reminder to include per-worker predicates. The same approach applies if you drive the storage bridge programmatically: route each worker’s query to a disjoint slice of the dataset and the bridge will keep artifacts isolated for you. + ## Arrow Support (Native) The ADBC adapter provides **native Apache Arrow support** through `select_to_arrow()`, offering zero-copy data transfer for exceptional performance. @@ -35,18 +63,19 @@ ADBC uses the driver's built-in `fetch_arrow_table()` method for direct Arrow re from sqlspec import SQLSpec from sqlspec.adapters.adbc import AdbcConfig -sql = SQLSpec() -config = AdbcConfig( - driver="adbc_driver_postgresql", - pool_config={"uri": "postgresql://localhost/mydb"} +db_manager = SQLSpec() +adbc_db = db_manager.add_config( + AdbcConfig( + driver="adbc_driver_postgresql", + connection_config={"uri": "postgresql://localhost/mydb"}, + ) ) -sql.add_config(config) -async with sql.provide_session() as session: +async with db_manager.provide_session(adbc_db) as session: # Native Arrow fetch - zero-copy! result = await session.select_to_arrow( "SELECT * FROM users WHERE age > :age", - {"age": 18} + age=18, ) print(f"Rows: {len(result)}") @@ -56,12 +85,14 @@ async with sql.provide_session() as session: ### Performance Characteristics **Native Arrow Benefits**: + - **5-10x faster** than dict conversion for large datasets - **Zero-copy data transfer** - no intermediate representations - **Native type preservation** - database types mapped directly to Arrow - **Memory efficient** - columnar format reduces memory usage **Benchmarks** (10K rows, 10 columns): + - Native Arrow: ~5ms - Dict conversion: ~25ms - **Speedup**: 5x @@ -148,49 +179,61 @@ ADBC preserves native database types in Arrow format: **PostgreSQL** (`adbc_driver_postgresql`): ```python -config = AdbcConfig( - driver="adbc_driver_postgresql", - pool_config={"uri": "postgresql://localhost/db"} +from sqlspec import SQLSpec +from sqlspec.adapters.adbc import AdbcConfig + +db_manager = SQLSpec() +postgres_db = db_manager.add_config( + AdbcConfig( + driver="adbc_driver_postgresql", + connection_config={"uri": "postgresql://localhost/db"}, + ) ) -async with sql.provide_session(config) as session: +async with db_manager.provide_session(postgres_db) as session: # PostgreSQL arrays preserved in Arrow list type - result = await session.select_to_arrow( - "SELECT id, tags FROM articles" - ) + result = await session.select_to_arrow("SELECT id, tags FROM articles") ``` **SQLite** (`adbc_driver_sqlite`): ```python -config = AdbcConfig( - driver="adbc_driver_sqlite", - pool_config={"uri": "file:app.db"} +from sqlspec import SQLSpec +from sqlspec.adapters.adbc import AdbcConfig + +db_manager = SQLSpec() +sqlite_db = db_manager.add_config( + AdbcConfig( + driver="adbc_driver_sqlite", + connection_config={"uri": "file:app.db"}, + ) ) -with sql.provide_session(config) as session: +with db_manager.provide_session(sqlite_db) as session: # SQLite types mapped to Arrow - result = session.select_to_arrow( - "SELECT * FROM users" - ) + result = session.select_to_arrow("SELECT * FROM users") ``` **Snowflake** (`adbc_driver_snowflake`): ```python -config = AdbcConfig( - driver="adbc_driver_snowflake", - pool_config={ - "uri": "snowflake://account.region.snowflakecomputing.com/database/schema", - "adbc.snowflake.sql.account": "your_account", - "adbc.snowflake.sql.user": "your_user" - } -) +from sqlspec import SQLSpec +from sqlspec.adapters.adbc import AdbcConfig -async with sql.provide_session(config) as session: - result = await session.select_to_arrow( - "SELECT * FROM large_table" +db_manager = SQLSpec() +snowflake_db = db_manager.add_config( + AdbcConfig( + driver="adbc_driver_snowflake", + connection_config={ + "uri": "snowflake://account.region.snowflakecomputing.com/database/schema", + "adbc.snowflake.sql.account": "your_account", + "adbc.snowflake.sql.user": "your_user", + }, ) +) + +async with db_manager.provide_session(snowflake_db) as session: + result = await session.select_to_arrow("SELECT * FROM large_table") ``` ### Best Practices diff --git a/pyproject.toml b/pyproject.toml index e41c8faa..837ab49e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -281,6 +281,7 @@ filterwarnings = [ "ignore::DeprecationWarning:pkg_resources.*", "ignore:pkg_resources is deprecated as an API:DeprecationWarning", "ignore::DeprecationWarning:pkg_resources", + "ignore:You are using a Python version .+ google\\.api_core:FutureWarning", "ignore::DeprecationWarning:google.rpc", "ignore::DeprecationWarning:google.gcloud", "ignore::DeprecationWarning:google.iam", diff --git a/sqlspec/adapters/adbc/config.py b/sqlspec/adapters/adbc/config.py index aa533bc1..de514f5c 100644 --- a/sqlspec/adapters/adbc/config.py +++ b/sqlspec/adapters/adbc/config.py @@ -101,6 +101,11 @@ class AdbcConfig(NoPoolSyncConfig[AdbcConnection, AdbcDriver]): driver_type: ClassVar[type[AdbcDriver]] = AdbcDriver connection_type: "ClassVar[type[AdbcConnection]]" = AdbcConnection supports_transactional_ddl: ClassVar[bool] = False + supports_native_arrow_export: "ClassVar[bool]" = True + supports_native_arrow_import: "ClassVar[bool]" = True + supports_native_parquet_export: "ClassVar[bool]" = True + supports_native_parquet_import: "ClassVar[bool]" = True + storage_partition_strategies: "ClassVar[tuple[str, ...]]" = ("fixed", "rows_per_chunk") def __init__( self, diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index 8cfa54b5..f32f2acd 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -7,7 +7,7 @@ import contextlib import datetime import decimal -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Literal, cast from sqlspec.adapters.adbc.data_dictionary import AdbcDataDictionary from sqlspec.adapters.adbc.type_converter import ADBCTypeConverter @@ -51,6 +51,13 @@ from sqlspec.core import ArrowResult, SQLResult, Statement, StatementFilter from sqlspec.driver import ExecutionResult from sqlspec.driver._sync import SyncDataDictionaryBase + from sqlspec.storage import ( + StorageBridgeJob, + StorageDestination, + StorageFormat, + StorageTelemetry, + SyncStoragePipeline, + ) from sqlspec.typing import ArrowReturnFormat, StatementParameters __all__ = ("AdbcCursor", "AdbcDriver", "AdbcExceptionHandler", "get_adbc_statement_config") @@ -681,6 +688,66 @@ def select_to_arrow( # Create ArrowResult return create_arrow_result(statement=prepared_statement, data=arrow_data, rows_affected=arrow_data.num_rows) + def select_to_storage( + self, + statement: "Statement | QueryBuilder | SQL | str", + destination: "StorageDestination", + /, + *parameters: "StatementParameters | StatementFilter", + statement_config: "StatementConfig | None" = None, + partitioner: "dict[str, Any] | None" = None, + format_hint: "StorageFormat | None" = None, + telemetry: "StorageTelemetry | None" = None, + **kwargs: Any, + ) -> "StorageBridgeJob": + """Stream query results to storage via the Arrow fast path.""" + + _ = kwargs + self._require_capability("arrow_export_enabled") + arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) + sync_pipeline: SyncStoragePipeline = cast("SyncStoragePipeline", self._storage_pipeline()) + telemetry_payload = arrow_result.write_to_storage_sync( + destination, format_hint=format_hint, pipeline=sync_pipeline + ) + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + def load_from_arrow( + self, + table: str, + source: "ArrowResult | Any", + *, + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + telemetry: "StorageTelemetry | None" = None, + ) -> "StorageBridgeJob": + """Ingest an Arrow payload directly through the ADBC cursor.""" + + self._require_capability("arrow_import_enabled") + arrow_table = self._coerce_arrow_table(source) + ingest_mode: Literal["append", "create", "replace", "create_append"] + ingest_mode = "replace" if overwrite else "create_append" + with self.with_cursor(self.connection) as cursor, self.handle_database_exceptions(): + cursor.adbc_ingest(table, arrow_table, mode=ingest_mode) + telemetry_payload = self._build_ingest_telemetry(arrow_table) + telemetry_payload["destination"] = table + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + def load_from_storage( + self, + table: str, + source: "StorageDestination", + *, + file_format: "StorageFormat", + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + ) -> "StorageBridgeJob": + """Read an artifact from storage and ingest it via ADBC.""" + + arrow_table, inbound = self._read_arrow_from_storage_sync(source, file_format=file_format) + return self.load_from_arrow(table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound) + def get_type_coercion_map(dialect: str) -> "dict[type, Any]": """Return dialect-aware type coercion mapping for Arrow parameter handling.""" diff --git a/sqlspec/adapters/aiosqlite/config.py b/sqlspec/adapters/aiosqlite/config.py index 240e0703..8235968e 100644 --- a/sqlspec/adapters/aiosqlite/config.py +++ b/sqlspec/adapters/aiosqlite/config.py @@ -74,6 +74,10 @@ class AiosqliteConfig(AsyncDatabaseConfig["AiosqliteConnection", AiosqliteConnec driver_type: "ClassVar[type[AiosqliteDriver]]" = AiosqliteDriver connection_type: "ClassVar[type[AiosqliteConnection]]" = AiosqliteConnection supports_transactional_ddl: "ClassVar[bool]" = True + supports_native_arrow_export: "ClassVar[bool]" = True + supports_native_arrow_import: "ClassVar[bool]" = True + supports_native_parquet_export: "ClassVar[bool]" = True + supports_native_parquet_import: "ClassVar[bool]" = True def __init__( self, diff --git a/sqlspec/adapters/aiosqlite/driver.py b/sqlspec/adapters/aiosqlite/driver.py index 24338367..8e118dc9 100644 --- a/sqlspec/adapters/aiosqlite/driver.py +++ b/sqlspec/adapters/aiosqlite/driver.py @@ -4,11 +4,12 @@ import contextlib from datetime import date, datetime from decimal import Decimal -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import aiosqlite from sqlspec.core import ( + ArrowResult, DriverParameterProfile, ParameterStyle, build_statement_config_from_profile, @@ -38,6 +39,13 @@ from sqlspec.core import SQL, SQLResult, StatementConfig from sqlspec.driver import ExecutionResult from sqlspec.driver._async import AsyncDataDictionaryBase + from sqlspec.storage import ( + AsyncStoragePipeline, + StorageBridgeJob, + StorageDestination, + StorageFormat, + StorageTelemetry, + ) __all__ = ("AiosqliteCursor", "AiosqliteDriver", "AiosqliteExceptionHandler", "aiosqlite_statement_config") @@ -273,6 +281,72 @@ async def _execute_statement(self, cursor: "aiosqlite.Cursor", statement: "SQL") affected_rows = cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 return self.create_execution_result(cursor, rowcount_override=affected_rows) + async def select_to_storage( + self, + statement: "SQL | str", + destination: "StorageDestination", + /, + *parameters: Any, + statement_config: "StatementConfig | None" = None, + partitioner: "dict[str, Any] | None" = None, + format_hint: "StorageFormat | None" = None, + telemetry: "StorageTelemetry | None" = None, + **kwargs: Any, + ) -> "StorageBridgeJob": + """Execute a query and stream Arrow results into storage.""" + + self._require_capability("arrow_export_enabled") + arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) + async_pipeline: AsyncStoragePipeline = cast("AsyncStoragePipeline", self._storage_pipeline()) + telemetry_payload = await arrow_result.write_to_storage_async( + destination, format_hint=format_hint, pipeline=async_pipeline + ) + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + async def load_from_arrow( + self, + table: str, + source: "ArrowResult | Any", + *, + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + telemetry: "StorageTelemetry | None" = None, + ) -> "StorageBridgeJob": + """Load Arrow data into SQLite using batched inserts.""" + + self._require_capability("arrow_import_enabled") + arrow_table = self._coerce_arrow_table(source) + if overwrite: + await self._truncate_table_async(table) + + columns, records = self._arrow_table_to_rows(arrow_table) + if records: + insert_sql = _build_sqlite_insert_statement(table, columns) + async with self.handle_database_exceptions(), self.with_cursor(self.connection) as cursor: + await cursor.executemany(insert_sql, records) + + telemetry_payload = self._build_ingest_telemetry(arrow_table) + telemetry_payload["destination"] = table + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + async def load_from_storage( + self, + table: str, + source: "StorageDestination", + *, + file_format: "StorageFormat", + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + ) -> "StorageBridgeJob": + """Load staged artifacts from storage into SQLite.""" + + arrow_table, inbound = await self._read_arrow_from_storage_async(source, file_format=file_format) + return await self.load_from_arrow( + table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound + ) + async def begin(self) -> None: """Begin a database transaction.""" try: @@ -311,6 +385,11 @@ async def commit(self) -> None: msg = f"Failed to commit transaction: {e}" raise SQLSpecError(msg) from e + async def _truncate_table_async(self, table: str) -> None: + statement = f"DELETE FROM {_format_sqlite_identifier(table)}" + async with self.handle_database_exceptions(), self.with_cursor(self.connection) as cursor: + await cursor.execute(statement) + @property def data_dictionary(self) -> "AsyncDataDictionaryBase": """Get the data dictionary for this driver. @@ -329,6 +408,27 @@ def _bool_to_int(value: bool) -> int: return int(value) +def _quote_sqlite_identifier(identifier: str) -> str: + normalized = identifier.replace('"', '""') + return f'"{normalized}"' + + +def _format_sqlite_identifier(identifier: str) -> str: + cleaned = identifier.strip() + if not cleaned: + msg = "Table name must not be empty" + raise SQLSpecError(msg) + parts = [part for part in cleaned.split(".") if part] + formatted = ".".join(_quote_sqlite_identifier(part) for part in parts) + return formatted or _quote_sqlite_identifier(cleaned) + + +def _build_sqlite_insert_statement(table: str, columns: "list[str]") -> str: + column_clause = ", ".join(_quote_sqlite_identifier(column) for column in columns) + placeholders = ", ".join("?" for _ in columns) + return f"INSERT INTO {_format_sqlite_identifier(table)} ({column_clause}) VALUES ({placeholders})" + + def _build_aiosqlite_profile() -> DriverParameterProfile: """Create the AIOSQLite driver parameter profile.""" diff --git a/sqlspec/adapters/asyncmy/config.py b/sqlspec/adapters/asyncmy/config.py index 2735e18a..1b0e5977 100644 --- a/sqlspec/adapters/asyncmy/config.py +++ b/sqlspec/adapters/asyncmy/config.py @@ -89,6 +89,10 @@ class AsyncmyConfig(AsyncDatabaseConfig[AsyncmyConnection, "AsyncmyPool", Asyncm driver_type: ClassVar[type[AsyncmyDriver]] = AsyncmyDriver connection_type: "ClassVar[type[AsyncmyConnection]]" = AsyncmyConnection # pyright: ignore supports_transactional_ddl: ClassVar[bool] = False + supports_native_arrow_export: ClassVar[bool] = True + supports_native_parquet_export: ClassVar[bool] = True + supports_native_arrow_import: ClassVar[bool] = True + supports_native_parquet_import: ClassVar[bool] = True def __init__( self, diff --git a/sqlspec/adapters/asyncmy/driver.py b/sqlspec/adapters/asyncmy/driver.py index d2794617..aaa3d446 100644 --- a/sqlspec/adapters/asyncmy/driver.py +++ b/sqlspec/adapters/asyncmy/driver.py @@ -5,13 +5,14 @@ """ import logging -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any, Final, cast import asyncmy.errors # pyright: ignore from asyncmy.constants import FIELD_TYPE as ASYNC_MY_FIELD_TYPE # pyright: ignore from asyncmy.cursors import Cursor, DictCursor # pyright: ignore from sqlspec.core import ( + ArrowResult, DriverParameterProfile, ParameterStyle, build_statement_config_from_profile, @@ -41,6 +42,13 @@ from sqlspec.core import SQL, SQLResult, StatementConfig from sqlspec.driver import ExecutionResult from sqlspec.driver._async import AsyncDataDictionaryBase + from sqlspec.storage import ( + AsyncStoragePipeline, + StorageBridgeJob, + StorageDestination, + StorageFormat, + StorageTelemetry, + ) __all__ = ( "AsyncmyCursor", "AsyncmyDriver", @@ -432,6 +440,72 @@ async def _execute_statement(self, cursor: Any, statement: "SQL") -> "ExecutionR last_id = getattr(cursor, "lastrowid", None) if cursor.rowcount and cursor.rowcount > 0 else None return self.create_execution_result(cursor, rowcount_override=affected_rows, last_inserted_id=last_id) + async def select_to_storage( + self, + statement: "SQL | str", + destination: "StorageDestination", + /, + *parameters: Any, + statement_config: "StatementConfig | None" = None, + partitioner: "dict[str, Any] | None" = None, + format_hint: "StorageFormat | None" = None, + telemetry: "StorageTelemetry | None" = None, + **kwargs: Any, + ) -> "StorageBridgeJob": + """Execute a query and stream Arrow-formatted results into storage.""" + + self._require_capability("arrow_export_enabled") + arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) + async_pipeline: AsyncStoragePipeline = cast("AsyncStoragePipeline", self._storage_pipeline()) + telemetry_payload = await arrow_result.write_to_storage_async( + destination, format_hint=format_hint, pipeline=async_pipeline + ) + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + async def load_from_arrow( + self, + table: str, + source: "ArrowResult | Any", + *, + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + telemetry: "StorageTelemetry | None" = None, + ) -> "StorageBridgeJob": + """Load Arrow data into MySQL using batched inserts.""" + + self._require_capability("arrow_import_enabled") + arrow_table = self._coerce_arrow_table(source) + if overwrite: + await self._truncate_table_async(table) + + columns, records = self._arrow_table_to_rows(arrow_table) + if records: + insert_sql = _build_asyncmy_insert_statement(table, columns) + async with self.handle_database_exceptions(), self.with_cursor(self.connection) as cursor: + await cursor.executemany(insert_sql, records) + + telemetry_payload = self._build_ingest_telemetry(arrow_table) + telemetry_payload["destination"] = table + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + async def load_from_storage( + self, + table: str, + source: "StorageDestination", + *, + file_format: "StorageFormat", + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + ) -> "StorageBridgeJob": + """Load staged artifacts from storage into MySQL.""" + + arrow_table, inbound = await self._read_arrow_from_storage_async(source, file_format=file_format) + return await self.load_from_arrow( + table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound + ) + async def begin(self) -> None: """Begin a database transaction. @@ -471,6 +545,11 @@ async def commit(self) -> None: msg = f"Failed to commit MySQL transaction: {e}" raise SQLSpecError(msg) from e + async def _truncate_table_async(self, table: str) -> None: + statement = f"TRUNCATE TABLE {_format_mysql_identifier(table)}" + async with self.handle_database_exceptions(), self.with_cursor(self.connection) as cursor: + await cursor.execute(statement) + @property def data_dictionary(self) -> "AsyncDataDictionaryBase": """Get the data dictionary for this driver. @@ -489,6 +568,27 @@ def _bool_to_int(value: bool) -> int: return int(value) +def _quote_mysql_identifier(identifier: str) -> str: + normalized = identifier.replace("`", "``") + return f"`{normalized}`" + + +def _format_mysql_identifier(identifier: str) -> str: + cleaned = identifier.strip() + if not cleaned: + msg = "Table name must not be empty" + raise SQLSpecError(msg) + parts = [part for part in cleaned.split(".") if part] + formatted = ".".join(_quote_mysql_identifier(part) for part in parts) + return formatted or _quote_mysql_identifier(cleaned) + + +def _build_asyncmy_insert_statement(table: str, columns: "list[str]") -> str: + column_clause = ", ".join(_quote_mysql_identifier(column) for column in columns) + placeholders = ", ".join("%s" for _ in columns) + return f"INSERT INTO {_format_mysql_identifier(table)} ({column_clause}) VALUES ({placeholders})" + + def _build_asyncmy_profile() -> DriverParameterProfile: """Create the AsyncMy driver parameter profile.""" diff --git a/sqlspec/adapters/asyncpg/config.py b/sqlspec/adapters/asyncpg/config.py index 5ec0f938..499bb941 100644 --- a/sqlspec/adapters/asyncpg/config.py +++ b/sqlspec/adapters/asyncpg/config.py @@ -102,6 +102,10 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async driver_type: "ClassVar[type[AsyncpgDriver]]" = AsyncpgDriver connection_type: "ClassVar[type[AsyncpgConnection]]" = type(AsyncpgConnection) # type: ignore[assignment] supports_transactional_ddl: "ClassVar[bool]" = True + supports_native_arrow_export: "ClassVar[bool]" = True + supports_native_arrow_import: "ClassVar[bool]" = True + supports_native_parquet_export: "ClassVar[bool]" = True + supports_native_parquet_import: "ClassVar[bool]" = True def __init__( self, diff --git a/sqlspec/adapters/asyncpg/driver.py b/sqlspec/adapters/asyncpg/driver.py index 3c36c874..fd662ba0 100644 --- a/sqlspec/adapters/asyncpg/driver.py +++ b/sqlspec/adapters/asyncpg/driver.py @@ -2,7 +2,7 @@ import datetime import re -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any, Final, cast import asyncpg @@ -37,8 +37,15 @@ from contextlib import AbstractAsyncContextManager from sqlspec.adapters.asyncpg._types import AsyncpgConnection - from sqlspec.core import SQL, ParameterStyleConfig, SQLResult, StatementConfig + from sqlspec.core import SQL, ArrowResult, ParameterStyleConfig, SQLResult, StatementConfig from sqlspec.driver import AsyncDataDictionaryBase, ExecutionResult + from sqlspec.storage import ( + AsyncStoragePipeline, + StorageBridgeJob, + StorageDestination, + StorageFormat, + StorageTelemetry, + ) __all__ = ( "AsyncpgCursor", @@ -335,6 +342,68 @@ async def _execute_statement(self, cursor: "AsyncpgConnection", statement: "SQL" return self.create_execution_result(cursor, rowcount_override=affected_rows) + async def select_to_storage( + self, + statement: "SQL | str", + destination: "StorageDestination", + /, + *parameters: Any, + statement_config: "StatementConfig | None" = None, + partitioner: "dict[str, Any] | None" = None, + format_hint: "StorageFormat | None" = None, + telemetry: "StorageTelemetry | None" = None, + **kwargs: Any, + ) -> "StorageBridgeJob": + """Execute a query and persist results to storage once native COPY is available.""" + + self._require_capability("arrow_export_enabled") + arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) + async_pipeline: AsyncStoragePipeline = cast("AsyncStoragePipeline", self._storage_pipeline()) + telemetry_payload = await arrow_result.write_to_storage_async( + destination, format_hint=format_hint, pipeline=async_pipeline + ) + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + async def load_from_arrow( + self, + table: str, + source: "ArrowResult | Any", + *, + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + telemetry: "StorageTelemetry | None" = None, + ) -> "StorageBridgeJob": + """Load Arrow data into a PostgreSQL table via COPY.""" + + self._require_capability("arrow_import_enabled") + arrow_table = self._coerce_arrow_table(source) + if overwrite: + await self._truncate_table(table) + columns, records = self._arrow_table_to_rows(arrow_table) + if records: + await self.connection.copy_records_to_table(table, records=records, columns=columns) + telemetry_payload = self._build_ingest_telemetry(arrow_table) + telemetry_payload["destination"] = table + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + async def load_from_storage( + self, + table: str, + source: "StorageDestination", + *, + file_format: "StorageFormat", + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + ) -> "StorageBridgeJob": + """Read an artifact from storage and ingest it via COPY.""" + + arrow_table, inbound = await self._read_arrow_from_storage_async(source, file_format=file_format) + return await self.load_from_arrow( + table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound + ) + @staticmethod def _parse_asyncpg_status(status: str) -> int: """Parse AsyncPG status string to extract row count. @@ -399,6 +468,13 @@ def data_dictionary(self) -> "AsyncDataDictionaryBase": self._data_dictionary = PostgresAsyncDataDictionary() return self._data_dictionary + async def _truncate_table(self, table: str) -> None: + try: + await self.connection.execute(f"TRUNCATE TABLE {table}") + except asyncpg.PostgresError as exc: + msg = f"Failed to truncate table '{table}': {exc}" + raise SQLSpecError(msg) from exc + def _convert_datetime_param(value: Any) -> Any: """Convert datetime parameter, handling ISO strings.""" diff --git a/sqlspec/adapters/bigquery/config.py b/sqlspec/adapters/bigquery/config.py index ce2e9ed0..72f58617 100644 --- a/sqlspec/adapters/bigquery/config.py +++ b/sqlspec/adapters/bigquery/config.py @@ -101,6 +101,11 @@ class BigQueryConfig(NoPoolSyncConfig[BigQueryConnection, BigQueryDriver]): driver_type: ClassVar[type[BigQueryDriver]] = BigQueryDriver connection_type: "ClassVar[type[BigQueryConnection]]" = BigQueryConnection supports_transactional_ddl: ClassVar[bool] = False + supports_native_parquet_import: ClassVar[bool] = True + supports_native_arrow_export: ClassVar[bool] = True + supports_native_parquet_export: ClassVar[bool] = True + requires_staging_for_load: ClassVar[bool] = True + staging_protocols: ClassVar[tuple[str, ...]] = ("gs://",) def __init__( self, diff --git a/sqlspec/adapters/bigquery/driver.py b/sqlspec/adapters/bigquery/driver.py index 5910c275..26d84b4f 100644 --- a/sqlspec/adapters/bigquery/driver.py +++ b/sqlspec/adapters/bigquery/driver.py @@ -5,13 +5,14 @@ """ import datetime +import io import logging from collections.abc import Callable from decimal import Decimal from typing import TYPE_CHECKING, Any, cast import sqlglot -from google.cloud.bigquery import ArrayQueryParameter, QueryJob, QueryJobConfig, ScalarQueryParameter +from google.cloud.bigquery import ArrayQueryParameter, LoadJobConfig, QueryJob, QueryJobConfig, ScalarQueryParameter from google.cloud.exceptions import GoogleCloudError from sqlspec.adapters.bigquery._types import BigQueryConnection @@ -34,6 +35,7 @@ OperationalError, SQLParsingError, SQLSpecError, + StorageCapabilityError, UniqueViolationError, ) from sqlspec.utils.serializers import to_json @@ -45,6 +47,13 @@ from sqlspec.builder import QueryBuilder from sqlspec.core import SQL, ArrowResult, SQLResult, Statement, StatementFilter from sqlspec.driver import SyncDataDictionaryBase + from sqlspec.storage import ( + StorageBridgeJob, + StorageDestination, + StorageFormat, + StorageTelemetry, + SyncStoragePipeline, + ) from sqlspec.typing import ArrowReturnFormat, StatementParameters logger = logging.getLogger(__name__) @@ -435,6 +444,37 @@ def _copy_job_config_attrs(self, source_config: QueryJobConfig, target_config: Q except (AttributeError, TypeError): continue + def _map_source_format(self, file_format: "StorageFormat") -> str: + if file_format == "parquet": + return "PARQUET" + if file_format in {"json", "jsonl"}: + return "NEWLINE_DELIMITED_JSON" + msg = f"BigQuery does not support loading '{file_format}' artifacts via the storage bridge" + raise StorageCapabilityError(msg, capability="parquet_import_enabled") + + def _build_load_job_config(self, file_format: "StorageFormat", overwrite: bool) -> LoadJobConfig: + job_config = LoadJobConfig() + job_config.source_format = self._map_source_format(file_format) + job_config.write_disposition = "WRITE_TRUNCATE" if overwrite else "WRITE_APPEND" + return job_config + + def _build_load_job_telemetry(self, job: QueryJob, table: str, *, format_label: str) -> "StorageTelemetry": + properties = getattr(job, "_properties", {}) + load_stats = properties.get("statistics", {}).get("load", {}) + rows_processed = int(load_stats.get("outputRows") or getattr(job, "output_rows", 0) or 0) + bytes_processed = int(load_stats.get("outputBytes") or load_stats.get("inputFileBytes", 0) or 0) + duration = 0.0 + if job.ended and job.started: + duration = (job.ended - job.started).total_seconds() + telemetry: StorageTelemetry = { + "destination": table, + "rows_processed": rows_processed, + "bytes_processed": bytes_processed, + "duration_s": duration, + "format": format_label, + } + return telemetry + def _run_query_job( self, sql_str: str, @@ -755,6 +795,82 @@ def select_to_arrow( # Create ArrowResult return create_arrow_result(statement=prepared_statement, data=arrow_data, rows_affected=arrow_data.num_rows) + def select_to_storage( + self, + statement: "Statement | QueryBuilder | SQL | str", + destination: "StorageDestination", + /, + *parameters: "StatementParameters | StatementFilter", + statement_config: "StatementConfig | None" = None, + partitioner: "dict[str, Any] | None" = None, + format_hint: "StorageFormat | None" = None, + telemetry: "StorageTelemetry | None" = None, + **kwargs: Any, + ) -> "StorageBridgeJob": + """Execute a query and persist Arrow results to a storage backend.""" + + self._require_capability("arrow_export_enabled") + arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) + sync_pipeline: SyncStoragePipeline = cast("SyncStoragePipeline", self._storage_pipeline()) + telemetry_payload = arrow_result.write_to_storage_sync( + destination, format_hint=format_hint, pipeline=sync_pipeline + ) + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + def load_from_arrow( + self, + table: str, + source: "ArrowResult | Any", + *, + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + telemetry: "StorageTelemetry | None" = None, + ) -> "StorageBridgeJob": + """Load Arrow data by uploading a temporary Parquet payload to BigQuery.""" + + self._require_capability("parquet_import_enabled") + arrow_table = self._coerce_arrow_table(source) + from sqlspec.utils.module_loader import ensure_pyarrow + + ensure_pyarrow() + + import pyarrow.parquet as pq + + buffer = io.BytesIO() + pq.write_table(arrow_table, buffer) + buffer.seek(0) + job_config = self._build_load_job_config("parquet", overwrite) + job = self.connection.load_table_from_file(buffer, table, job_config=job_config) + job.result() + telemetry_payload = self._build_load_job_telemetry(job, table, format_label="parquet") + if telemetry: + telemetry_payload.setdefault("extra", {}) + telemetry_payload["extra"]["arrow_rows"] = telemetry.get("rows_processed") + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload) + + def load_from_storage( + self, + table: str, + source: "StorageDestination", + *, + file_format: "StorageFormat", + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + ) -> "StorageBridgeJob": + """Load staged artifacts from storage into BigQuery.""" + + if file_format != "parquet": + msg = "BigQuery storage bridge currently supports Parquet ingest only" + raise StorageCapabilityError(msg, capability="parquet_import_enabled") + job_config = self._build_load_job_config(file_format, overwrite) + job = self.connection.load_table_from_uri(source, table, job_config=job_config) + job.result() + telemetry_payload = self._build_load_job_telemetry(job, table, format_label=file_format) + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload) + def _build_bigquery_profile() -> DriverParameterProfile: """Create the BigQuery driver parameter profile.""" diff --git a/sqlspec/adapters/duckdb/config.py b/sqlspec/adapters/duckdb/config.py index 3767a3f2..6591d085 100644 --- a/sqlspec/adapters/duckdb/config.py +++ b/sqlspec/adapters/duckdb/config.py @@ -178,6 +178,11 @@ class DuckDBConfig(SyncDatabaseConfig[DuckDBConnection, DuckDBConnectionPool, Du driver_type: "ClassVar[type[DuckDBDriver]]" = DuckDBDriver connection_type: "ClassVar[type[DuckDBConnection]]" = DuckDBConnection supports_transactional_ddl: "ClassVar[bool]" = True + supports_native_arrow_export: "ClassVar[bool]" = True + supports_native_arrow_import: "ClassVar[bool]" = True + supports_native_parquet_export: "ClassVar[bool]" = True + supports_native_parquet_import: "ClassVar[bool]" = True + storage_partition_strategies: "ClassVar[tuple[str, ...]]" = ("fixed", "rows_per_chunk", "manifest") def __init__( self, diff --git a/sqlspec/adapters/duckdb/driver.py b/sqlspec/adapters/duckdb/driver.py index 39b467f6..f93d22c7 100644 --- a/sqlspec/adapters/duckdb/driver.py +++ b/sqlspec/adapters/duckdb/driver.py @@ -1,9 +1,11 @@ """DuckDB driver implementation.""" +import contextlib import typing from datetime import date, datetime from decimal import Decimal -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any, Final, cast +from uuid import uuid4 import duckdb @@ -44,6 +46,13 @@ from sqlspec.core import ArrowResult, SQLResult, Statement, StatementFilter from sqlspec.driver import ExecutionResult from sqlspec.driver._sync import SyncDataDictionaryBase + from sqlspec.storage import ( + StorageBridgeJob, + StorageDestination, + StorageFormat, + StorageTelemetry, + SyncStoragePipeline, + ) from sqlspec.typing import ArrowReturnFormat, StatementParameters __all__ = ( @@ -477,6 +486,72 @@ def select_to_arrow( # Create ArrowResult return create_arrow_result(statement=prepared_statement, data=arrow_data, rows_affected=arrow_data.num_rows) + def select_to_storage( + self, + statement: "Statement | QueryBuilder | SQL | str", + destination: "StorageDestination", + /, + *parameters: "StatementParameters | StatementFilter", + statement_config: "StatementConfig | None" = None, + partitioner: "dict[str, Any] | None" = None, + format_hint: "StorageFormat | None" = None, + telemetry: "StorageTelemetry | None" = None, + **kwargs: Any, + ) -> "StorageBridgeJob": + """Persist DuckDB query output to a storage backend using Arrow fast paths.""" + + _ = kwargs + self._require_capability("arrow_export_enabled") + arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) + sync_pipeline: SyncStoragePipeline = cast("SyncStoragePipeline", self._storage_pipeline()) + telemetry_payload = arrow_result.write_to_storage_sync( + destination, format_hint=format_hint, pipeline=sync_pipeline + ) + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + def load_from_arrow( + self, + table: str, + source: "ArrowResult | Any", + *, + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + telemetry: "StorageTelemetry | None" = None, + ) -> "StorageBridgeJob": + """Load Arrow data into DuckDB using temporary table registration.""" + + self._require_capability("arrow_import_enabled") + arrow_table = self._coerce_arrow_table(source) + temp_view = f"_sqlspec_arrow_{uuid4().hex}" + if overwrite: + self.connection.execute(f"TRUNCATE TABLE {table}") + self.connection.register(temp_view, arrow_table) + try: + self.connection.execute(f"INSERT INTO {table} SELECT * FROM {temp_view}") + finally: + with contextlib.suppress(Exception): + self.connection.unregister(temp_view) + + telemetry_payload = self._build_ingest_telemetry(arrow_table) + telemetry_payload["destination"] = table + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + def load_from_storage( + self, + table: str, + source: "StorageDestination", + *, + file_format: "StorageFormat", + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + ) -> "StorageBridgeJob": + """Read an artifact from storage and load it into DuckDB.""" + + arrow_table, inbound = self._read_arrow_from_storage_sync(source, file_format=file_format) + return self.load_from_arrow(table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound) + def _bool_to_int(value: bool) -> int: return int(value) diff --git a/sqlspec/adapters/oracledb/config.py b/sqlspec/adapters/oracledb/config.py index 95f75cbe..2b31fd68 100644 --- a/sqlspec/adapters/oracledb/config.py +++ b/sqlspec/adapters/oracledb/config.py @@ -119,6 +119,10 @@ class OracleSyncConfig(SyncDatabaseConfig[OracleSyncConnection, "OracleSyncConne connection_type: "ClassVar[type[OracleSyncConnection]]" = OracleSyncConnection migration_tracker_type: "ClassVar[type[OracleSyncMigrationTracker]]" = OracleSyncMigrationTracker supports_transactional_ddl: ClassVar[bool] = False + supports_native_arrow_export: ClassVar[bool] = True + supports_native_arrow_import: ClassVar[bool] = True + supports_native_parquet_export: ClassVar[bool] = True + supports_native_parquet_import: ClassVar[bool] = True def __init__( self, @@ -277,6 +281,10 @@ class OracleAsyncConfig(AsyncDatabaseConfig[OracleAsyncConnection, "OracleAsyncC driver_type: ClassVar[type[OracleAsyncDriver]] = OracleAsyncDriver migration_tracker_type: "ClassVar[type[OracleAsyncMigrationTracker]]" = OracleAsyncMigrationTracker supports_transactional_ddl: ClassVar[bool] = False + supports_native_arrow_export: ClassVar[bool] = True + supports_native_arrow_import: ClassVar[bool] = True + supports_native_parquet_export: ClassVar[bool] = True + supports_native_parquet_import: ClassVar[bool] = True def __init__( self, diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index 031f822d..5de971fd 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -3,7 +3,7 @@ import contextlib import logging import re -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any, Final, cast import oracledb from oracledb import AsyncCursor, Cursor @@ -47,8 +47,16 @@ from contextlib import AbstractAsyncContextManager, AbstractContextManager from sqlspec.builder import QueryBuilder - from sqlspec.core import SQLResult, Statement, StatementFilter + from sqlspec.core import ArrowResult, SQLResult, Statement, StatementFilter from sqlspec.driver import ExecutionResult + from sqlspec.storage import ( + AsyncStoragePipeline, + StorageBridgeJob, + StorageDestination, + StorageFormat, + StorageTelemetry, + SyncStoragePipeline, + ) from sqlspec.typing import ArrowReturnFormat, StatementParameters logger = logging.getLogger(__name__) @@ -83,6 +91,16 @@ def _normalize_column_names(column_names: "list[str]", driver_features: "dict[st return normalized +def _oracle_insert_statement(table: str, columns: "list[str]") -> str: + column_list = ", ".join(columns) + placeholders = ", ".join(f":{idx + 1}" for idx in range(len(columns))) + return f"INSERT INTO {table} ({column_list}) VALUES ({placeholders})" + + +def _oracle_truncate_statement(table: str) -> str: + return f"TRUNCATE TABLE {table}" + + def _coerce_sync_row_values(row: "tuple[Any, ...]") -> "list[Any]": """Coerce LOB handles to concrete values for synchronous execution. @@ -471,6 +489,68 @@ def _execute_statement(self, cursor: Any, statement: "SQL") -> "ExecutionResult" affected_rows = cursor.rowcount if cursor.rowcount is not None else 0 return self.create_execution_result(cursor, rowcount_override=affected_rows) + def select_to_storage( + self, + statement: "Statement | QueryBuilder | SQL | str", + destination: "StorageDestination", + /, + *parameters: "StatementParameters | StatementFilter", + statement_config: "StatementConfig | None" = None, + partitioner: "dict[str, Any] | None" = None, + format_hint: "StorageFormat | None" = None, + telemetry: "StorageTelemetry | None" = None, + **kwargs: Any, + ) -> "StorageBridgeJob": + """Execute a query and stream Arrow-formatted output to storage (sync).""" + + self._require_capability("arrow_export_enabled") + arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) + sync_pipeline: SyncStoragePipeline = cast("SyncStoragePipeline", self._storage_pipeline()) + telemetry_payload = arrow_result.write_to_storage_sync( + destination, format_hint=format_hint, pipeline=sync_pipeline + ) + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + def load_from_arrow( + self, + table: str, + source: "ArrowResult | Any", + *, + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + telemetry: "StorageTelemetry | None" = None, + ) -> "StorageBridgeJob": + """Load Arrow data into Oracle using batched executemany calls.""" + + self._require_capability("arrow_import_enabled") + arrow_table = self._coerce_arrow_table(source) + if overwrite: + self._truncate_table_sync(table) + columns, records = self._arrow_table_to_rows(arrow_table) + if records: + statement = _oracle_insert_statement(table, columns) + with self.with_cursor(self.connection) as cursor, self.handle_database_exceptions(): + cursor.executemany(statement, records) + telemetry_payload = self._build_ingest_telemetry(arrow_table) + telemetry_payload["destination"] = table + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + def load_from_storage( + self, + table: str, + source: "StorageDestination", + *, + file_format: "StorageFormat", + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + ) -> "StorageBridgeJob": + """Load staged artifacts into Oracle.""" + + arrow_table, inbound = self._read_arrow_from_storage_sync(source, file_format=file_format) + return self.load_from_arrow(table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound) + # Oracle transaction management def begin(self) -> None: """Begin a database transaction. @@ -602,6 +682,11 @@ def data_dictionary(self) -> "SyncDataDictionaryBase": self._data_dictionary = OracleSyncDataDictionary() return self._data_dictionary + def _truncate_table_sync(self, table: str) -> None: + statement = _oracle_truncate_statement(table) + with self.handle_database_exceptions(): + self.connection.execute(statement) + class OracleAsyncDriver(AsyncDriverAdapterBase): """Asynchronous Oracle Database driver. @@ -757,6 +842,70 @@ async def _execute_statement(self, cursor: Any, statement: "SQL") -> "ExecutionR affected_rows = cursor.rowcount if cursor.rowcount is not None else 0 return self.create_execution_result(cursor, rowcount_override=affected_rows) + async def select_to_storage( + self, + statement: "Statement | QueryBuilder | SQL | str", + destination: "StorageDestination", + /, + *parameters: "StatementParameters | StatementFilter", + statement_config: "StatementConfig | None" = None, + partitioner: "dict[str, Any] | None" = None, + format_hint: "StorageFormat | None" = None, + telemetry: "StorageTelemetry | None" = None, + **kwargs: Any, + ) -> "StorageBridgeJob": + """Execute a query and write Arrow-compatible output to storage (async).""" + + self._require_capability("arrow_export_enabled") + arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) + async_pipeline: AsyncStoragePipeline = cast("AsyncStoragePipeline", self._storage_pipeline()) + telemetry_payload = await arrow_result.write_to_storage_async( + destination, format_hint=format_hint, pipeline=async_pipeline + ) + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + async def load_from_arrow( + self, + table: str, + source: "ArrowResult | Any", + *, + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + telemetry: "StorageTelemetry | None" = None, + ) -> "StorageBridgeJob": + """Asynchronously load Arrow data into Oracle.""" + + self._require_capability("arrow_import_enabled") + arrow_table = self._coerce_arrow_table(source) + if overwrite: + await self._truncate_table_async(table) + columns, records = self._arrow_table_to_rows(arrow_table) + if records: + statement = _oracle_insert_statement(table, columns) + async with self.with_cursor(self.connection) as cursor, self.handle_database_exceptions(): + await cursor.executemany(statement, records) + telemetry_payload = self._build_ingest_telemetry(arrow_table) + telemetry_payload["destination"] = table + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + async def load_from_storage( + self, + table: str, + source: "StorageDestination", + *, + file_format: "StorageFormat", + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + ) -> "StorageBridgeJob": + """Asynchronously load staged artifacts into Oracle.""" + + arrow_table, inbound = await self._read_arrow_from_storage_async(source, file_format=file_format) + return await self.load_from_arrow( + table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound + ) + # Oracle transaction management async def begin(self) -> None: """Begin a database transaction. @@ -888,6 +1037,11 @@ def data_dictionary(self) -> "AsyncDataDictionaryBase": self._data_dictionary = OracleAsyncDataDictionary() return self._data_dictionary + async def _truncate_table_async(self, table: str) -> None: + statement = _oracle_truncate_statement(table) + async with self.handle_database_exceptions(): + await self.connection.execute(statement) + def _build_oracledb_profile() -> DriverParameterProfile: """Create the OracleDB driver parameter profile.""" diff --git a/sqlspec/adapters/psqlpy/config.py b/sqlspec/adapters/psqlpy/config.py index 2e692971..aec2b6e5 100644 --- a/sqlspec/adapters/psqlpy/config.py +++ b/sqlspec/adapters/psqlpy/config.py @@ -100,6 +100,10 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD driver_type: ClassVar[type[PsqlpyDriver]] = PsqlpyDriver connection_type: "ClassVar[type[PsqlpyConnection]]" = PsqlpyConnection supports_transactional_ddl: "ClassVar[bool]" = True + supports_native_arrow_export: ClassVar[bool] = True + supports_native_arrow_import: ClassVar[bool] = True + supports_native_parquet_export: ClassVar[bool] = True + supports_native_parquet_import: ClassVar[bool] = True def __init__( self, diff --git a/sqlspec/adapters/psqlpy/driver.py b/sqlspec/adapters/psqlpy/driver.py index 553915c6..1254a6f0 100644 --- a/sqlspec/adapters/psqlpy/driver.py +++ b/sqlspec/adapters/psqlpy/driver.py @@ -6,9 +6,11 @@ import datetime import decimal +import inspect +import io import re import uuid -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any, Final, cast import psqlpy.exceptions from psqlpy.extra_types import JSONB @@ -49,9 +51,16 @@ from contextlib import AbstractAsyncContextManager from sqlspec.adapters.psqlpy._types import PsqlpyConnection - from sqlspec.core import SQLResult + from sqlspec.core import ArrowResult, SQLResult from sqlspec.driver import ExecutionResult from sqlspec.driver._async import AsyncDataDictionaryBase + from sqlspec.storage import ( + AsyncStoragePipeline, + StorageBridgeJob, + StorageDestination, + StorageFormat, + StorageTelemetry, + ) __all__ = ( "PsqlpyCursor", @@ -487,6 +496,86 @@ def _parse_command_tag(self, tag: str) -> int: return int(match.group(3)) return -1 + async def select_to_storage( + self, + statement: "SQL | str", + destination: "StorageDestination", + /, + *parameters: Any, + statement_config: "StatementConfig | None" = None, + partitioner: "dict[str, Any] | None" = None, + format_hint: "StorageFormat | None" = None, + telemetry: "StorageTelemetry | None" = None, + **kwargs: Any, + ) -> "StorageBridgeJob": + """Execute a query and stream Arrow results to a storage backend.""" + + self._require_capability("arrow_export_enabled") + arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) + async_pipeline: AsyncStoragePipeline = cast("AsyncStoragePipeline", self._storage_pipeline()) + telemetry_payload = await arrow_result.write_to_storage_async( + destination, format_hint=format_hint, pipeline=async_pipeline + ) + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + async def load_from_arrow( + self, + table: str, + source: "ArrowResult | Any", + *, + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + telemetry: "StorageTelemetry | None" = None, + ) -> "StorageBridgeJob": + """Load Arrow-formatted data into PostgreSQL via psqlpy binary COPY.""" + + self._require_capability("arrow_import_enabled") + arrow_table = self._coerce_arrow_table(source) + if overwrite: + await self._truncate_table_async(table) + + columns, records = self._arrow_table_to_rows(arrow_table) + if records: + schema_name, table_name = _split_schema_and_table(table) + async with self.handle_database_exceptions(), self.with_cursor(self.connection) as cursor: + copy_kwargs: dict[str, Any] = {"columns": columns} + if schema_name: + copy_kwargs["schema_name"] = schema_name + try: + copy_payload = _encode_records_for_binary_copy(records) + copy_operation = cursor.binary_copy_to_table(copy_payload, table_name, **copy_kwargs) + if inspect.isawaitable(copy_operation): + await copy_operation + except (TypeError, psqlpy.exceptions.DatabaseError) as exc: + logger.debug("Binary COPY not available for psqlpy; falling back to INSERT statements: %s", exc) + insert_sql = _build_psqlpy_insert_statement(table, columns) + formatted_records = _coerce_records_for_execute_many(records) + insert_operation = cursor.execute_many(insert_sql, formatted_records) + if inspect.isawaitable(insert_operation): + await insert_operation + + telemetry_payload = self._build_ingest_telemetry(arrow_table) + telemetry_payload["destination"] = table + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + async def load_from_storage( + self, + table: str, + source: "StorageDestination", + *, + file_format: "StorageFormat", + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + ) -> "StorageBridgeJob": + """Load staged artifacts from storage using the storage bridge pipeline.""" + + arrow_table, inbound = await self._read_arrow_from_storage_async(source, file_format=file_format) + return await self.load_from_arrow( + table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound + ) + async def begin(self) -> None: """Begin a database transaction.""" try: @@ -511,6 +600,11 @@ async def commit(self) -> None: msg = f"Failed to commit psqlpy transaction: {e}" raise SQLSpecError(msg) from e + async def _truncate_table_async(self, table: str) -> None: + qualified = _format_table_identifier(table) + async with self.handle_database_exceptions(), self.with_cursor(self.connection) as cursor: + await cursor.execute(f"TRUNCATE TABLE {qualified}") + @property def data_dictionary(self) -> "AsyncDataDictionaryBase": """Get the data dictionary for this driver. @@ -662,6 +756,91 @@ def _coerce_numeric_for_write(value: Any) -> Any: return value +def _escape_copy_text(value: str) -> str: + return value.replace("\\", "\\\\").replace("\t", "\\t").replace("\n", "\\n").replace("\r", "\\r") + + +def _format_copy_value(value: Any) -> str: + if value is None: + return r"\N" + if isinstance(value, bool): + return "t" if value else "f" + if isinstance(value, (datetime.date, datetime.datetime, datetime.time)): + return value.isoformat() + if isinstance(value, (list, tuple, dict)): + return to_json(value) + if isinstance(value, (bytes, bytearray)): + return value.decode("utf-8") + return str(_coerce_numeric_for_write(value)) + + +def _encode_records_for_binary_copy(records: "list[tuple[Any, ...]]") -> bytes: + """Encode row tuples into a bytes payload compatible with binary_copy_to_table. + + Args: + records: Sequence of row tuples extracted from the Arrow table. + + Returns: + UTF-8 encoded bytes buffer representing the COPY payload. + """ + + buffer = io.StringIO() + for record in records: + encoded_columns = [_escape_copy_text(_format_copy_value(value)) for value in record] + buffer.write("\t".join(encoded_columns)) + buffer.write("\n") + return buffer.getvalue().encode("utf-8") + + +def _split_schema_and_table(identifier: str) -> "tuple[str | None, str]": + cleaned = identifier.strip() + if not cleaned: + msg = "Table name must not be empty" + raise SQLSpecError(msg) + if "." not in cleaned: + return None, cleaned.strip('"') + parts = [part for part in cleaned.split(".") if part] + if len(parts) == 1: + return None, parts[0].strip('"') + schema_name = ".".join(parts[:-1]).strip('"') + table_name = parts[-1].strip('"') + if not table_name: + msg = "Table name must not be empty" + raise SQLSpecError(msg) + return schema_name or None, table_name + + +def _quote_identifier(identifier: str) -> str: + normalized = identifier.replace('"', '""') + return f'"{normalized}"' + + +def _format_table_identifier(identifier: str) -> str: + schema_name, table_name = _split_schema_and_table(identifier) + if schema_name: + return f"{_quote_identifier(schema_name)}.{_quote_identifier(table_name)}" + return _quote_identifier(table_name) + + +def _build_psqlpy_insert_statement(table: str, columns: "list[str]") -> str: + column_clause = ", ".join(_quote_identifier(column) for column in columns) + placeholders = ", ".join(f"${index}" for index in range(1, len(columns) + 1)) + return f"INSERT INTO {_format_table_identifier(table)} ({column_clause}) VALUES ({placeholders})" + + +def _coerce_records_for_execute_many(records: "list[tuple[Any, ...]]") -> "list[list[Any]]": + formatted_records: list[list[Any]] = [] + for record in records: + coerced = _coerce_numeric_for_write(record) + if isinstance(coerced, tuple): + formatted_records.append(list(coerced)) + elif isinstance(coerced, list): + formatted_records.append(coerced) + else: + formatted_records.append([coerced]) + return formatted_records + + def _build_psqlpy_profile() -> DriverParameterProfile: """Create the psqlpy driver parameter profile.""" diff --git a/sqlspec/adapters/psycopg/config.py b/sqlspec/adapters/psycopg/config.py index 01cfb184..7ec1d0e7 100644 --- a/sqlspec/adapters/psycopg/config.py +++ b/sqlspec/adapters/psycopg/config.py @@ -110,6 +110,10 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool driver_type: "ClassVar[type[PsycopgSyncDriver]]" = PsycopgSyncDriver connection_type: "ClassVar[type[PsycopgSyncConnection]]" = PsycopgSyncConnection supports_transactional_ddl: "ClassVar[bool]" = True + supports_native_arrow_export: "ClassVar[bool]" = True + supports_native_arrow_import: "ClassVar[bool]" = True + supports_native_parquet_export: "ClassVar[bool]" = True + supports_native_parquet_import: "ClassVar[bool]" = True def __init__( self, @@ -301,6 +305,10 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec driver_type: ClassVar[type[PsycopgAsyncDriver]] = PsycopgAsyncDriver connection_type: "ClassVar[type[PsycopgAsyncConnection]]" = PsycopgAsyncConnection supports_transactional_ddl: "ClassVar[bool]" = True + supports_native_arrow_export: ClassVar[bool] = True + supports_native_arrow_import: ClassVar[bool] = True + supports_native_parquet_export: ClassVar[bool] = True + supports_native_parquet_import: ClassVar[bool] = True def __init__( self, diff --git a/sqlspec/adapters/psycopg/driver.py b/sqlspec/adapters/psycopg/driver.py index 434c11ef..ee867409 100644 --- a/sqlspec/adapters/psycopg/driver.py +++ b/sqlspec/adapters/psycopg/driver.py @@ -16,9 +16,11 @@ import datetime import io -from typing import TYPE_CHECKING, Any +from contextlib import AsyncExitStack, ExitStack +from typing import TYPE_CHECKING, Any, cast import psycopg +from psycopg import sql as psycopg_sql from sqlspec.adapters.psycopg._types import PsycopgAsyncConnection, PsycopgSyncConnection from sqlspec.core import ( @@ -57,9 +59,18 @@ from collections.abc import Callable from contextlib import AbstractAsyncContextManager, AbstractContextManager + from sqlspec.core import ArrowResult from sqlspec.driver._async import AsyncDataDictionaryBase from sqlspec.driver._common import ExecutionResult from sqlspec.driver._sync import SyncDataDictionaryBase + from sqlspec.storage import ( + AsyncStoragePipeline, + StorageBridgeJob, + StorageDestination, + StorageFormat, + StorageTelemetry, + SyncStoragePipeline, + ) __all__ = ( "PsycopgAsyncCursor", @@ -82,6 +93,25 @@ TRANSACTION_STATUS_UNKNOWN = 4 +def _compose_table_identifier(table: str) -> "psycopg_sql.Composed": + parts = [part for part in table.split(".") if part] + if not parts: + msg = "Table name must not be empty" + raise SQLSpecError(msg) + identifiers = [psycopg_sql.Identifier(part) for part in parts] + return psycopg_sql.SQL(".").join(identifiers) + + +def _build_copy_from_command(table: str, columns: "list[str]") -> "psycopg_sql.Composed": + table_identifier = _compose_table_identifier(table) + column_sql = psycopg_sql.SQL(", ").join(psycopg_sql.Identifier(column) for column in columns) + return psycopg_sql.SQL("COPY {} ({}) FROM STDIN").format(table_identifier, column_sql) + + +def _build_truncate_command(table: str) -> "psycopg_sql.Composed": + return psycopg_sql.SQL("TRUNCATE TABLE {}").format(_compose_table_identifier(table)) + + class PsycopgSyncCursor: """Context manager for PostgreSQL psycopg cursor management.""" @@ -436,6 +466,72 @@ def _execute_statement(self, cursor: Any, statement: "SQL") -> "ExecutionResult" affected_rows = cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 return self.create_execution_result(cursor, rowcount_override=affected_rows) + def select_to_storage( + self, + statement: "SQL | str", + destination: "StorageDestination", + /, + *parameters: Any, + statement_config: "StatementConfig | None" = None, + partitioner: "dict[str, Any] | None" = None, + format_hint: "StorageFormat | None" = None, + telemetry: "StorageTelemetry | None" = None, + **kwargs: Any, + ) -> "StorageBridgeJob": + """Execute a query and stream Arrow results to storage (sync).""" + + self._require_capability("arrow_export_enabled") + arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) + sync_pipeline: SyncStoragePipeline = cast("SyncStoragePipeline", self._storage_pipeline()) + telemetry_payload = arrow_result.write_to_storage_sync( + destination, format_hint=format_hint, pipeline=sync_pipeline + ) + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + def load_from_arrow( + self, + table: str, + source: "ArrowResult | Any", + *, + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + telemetry: "StorageTelemetry | None" = None, + ) -> "StorageBridgeJob": + """Load Arrow data into PostgreSQL using COPY.""" + + self._require_capability("arrow_import_enabled") + arrow_table = self._coerce_arrow_table(source) + if overwrite: + self._truncate_table_sync(table) + columns, records = self._arrow_table_to_rows(arrow_table) + if records: + copy_sql = _build_copy_from_command(table, columns) + with ExitStack() as stack: + stack.enter_context(self.handle_database_exceptions()) + cursor = stack.enter_context(self.with_cursor(self.connection)) + copy_ctx = stack.enter_context(cursor.copy(copy_sql)) + for record in records: + copy_ctx.write_row(record) + telemetry_payload = self._build_ingest_telemetry(arrow_table) + telemetry_payload["destination"] = table + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + def load_from_storage( + self, + table: str, + source: "StorageDestination", + *, + file_format: "StorageFormat", + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + ) -> "StorageBridgeJob": + """Load staged artifacts into PostgreSQL via COPY.""" + + arrow_table, inbound = self._read_arrow_from_storage_sync(source, file_format=file_format) + return self.load_from_arrow(table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound) + @property def data_dictionary(self) -> "SyncDataDictionaryBase": """Get the data dictionary for this driver. @@ -449,6 +545,11 @@ def data_dictionary(self) -> "SyncDataDictionaryBase": self._data_dictionary = PostgresSyncDataDictionary() return self._data_dictionary + def _truncate_table_sync(self, table: str) -> None: + truncate_sql = _build_truncate_command(table) + with self.with_cursor(self.connection) as cursor, self.handle_database_exceptions(): + cursor.execute(truncate_sql) + class PsycopgAsyncCursor: """Async context manager for PostgreSQL psycopg cursor management.""" @@ -807,6 +908,74 @@ async def _execute_statement(self, cursor: Any, statement: "SQL") -> "ExecutionR affected_rows = cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 return self.create_execution_result(cursor, rowcount_override=affected_rows) + async def select_to_storage( + self, + statement: "SQL | str", + destination: "StorageDestination", + /, + *parameters: Any, + statement_config: "StatementConfig | None" = None, + partitioner: "dict[str, Any] | None" = None, + format_hint: "StorageFormat | None" = None, + telemetry: "StorageTelemetry | None" = None, + **kwargs: Any, + ) -> "StorageBridgeJob": + """Execute a query and stream Arrow data to storage asynchronously.""" + + self._require_capability("arrow_export_enabled") + arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) + async_pipeline: AsyncStoragePipeline = cast("AsyncStoragePipeline", self._storage_pipeline()) + telemetry_payload = await arrow_result.write_to_storage_async( + destination, format_hint=format_hint, pipeline=async_pipeline + ) + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + async def load_from_arrow( + self, + table: str, + source: "ArrowResult | Any", + *, + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + telemetry: "StorageTelemetry | None" = None, + ) -> "StorageBridgeJob": + """Load Arrow data into PostgreSQL asynchronously via COPY.""" + + self._require_capability("arrow_import_enabled") + arrow_table = self._coerce_arrow_table(source) + if overwrite: + await self._truncate_table_async(table) + columns, records = self._arrow_table_to_rows(arrow_table) + if records: + copy_sql = _build_copy_from_command(table, columns) + async with AsyncExitStack() as stack: + await stack.enter_async_context(self.handle_database_exceptions()) + cursor = await stack.enter_async_context(self.with_cursor(self.connection)) + copy_ctx = await stack.enter_async_context(cursor.copy(copy_sql)) + for record in records: + await copy_ctx.write_row(record) + telemetry_payload = self._build_ingest_telemetry(arrow_table) + telemetry_payload["destination"] = table + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + async def load_from_storage( + self, + table: str, + source: "StorageDestination", + *, + file_format: "StorageFormat", + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + ) -> "StorageBridgeJob": + """Load staged artifacts asynchronously.""" + + arrow_table, inbound = await self._read_arrow_from_storage_async(source, file_format=file_format) + return await self.load_from_arrow( + table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound + ) + @property def data_dictionary(self) -> "AsyncDataDictionaryBase": """Get the data dictionary for this driver. @@ -820,6 +989,11 @@ def data_dictionary(self) -> "AsyncDataDictionaryBase": self._data_dictionary = PostgresAsyncDataDictionary() return self._data_dictionary + async def _truncate_table_async(self, table: str) -> None: + truncate_sql = _build_truncate_command(table) + async with self.with_cursor(self.connection) as cursor, self.handle_database_exceptions(): + await cursor.execute(truncate_sql) + def _identity(value: Any) -> Any: return value @@ -844,11 +1018,7 @@ def _build_psycopg_profile() -> DriverParameterProfile: ParameterStyle.QMARK, }, default_execution_style=ParameterStyle.POSITIONAL_PYFORMAT, - supported_execution_styles={ - ParameterStyle.POSITIONAL_PYFORMAT, - ParameterStyle.NAMED_PYFORMAT, - ParameterStyle.NUMERIC, - }, + supported_execution_styles={ParameterStyle.POSITIONAL_PYFORMAT, ParameterStyle.NAMED_PYFORMAT}, has_native_list_expansion=True, preserve_parameter_format=True, needs_static_script_compilation=False, diff --git a/sqlspec/adapters/sqlite/config.py b/sqlspec/adapters/sqlite/config.py index 177ac529..07e3d10f 100644 --- a/sqlspec/adapters/sqlite/config.py +++ b/sqlspec/adapters/sqlite/config.py @@ -63,6 +63,10 @@ class SqliteConfig(SyncDatabaseConfig[SqliteConnection, SqliteConnectionPool, Sq driver_type: "ClassVar[type[SqliteDriver]]" = SqliteDriver connection_type: "ClassVar[type[SqliteConnection]]" = SqliteConnection supports_transactional_ddl: "ClassVar[bool]" = True + supports_native_arrow_export: "ClassVar[bool]" = True + supports_native_arrow_import: "ClassVar[bool]" = True + supports_native_parquet_export: "ClassVar[bool]" = True + supports_native_parquet_import: "ClassVar[bool]" = True def __init__( self, diff --git a/sqlspec/adapters/sqlite/driver.py b/sqlspec/adapters/sqlite/driver.py index 28481b54..31216e89 100644 --- a/sqlspec/adapters/sqlite/driver.py +++ b/sqlspec/adapters/sqlite/driver.py @@ -4,9 +4,10 @@ import sqlite3 from datetime import date, datetime from decimal import Decimal -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from sqlspec.core import ( + ArrowResult, DriverParameterProfile, ParameterStyle, build_statement_config_from_profile, @@ -36,6 +37,13 @@ from sqlspec.core import SQL, SQLResult, StatementConfig from sqlspec.driver import ExecutionResult from sqlspec.driver._sync import SyncDataDictionaryBase + from sqlspec.storage import ( + StorageBridgeJob, + StorageDestination, + StorageFormat, + StorageTelemetry, + SyncStoragePipeline, + ) __all__ = ("SqliteCursor", "SqliteDriver", "SqliteExceptionHandler", "sqlite_statement_config") @@ -340,6 +348,70 @@ def _execute_statement(self, cursor: "sqlite3.Cursor", statement: "SQL") -> "Exe affected_rows = cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 return self.create_execution_result(cursor, rowcount_override=affected_rows) + def select_to_storage( + self, + statement: "SQL | str", + destination: "StorageDestination", + /, + *parameters: Any, + statement_config: "StatementConfig | None" = None, + partitioner: "dict[str, Any] | None" = None, + format_hint: "StorageFormat | None" = None, + telemetry: "StorageTelemetry | None" = None, + **kwargs: Any, + ) -> "StorageBridgeJob": + """Execute a query and write Arrow-compatible output to storage (sync).""" + + self._require_capability("arrow_export_enabled") + arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) + sync_pipeline: SyncStoragePipeline = cast("SyncStoragePipeline", self._storage_pipeline()) + telemetry_payload = arrow_result.write_to_storage_sync( + destination, format_hint=format_hint, pipeline=sync_pipeline + ) + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + def load_from_arrow( + self, + table: str, + source: "ArrowResult | Any", + *, + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + telemetry: "StorageTelemetry | None" = None, + ) -> "StorageBridgeJob": + """Load Arrow data into SQLite using batched inserts.""" + + self._require_capability("arrow_import_enabled") + arrow_table = self._coerce_arrow_table(source) + if overwrite: + self._truncate_table_sync(table) + + columns, records = self._arrow_table_to_rows(arrow_table) + if records: + insert_sql = _build_sqlite_insert_statement(table, columns) + with self.handle_database_exceptions(), self.with_cursor(self.connection) as cursor: + cursor.executemany(insert_sql, records) + + telemetry_payload = self._build_ingest_telemetry(arrow_table) + telemetry_payload["destination"] = table + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + def load_from_storage( + self, + table: str, + source: "StorageDestination", + *, + file_format: "StorageFormat", + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + ) -> "StorageBridgeJob": + """Load staged artifacts from storage into SQLite.""" + + arrow_table, inbound = self._read_arrow_from_storage_sync(source, file_format=file_format) + return self.load_from_arrow(table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound) + def begin(self) -> None: """Begin a database transaction. @@ -365,6 +437,11 @@ def rollback(self) -> None: msg = f"Failed to rollback transaction: {e}" raise SQLSpecError(msg) from e + def _truncate_table_sync(self, table: str) -> None: + statement = f"DELETE FROM {_format_sqlite_identifier(table)}" + with self.handle_database_exceptions(), self.with_cursor(self.connection) as cursor: + cursor.execute(statement) + def commit(self) -> None: """Commit the current transaction. @@ -395,6 +472,27 @@ def _bool_to_int(value: bool) -> int: return int(value) +def _quote_sqlite_identifier(identifier: str) -> str: + normalized = identifier.replace('"', '""') + return f'"{normalized}"' + + +def _format_sqlite_identifier(identifier: str) -> str: + cleaned = identifier.strip() + if not cleaned: + msg = "Table name must not be empty" + raise SQLSpecError(msg) + parts = [part for part in cleaned.split(".") if part] + formatted = ".".join(_quote_sqlite_identifier(part) for part in parts) + return formatted or _quote_sqlite_identifier(cleaned) + + +def _build_sqlite_insert_statement(table: str, columns: "list[str]") -> str: + column_clause = ", ".join(_quote_sqlite_identifier(column) for column in columns) + placeholders = ", ".join("?" for _ in columns) + return f"INSERT INTO {_format_sqlite_identifier(table)} ({column_clause}) VALUES ({placeholders})" + + def _build_sqlite_profile() -> DriverParameterProfile: """Create the SQLite driver parameter profile.""" diff --git a/sqlspec/config.py b/sqlspec/config.py index 7a3f9ecd..acb9f2f6 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -6,8 +6,10 @@ from typing_extensions import NotRequired, TypedDict from sqlspec.core import ParameterStyle, ParameterStyleConfig, StatementConfig +from sqlspec.exceptions import MissingDependencyError from sqlspec.migrations.tracker import AsyncMigrationTracker, SyncMigrationTracker from sqlspec.utils.logging import get_logger +from sqlspec.utils.module_loader import ensure_pyarrow if TYPE_CHECKING: from collections.abc import Awaitable @@ -16,6 +18,7 @@ from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase from sqlspec.loader import SQLFileLoader from sqlspec.migrations.commands import AsyncMigrationCommands, SyncMigrationCommands + from sqlspec.storage import StorageCapabilities __all__ = ( @@ -389,6 +392,7 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]): __slots__ = ( "_migration_commands", "_migration_loader", + "_storage_capabilities", "bind_key", "driver_features", "migration_config", @@ -407,11 +411,16 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]): supports_native_arrow_export: "ClassVar[bool]" = False supports_native_parquet_import: "ClassVar[bool]" = False supports_native_parquet_export: "ClassVar[bool]" = False + requires_staging_for_load: "ClassVar[bool]" = False + staging_protocols: "ClassVar[tuple[str, ...]]" = () + default_storage_profile: "ClassVar[str | None]" = None + storage_partition_strategies: "ClassVar[tuple[str, ...]]" = ("fixed",) bind_key: "str | None" statement_config: "StatementConfig" pool_instance: "PoolT | None" migration_config: "dict[str, Any] | MigrationConfig" driver_features: "dict[str, Any]" + _storage_capabilities: "StorageCapabilities | None" def __hash__(self) -> int: return id(self) @@ -425,6 +434,46 @@ def __repr__(self) -> str: parts = ", ".join([f"pool_instance={self.pool_instance!r}", f"migration_config={self.migration_config!r}"]) return f"{type(self).__name__}({parts})" + def storage_capabilities(self) -> "StorageCapabilities": + """Return cached storage capabilities for this configuration.""" + + if self._storage_capabilities is None: + self._storage_capabilities = self._build_storage_capabilities() + return cast("StorageCapabilities", dict(self._storage_capabilities)) + + def reset_storage_capabilities_cache(self) -> None: + """Clear the cached capability snapshot.""" + + self._storage_capabilities = None + + def _build_storage_capabilities(self) -> "StorageCapabilities": + arrow_dependency_needed = self.supports_native_arrow_export or self.supports_native_arrow_import + parquet_dependency_needed = self.supports_native_parquet_export or self.supports_native_parquet_import + + arrow_dependency_ready = self._dependency_available(ensure_pyarrow) if arrow_dependency_needed else False + parquet_dependency_ready = self._dependency_available(ensure_pyarrow) if parquet_dependency_needed else False + + capabilities: StorageCapabilities = { + "arrow_export_enabled": bool(self.supports_native_arrow_export and arrow_dependency_ready), + "arrow_import_enabled": bool(self.supports_native_arrow_import and arrow_dependency_ready), + "parquet_export_enabled": bool(self.supports_native_parquet_export and parquet_dependency_ready), + "parquet_import_enabled": bool(self.supports_native_parquet_import and parquet_dependency_ready), + "requires_staging_for_load": self.requires_staging_for_load, + "staging_protocols": list(self.staging_protocols), + "partition_strategies": list(self.storage_partition_strategies), + } + if self.default_storage_profile is not None: + capabilities["default_storage_profile"] = self.default_storage_profile + return capabilities + + @staticmethod + def _dependency_available(checker: "Callable[[], None]") -> bool: + try: + checker() + except MissingDependencyError: + return False + return True + @abstractmethod def create_connection(self) -> "ConnectionT | Awaitable[ConnectionT]": """Create and return a new database connection.""" @@ -480,7 +529,7 @@ def _initialize_migration_components(self) -> None: at runtime when needed. """ from sqlspec.loader import SQLFileLoader - from sqlspec.migrations.commands import create_migration_commands + from sqlspec.migrations import create_migration_commands self._migration_loader = SQLFileLoader() self._migration_commands = create_migration_commands(self) # pyright: ignore @@ -660,6 +709,8 @@ def __init__( else: self.statement_config = statement_config self.driver_features = driver_features or {} + self._storage_capabilities = None + self.driver_features.setdefault("storage_capabilities", self.storage_capabilities()) def create_connection(self) -> ConnectionT: """Create a database connection.""" @@ -939,6 +990,8 @@ def __init__( else: self.statement_config = statement_config self.driver_features = driver_features or {} + self._storage_capabilities = None + self.driver_features.setdefault("storage_capabilities", self.storage_capabilities()) def create_pool(self) -> PoolT: """Create and return the connection pool. @@ -1103,6 +1156,8 @@ def __init__( else: self.statement_config = statement_config self.driver_features = driver_features or {} + self._storage_capabilities = None + self.driver_features.setdefault("storage_capabilities", self.storage_capabilities()) async def create_pool(self) -> PoolT: """Create and return the connection pool. diff --git a/sqlspec/core/result.py b/sqlspec/core/result.py index db6bd0d4..4745e9df 100644 --- a/sqlspec/core/result.py +++ b/sqlspec/core/result.py @@ -16,6 +16,13 @@ from typing_extensions import TypeVar from sqlspec.core.compiler import OperationType +from sqlspec.storage import ( + AsyncStoragePipeline, + StorageDestination, + StorageFormat, + StorageTelemetry, + SyncStoragePipeline, +) from sqlspec.utils.module_loader import ensure_pandas, ensure_polars, ensure_pyarrow from sqlspec.utils.schema import to_schema @@ -564,6 +571,32 @@ def scalar_or_none(self) -> Any: return next(iter(row.values())) + def write_to_storage_sync( + self, + destination: "StorageDestination", + *, + format_hint: "StorageFormat | None" = None, + storage_options: "dict[str, Any] | None" = None, + pipeline: "SyncStoragePipeline | None" = None, + ) -> "StorageTelemetry": + active_pipeline = pipeline or SyncStoragePipeline() + rows = self.get_data() + return active_pipeline.write_rows(rows, destination, format_hint=format_hint, storage_options=storage_options) + + async def write_to_storage_async( + self, + destination: "StorageDestination", + *, + format_hint: "StorageFormat | None" = None, + storage_options: "dict[str, Any] | None" = None, + pipeline: "AsyncStoragePipeline | None" = None, + ) -> "StorageTelemetry": + active_pipeline = pipeline or AsyncStoragePipeline() + rows = self.get_data() + return await active_pipeline.write_rows( + rows, destination, format_hint=format_hint, storage_options=storage_options + ) + @mypyc_attr(allow_interpreted_subclasses=False) class ArrowResult(StatementResult): @@ -769,6 +802,36 @@ def to_dict(self) -> "list[dict[str, Any]]": return cast("list[dict[str, Any]]", self.data.to_pylist()) + def write_to_storage_sync( + self, + destination: "StorageDestination", + *, + format_hint: "StorageFormat | None" = None, + storage_options: "dict[str, Any] | None" = None, + compression: str | None = None, + pipeline: "SyncStoragePipeline | None" = None, + ) -> "StorageTelemetry": + table = self.get_data() + active_pipeline = pipeline or SyncStoragePipeline() + return active_pipeline.write_arrow( + table, destination, format_hint=format_hint, storage_options=storage_options, compression=compression + ) + + async def write_to_storage_async( + self, + destination: "StorageDestination", + *, + format_hint: "StorageFormat | None" = None, + storage_options: "dict[str, Any] | None" = None, + compression: str | None = None, + pipeline: "AsyncStoragePipeline | None" = None, + ) -> "StorageTelemetry": + table = self.get_data() + active_pipeline = pipeline or AsyncStoragePipeline() + return await active_pipeline.write_arrow( + table, destination, format_hint=format_hint, storage_options=storage_options, compression=compression + ) + def __len__(self) -> int: """Return number of rows in the Arrow table. diff --git a/sqlspec/driver/_async.py b/sqlspec/driver/_async.py index 5a44295d..708c719d 100644 --- a/sqlspec/driver/_async.py +++ b/sqlspec/driver/_async.py @@ -11,7 +11,7 @@ VersionInfo, handle_single_row_error, ) -from sqlspec.driver.mixins import SQLTranslatorMixin +from sqlspec.driver.mixins import SQLTranslatorMixin, StorageDriverMixin from sqlspec.exceptions import ImproperConfigurationError from sqlspec.utils.arrow_helpers import convert_dict_to_arrow from sqlspec.utils.logging import get_logger @@ -25,21 +25,22 @@ from sqlspec.core import ArrowResult, SQLResult, StatementConfig, StatementFilter from sqlspec.typing import ArrowReturnFormat, SchemaT, StatementParameters -_LOGGER_NAME: Final[str] = "sqlspec" -logger = get_logger(_LOGGER_NAME) __all__ = ("AsyncDataDictionaryBase", "AsyncDriverAdapterBase", "AsyncDriverT") EMPTY_FILTERS: Final["list[StatementFilter]"] = [] +_LOGGER_NAME: Final[str] = "sqlspec" +logger = get_logger(_LOGGER_NAME) AsyncDriverT = TypeVar("AsyncDriverT", bound="AsyncDriverAdapterBase") -class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin): +class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, StorageDriverMixin): """Base class for asynchronous database drivers.""" __slots__ = () + is_async: bool = True @property @abstractmethod diff --git a/sqlspec/driver/_sync.py b/sqlspec/driver/_sync.py index c9dcec2f..aed4efa7 100644 --- a/sqlspec/driver/_sync.py +++ b/sqlspec/driver/_sync.py @@ -11,7 +11,7 @@ VersionInfo, handle_single_row_error, ) -from sqlspec.driver.mixins import SQLTranslatorMixin +from sqlspec.driver.mixins import SQLTranslatorMixin, StorageDriverMixin from sqlspec.exceptions import ImproperConfigurationError from sqlspec.utils.arrow_helpers import convert_dict_to_arrow from sqlspec.utils.logging import get_logger @@ -36,10 +36,11 @@ SyncDriverT = TypeVar("SyncDriverT", bound="SyncDriverAdapterBase") -class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin): +class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, StorageDriverMixin): """Base class for synchronous database drivers.""" __slots__ = () + is_async: bool = False @property @abstractmethod diff --git a/sqlspec/driver/mixins/__init__.py b/sqlspec/driver/mixins/__init__.py index 3376c7d8..f5fdb756 100644 --- a/sqlspec/driver/mixins/__init__.py +++ b/sqlspec/driver/mixins/__init__.py @@ -2,5 +2,6 @@ from sqlspec.driver.mixins._result_tools import ToSchemaMixin from sqlspec.driver.mixins._sql_translator import SQLTranslatorMixin +from sqlspec.driver.mixins.storage import StorageDriverMixin -__all__ = ("SQLTranslatorMixin", "ToSchemaMixin") +__all__ = ("SQLTranslatorMixin", "StorageDriverMixin", "ToSchemaMixin") diff --git a/sqlspec/driver/mixins/storage.py b/sqlspec/driver/mixins/storage.py new file mode 100644 index 00000000..7e5926dc --- /dev/null +++ b/sqlspec/driver/mixins/storage.py @@ -0,0 +1,224 @@ +"""Storage bridge mixin shared by sync and async drivers.""" + +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, cast + +from mypy_extensions import trait + +from sqlspec.exceptions import StorageCapabilityError +from sqlspec.storage import ( + AsyncStoragePipeline, + StorageBridgeJob, + StorageCapabilities, + StorageDestination, + StorageFormat, + StorageTelemetry, + SyncStoragePipeline, + create_storage_bridge_job, +) +from sqlspec.utils.module_loader import ensure_pyarrow + +if TYPE_CHECKING: + from collections.abc import Awaitable + + from sqlspec.core import StatementConfig, StatementFilter + from sqlspec.core.result import ArrowResult + from sqlspec.core.statement import SQL + from sqlspec.typing import ArrowTable, StatementParameters + +__all__ = ("StorageDriverMixin",) + + +CAPABILITY_HINTS: dict[str, str] = { + "arrow_export_enabled": "native Arrow export", + "arrow_import_enabled": "native Arrow import", + "parquet_export_enabled": "native Parquet export", + "parquet_import_enabled": "native Parquet import", +} + + +@trait +class StorageDriverMixin: + """Mixin providing capability-aware storage bridge helpers.""" + + __slots__ = () + storage_pipeline_factory: "type[SyncStoragePipeline | AsyncStoragePipeline] | None" = None + driver_features: dict[str, Any] + + def storage_capabilities(self) -> StorageCapabilities: + """Return cached storage capabilities for the active driver.""" + + capabilities = self.driver_features.get("storage_capabilities") + if capabilities is None: + msg = "Storage capabilities are not configured for this driver." + raise StorageCapabilityError(msg, capability="storage_capabilities") + return cast("StorageCapabilities", dict(capabilities)) + + def select_to_storage( + self, + statement: "SQL | str", + destination: StorageDestination, + /, + *parameters: "StatementParameters | StatementFilter", + statement_config: "StatementConfig | None" = None, + partitioner: "dict[str, Any] | None" = None, + format_hint: StorageFormat | None = None, + telemetry: StorageTelemetry | None = None, + ) -> "StorageBridgeJob | Awaitable[StorageBridgeJob]": + """Stream a SELECT statement directly into storage.""" + + self._raise_not_implemented("select_to_storage") + raise NotImplementedError + + def select_to_arrow( + self, + statement: "SQL | str", + /, + *parameters: "StatementParameters | StatementFilter", + partitioner: "dict[str, Any] | None" = None, + memory_pool: Any | None = None, + statement_config: "StatementConfig | None" = None, + ) -> "ArrowResult | Awaitable[ArrowResult]": + """Execute a SELECT that returns an ArrowResult.""" + + self._raise_not_implemented("select_to_arrow") + raise NotImplementedError + + def load_from_arrow( + self, + table: str, + source: "ArrowResult | Any", + *, + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + ) -> "StorageBridgeJob | Awaitable[StorageBridgeJob]": + """Load Arrow data into the target table.""" + + self._raise_not_implemented("load_from_arrow") + raise NotImplementedError + + def load_from_storage( + self, + table: str, + source: StorageDestination, + *, + file_format: StorageFormat, + partitioner: "dict[str, Any] | None" = None, + overwrite: bool = False, + ) -> "StorageBridgeJob | Awaitable[StorageBridgeJob]": + """Load artifacts from storage into the target table.""" + + self._raise_not_implemented("load_from_storage") + raise NotImplementedError + + def stage_artifact(self, request: "dict[str, Any]") -> "dict[str, Any]": + """Provision staging metadata for adapters that require remote URIs.""" + + self._raise_not_implemented("stage_artifact") + raise NotImplementedError + + def flush_staging_artifacts(self, artifacts: "list[dict[str, Any]]", *, error: Exception | None = None) -> None: + """Clean up staged artifacts after a job completes.""" + + if artifacts: + self._raise_not_implemented("flush_staging_artifacts") + + def get_storage_job(self, job_id: str) -> StorageBridgeJob | None: + """Fetch a previously created job handle.""" + + return None + + def _storage_pipeline(self) -> "SyncStoragePipeline | AsyncStoragePipeline": + factory = self.storage_pipeline_factory + if factory is None: + if getattr(self, "is_async", False): + return AsyncStoragePipeline() + return SyncStoragePipeline() + return factory() + + def _raise_not_implemented(self, capability: str) -> None: + msg = f"{capability} is not implemented for this driver" + remediation = "Override StorageDriverMixin methods on the adapter to enable this capability." + raise StorageCapabilityError(msg, capability=capability, remediation=remediation) + + def _require_capability(self, capability_flag: str) -> None: + capabilities = self.storage_capabilities() + if capabilities.get(capability_flag, False): + return + human_label = CAPABILITY_HINTS.get(capability_flag, capability_flag) + remediation = "Check adapter supports this capability or stage artifacts via storage pipeline." + msg = f"{human_label} is not available for this adapter" + raise StorageCapabilityError(msg, capability=capability_flag, remediation=remediation) + + def _attach_partition_telemetry(self, telemetry: StorageTelemetry, partitioner: "dict[str, Any] | None") -> None: + if not partitioner: + return + extra = dict(telemetry.get("extra", {})) + extra["partitioner"] = partitioner + telemetry["extra"] = extra + + def _create_storage_job( + self, produced: StorageTelemetry, provided: StorageTelemetry | None = None, *, status: str = "completed" + ) -> StorageBridgeJob: + merged = cast("StorageTelemetry", dict(produced)) + if provided: + source_bytes = provided.get("bytes_processed") + if source_bytes is not None: + merged["bytes_processed"] = int(merged.get("bytes_processed", 0)) + int(source_bytes) + extra = dict(merged.get("extra", {})) + extra["source"] = provided + merged["extra"] = extra + return create_storage_bridge_job(status, merged) + + def _read_arrow_from_storage_sync( + self, source: StorageDestination, *, file_format: StorageFormat, storage_options: "dict[str, Any] | None" = None + ) -> "tuple[ArrowTable, StorageTelemetry]": + pipeline = cast("SyncStoragePipeline", self._storage_pipeline()) + return pipeline.read_arrow(source, file_format=file_format, storage_options=storage_options) + + async def _read_arrow_from_storage_async( + self, source: StorageDestination, *, file_format: StorageFormat, storage_options: "dict[str, Any] | None" = None + ) -> "tuple[ArrowTable, StorageTelemetry]": + pipeline = cast("AsyncStoragePipeline", self._storage_pipeline()) + return await pipeline.read_arrow_async(source, file_format=file_format, storage_options=storage_options) + + @staticmethod + def _build_ingest_telemetry(table: "ArrowTable", *, format_label: str = "arrow") -> StorageTelemetry: + rows = int(getattr(table, "num_rows", 0)) + bytes_processed = int(getattr(table, "nbytes", 0)) + return {"rows_processed": rows, "bytes_processed": bytes_processed, "format": format_label} + + def _coerce_arrow_table(self, source: "ArrowResult | Any") -> "ArrowTable": + ensure_pyarrow() + import pyarrow as pa + + if hasattr(source, "get_data"): + table = source.get_data() + if isinstance(table, pa.Table): + return table + msg = "ArrowResult did not return a pyarrow.Table instance" + raise TypeError(msg) + if isinstance(source, pa.Table): + return source + if isinstance(source, pa.RecordBatch): + return pa.Table.from_batches([source]) + if isinstance(source, Iterable): + return pa.Table.from_pylist(list(source)) + msg = f"Unsupported Arrow source type: {type(source).__name__}" + raise TypeError(msg) + + @staticmethod + def _arrow_table_to_rows( + table: "ArrowTable", columns: "list[str] | None" = None + ) -> "tuple[list[str], list[tuple[Any, ...]]]": + ensure_pyarrow() + resolved_columns = columns or list(table.column_names) + if not resolved_columns: + msg = "Arrow table has no columns to import" + raise ValueError(msg) + batches = table.to_pylist() + records: list[tuple[Any, ...]] = [] + for row in batches: + record = tuple(row.get(col) for col in resolved_columns) + records.append(record) + return resolved_columns, records diff --git a/sqlspec/exceptions.py b/sqlspec/exceptions.py index 9f89eab9..2ebfa53a 100644 --- a/sqlspec/exceptions.py +++ b/sqlspec/exceptions.py @@ -28,6 +28,7 @@ "SQLParsingError", "SQLSpecError", "SerializationError", + "StorageCapabilityError", "StorageOperationFailedError", "TransactionError", "UniqueViolationError", @@ -177,6 +178,21 @@ class StorageOperationFailedError(SQLSpecError): """Raised when a storage backend operation fails (e.g., network, permission, API error).""" +class StorageCapabilityError(SQLSpecError): + """Raised when a requested storage bridge capability is unavailable.""" + + def __init__(self, message: str, *, capability: str | None = None, remediation: str | None = None) -> None: + parts = [message] + if capability: + parts.append(f"(capability: {capability})") + if remediation: + parts.append(remediation) + detail = " ".join(parts) + super().__init__(detail) + self.capability = capability + self.remediation = remediation + + class FileNotFoundInStorageError(StorageOperationFailedError): """Raised when a file or object is not found in the storage backend.""" diff --git a/sqlspec/storage/__init__.py b/sqlspec/storage/__init__.py index d2c405ae..ddf056bd 100644 --- a/sqlspec/storage/__init__.py +++ b/sqlspec/storage/__init__.py @@ -8,6 +8,39 @@ - Capability-based backend selection """ +from sqlspec.storage.pipeline import ( + AsyncStoragePipeline, + PartitionStrategyConfig, + StagedArtifact, + StorageBridgeJob, + StorageCapabilities, + StorageDestination, + StorageFormat, + StorageLoadRequest, + StorageTelemetry, + SyncStoragePipeline, + create_storage_bridge_job, + get_storage_bridge_diagnostics, + get_storage_bridge_metrics, + reset_storage_bridge_metrics, +) from sqlspec.storage.registry import StorageRegistry, storage_registry -__all__ = ("StorageRegistry", "storage_registry") +__all__ = ( + "AsyncStoragePipeline", + "PartitionStrategyConfig", + "StagedArtifact", + "StorageBridgeJob", + "StorageCapabilities", + "StorageDestination", + "StorageFormat", + "StorageLoadRequest", + "StorageRegistry", + "StorageTelemetry", + "SyncStoragePipeline", + "create_storage_bridge_job", + "get_storage_bridge_diagnostics", + "get_storage_bridge_metrics", + "reset_storage_bridge_metrics", + "storage_registry", +) diff --git a/sqlspec/storage/pipeline.py b/sqlspec/storage/pipeline.py new file mode 100644 index 00000000..98b17965 --- /dev/null +++ b/sqlspec/storage/pipeline.py @@ -0,0 +1,574 @@ +"""Storage pipeline scaffolding for driver-aware storage bridge.""" + +from functools import partial +from pathlib import Path +from time import perf_counter, time +from typing import TYPE_CHECKING, Any, Literal, NamedTuple, TypeAlias, cast +from uuid import uuid4 + +from mypy_extensions import mypyc_attr +from typing_extensions import NotRequired, TypedDict + +from sqlspec.exceptions import ImproperConfigurationError +from sqlspec.storage._utils import import_pyarrow, import_pyarrow_parquet +from sqlspec.storage.errors import execute_async_storage_operation, execute_sync_storage_operation +from sqlspec.storage.registry import StorageRegistry, storage_registry +from sqlspec.utils.serializers import from_json, get_serializer_metrics, serialize_collection, to_json +from sqlspec.utils.sync_tools import async_ + +if TYPE_CHECKING: + from sqlspec.protocols import ObjectStoreProtocol + from sqlspec.typing import ArrowTable + + +__all__ = ( + "AsyncStoragePipeline", + "PartitionStrategyConfig", + "StagedArtifact", + "StorageBridgeJob", + "StorageCapabilities", + "StorageDestination", + "StorageFormat", + "StorageLoadRequest", + "StorageTelemetry", + "SyncStoragePipeline", + "create_storage_bridge_job", + "get_storage_bridge_diagnostics", + "get_storage_bridge_metrics", + "reset_storage_bridge_metrics", +) + +StorageFormat = Literal["jsonl", "json", "parquet", "arrow-ipc"] +StorageDestination: TypeAlias = str | Path + + +class StorageCapabilities(TypedDict): + """Runtime-evaluated driver storage capabilities.""" + + arrow_export_enabled: bool + arrow_import_enabled: bool + parquet_export_enabled: bool + parquet_import_enabled: bool + requires_staging_for_load: bool + staging_protocols: "list[str]" + partition_strategies: "list[str]" + default_storage_profile: NotRequired[str | None] + + +class PartitionStrategyConfig(TypedDict, total=False): + """Configuration for partition fan-out strategies.""" + + kind: str + partitions: int + rows_per_chunk: int + manifest_path: str + + +class StorageLoadRequest(TypedDict): + """Request describing a staging allocation.""" + + partition_id: str + destination_uri: str + ttl_seconds: int + correlation_id: str + source_uri: NotRequired[str] + + +class StagedArtifact(TypedDict): + """Metadata describing a staged artifact managed by the pipeline.""" + + partition_id: str + uri: str + cleanup_token: str + ttl_seconds: int + expires_at: float + correlation_id: str + + +class StorageTelemetry(TypedDict, total=False): + """Telemetry payload for storage bridge operations.""" + + destination: str + bytes_processed: int + rows_processed: int + partitions_created: int + duration_s: float + format: str + extra: "dict[str, Any]" + backend: str + + +class StorageBridgeJob(NamedTuple): + """Handle representing a storage bridge operation.""" + + job_id: str + status: str + telemetry: StorageTelemetry + + +class _StorageBridgeMetrics: + __slots__ = ("bytes_written", "partitions_created") + + def __init__(self) -> None: + self.bytes_written = 0 + self.partitions_created = 0 + + def record_bytes(self, count: int) -> None: + self.bytes_written += max(count, 0) + + def record_partitions(self, count: int) -> None: + self.partitions_created += max(count, 0) + + def snapshot(self) -> "dict[str, int]": + return { + "storage_bridge.bytes_written": self.bytes_written, + "storage_bridge.partitions_created": self.partitions_created, + } + + def reset(self) -> None: + self.bytes_written = 0 + self.partitions_created = 0 + + +_METRICS = _StorageBridgeMetrics() + + +def get_storage_bridge_metrics() -> "dict[str, int]": + """Return aggregated storage bridge metrics.""" + + return _METRICS.snapshot() + + +def reset_storage_bridge_metrics() -> None: + """Reset aggregated storage bridge metrics.""" + + _METRICS.reset() + + +def create_storage_bridge_job(status: str, telemetry: StorageTelemetry) -> StorageBridgeJob: + """Create a storage bridge job handle with a unique identifier.""" + + return StorageBridgeJob(job_id=str(uuid4()), status=status, telemetry=telemetry) + + +def get_storage_bridge_diagnostics() -> "dict[str, int]": + """Return aggregated storage bridge + serializer cache metrics.""" + + diagnostics = dict(get_storage_bridge_metrics()) + serializer_metrics = get_serializer_metrics() + for key, value in serializer_metrics.items(): + diagnostics[f"serializer.{key}"] = value + return diagnostics + + +def _encode_row_payload(rows: "list[Any]", format_hint: StorageFormat) -> bytes: + if format_hint == "json": + data = to_json(rows, as_bytes=True) + if isinstance(data, bytes): + return data + return data.encode() + buffer = bytearray() + for row in rows: + buffer.extend(to_json(row, as_bytes=True)) + buffer.extend(b"\n") + return bytes(buffer) + + +def _encode_arrow_payload(table: "ArrowTable", format_choice: StorageFormat, *, compression: str | None) -> bytes: + pa = import_pyarrow() + sink = pa.BufferOutputStream() + if format_choice == "arrow-ipc": + writer = pa.ipc.new_file(sink, table.schema) + writer.write_table(table) + writer.close() + else: + pq = import_pyarrow_parquet() + pq.write_table(table, sink, compression=compression) + buffer = sink.getvalue() + result_bytes: bytes = buffer.to_pybytes() + return result_bytes + + +def _decode_arrow_payload(payload: bytes, format_choice: StorageFormat) -> "ArrowTable": + pa = import_pyarrow() + if format_choice == "parquet": + pq = import_pyarrow_parquet() + return cast("ArrowTable", pq.read_table(pa.BufferReader(payload))) + if format_choice == "arrow-ipc": + reader = pa.ipc.open_file(pa.BufferReader(payload)) + return cast("ArrowTable", reader.read_all()) + text_payload = payload.decode() + if format_choice == "json": + data = from_json(text_payload) + rows = data if isinstance(data, list) else [data] + return cast("ArrowTable", pa.Table.from_pylist(rows)) + if format_choice == "jsonl": + rows = [from_json(line) for line in text_payload.splitlines() if line.strip()] + return cast("ArrowTable", pa.Table.from_pylist(rows)) + msg = f"Unsupported storage format for Arrow decoding: {format_choice}" + raise ValueError(msg) + + +@mypyc_attr(allow_interpreted_subclasses=True) +class SyncStoragePipeline: + """Pipeline coordinating storage registry operations and telemetry.""" + + __slots__ = ("registry",) + + def __init__(self, *, registry: StorageRegistry | None = None) -> None: + self.registry = registry or storage_registry + + def write_rows( + self, + rows: "list[dict[str, Any]]", + destination: StorageDestination, + *, + format_hint: StorageFormat | None = None, + storage_options: "dict[str, Any] | None" = None, + ) -> StorageTelemetry: + """Write dictionary rows to storage using cached serializers.""" + + serialized = serialize_collection(rows) + format_choice = format_hint or "jsonl" + payload = _encode_row_payload(serialized, format_choice) + return self._write_bytes( + payload, + destination, + rows=len(serialized), + format_label=format_choice, + storage_options=storage_options or {}, + ) + + def write_arrow( + self, + table: "ArrowTable", + destination: StorageDestination, + *, + format_hint: StorageFormat | None = None, + storage_options: "dict[str, Any] | None" = None, + compression: str | None = None, + ) -> StorageTelemetry: + """Write an Arrow table to storage using zero-copy buffers.""" + + format_choice = format_hint or "parquet" + payload = _encode_arrow_payload(table, format_choice, compression=compression) + return self._write_bytes( + payload, + destination, + rows=int(getattr(table, "num_rows", 0)), + format_label=format_choice, + storage_options=storage_options or {}, + ) + + def read_arrow( + self, source: StorageDestination, *, file_format: StorageFormat, storage_options: "dict[str, Any] | None" = None + ) -> "tuple[ArrowTable, StorageTelemetry]": + """Read an artifact from storage and decode it into an Arrow table.""" + + backend, path = self._resolve_backend(source, **(storage_options or {})) + backend_name = getattr(backend, "backend_type", "storage") + payload = execute_sync_storage_operation( + partial(backend.read_bytes, path), backend=backend_name, operation="read_bytes", path=path + ) + table = _decode_arrow_payload(payload, file_format) + telemetry: StorageTelemetry = { + "destination": path, + "bytes_processed": len(payload), + "rows_processed": int(getattr(table, "num_rows", 0)), + "format": file_format, + "backend": backend_name, + } + return table, telemetry + + def allocate_staging_artifacts(self, requests: "list[StorageLoadRequest]") -> "list[StagedArtifact]": + """Allocate staging metadata for upcoming loads.""" + + artifacts: list[StagedArtifact] = [] + now = time() + + for request in requests: + ttl = max(request["ttl_seconds"], 0) + cleanup_token = f"{request['correlation_id']}::{request['partition_id']}" + artifacts.append({ + "partition_id": request["partition_id"], + "uri": request["destination_uri"], + "cleanup_token": cleanup_token, + "ttl_seconds": ttl, + "expires_at": now + ttl if ttl else now, + "correlation_id": request["correlation_id"], + }) + if artifacts: + _METRICS.record_partitions(len(artifacts)) + return artifacts + + def cleanup_staging_artifacts(self, artifacts: "list[StagedArtifact]", *, ignore_errors: bool = True) -> None: + """Delete staged artifacts best-effort.""" + + for artifact in artifacts: + backend, path = self._resolve_backend(artifact["uri"]) + try: + execute_sync_storage_operation( + partial(backend.delete, path), + backend=getattr(backend, "backend_type", "storage"), + operation="delete", + path=path, + ) + except Exception: + if not ignore_errors: + raise + + def _write_bytes( + self, + payload: bytes, + destination: StorageDestination, + *, + rows: int, + format_label: str, + storage_options: "dict[str, Any]", + ) -> StorageTelemetry: + backend, path = self._resolve_backend(destination, **storage_options) + backend_name = getattr(backend, "backend_type", "storage") + start = perf_counter() + execute_sync_storage_operation( + partial(backend.write_bytes, path, payload), backend=backend_name, operation="write_bytes", path=path + ) + elapsed = perf_counter() - start + bytes_written = len(payload) + _METRICS.record_bytes(bytes_written) + telemetry: StorageTelemetry = { + "destination": path, + "bytes_processed": bytes_written, + "rows_processed": rows, + "duration_s": elapsed, + "format": format_label, + "backend": backend_name, + } + return telemetry + + def _resolve_backend( + self, destination: StorageDestination, **backend_options: Any + ) -> "tuple[ObjectStoreProtocol, str]": + destination_str = destination.as_posix() if isinstance(destination, Path) else str(destination) + alias_resolution = self._resolve_alias_destination(destination_str, backend_options) + if alias_resolution is not None: + return alias_resolution + backend = self.registry.get(destination_str, **backend_options) + normalized_path = self._normalize_path_for_backend(destination_str) + return backend, normalized_path + + def _resolve_alias_destination( + self, destination: str, backend_options: "dict[str, Any]" + ) -> "tuple[ObjectStoreProtocol, str] | None": + if not destination.startswith("alias://"): + return None + payload = destination.removeprefix("alias://") + alias_name, _, relative_path = payload.partition("/") + alias = alias_name.strip() + if not alias: + msg = "Alias destinations must include a registry alias before the path component" + raise ImproperConfigurationError(msg) + path_segment = relative_path.strip() + if not path_segment: + msg = "Alias destinations must include an object path after the alias name" + raise ImproperConfigurationError(msg) + backend = self.registry.get(alias, **backend_options) + return backend, path_segment.lstrip("/") + + def _normalize_path_for_backend(self, destination: str) -> str: + if destination.startswith("file://"): + return destination.removeprefix("file://") + if "://" in destination: + _, remainder = destination.split("://", 1) + return remainder.lstrip("/") + return destination + + +@mypyc_attr(allow_interpreted_subclasses=True) +class AsyncStoragePipeline: + """Async variant of the storage pipeline leveraging async-capable backends when available.""" + + __slots__ = ("registry",) + + def __init__(self, *, registry: StorageRegistry | None = None) -> None: + self.registry = registry or storage_registry + + async def write_rows( + self, + rows: "list[dict[str, Any]]", + destination: StorageDestination, + *, + format_hint: StorageFormat | None = None, + storage_options: "dict[str, Any] | None" = None, + ) -> StorageTelemetry: + serialized = serialize_collection(rows) + format_choice = format_hint or "jsonl" + payload = _encode_row_payload(serialized, format_choice) + return await self._write_bytes_async( + payload, + destination, + rows=len(serialized), + format_label=format_choice, + storage_options=storage_options or {}, + ) + + async def write_arrow( + self, + table: "ArrowTable", + destination: StorageDestination, + *, + format_hint: StorageFormat | None = None, + storage_options: "dict[str, Any] | None" = None, + compression: str | None = None, + ) -> StorageTelemetry: + format_choice = format_hint or "parquet" + payload = _encode_arrow_payload(table, format_choice, compression=compression) + return await self._write_bytes_async( + payload, + destination, + rows=int(getattr(table, "num_rows", 0)), + format_label=format_choice, + storage_options=storage_options or {}, + ) + + async def cleanup_staging_artifacts(self, artifacts: "list[StagedArtifact]", *, ignore_errors: bool = True) -> None: + for artifact in artifacts: + backend, path = self._resolve_backend(artifact["uri"]) + backend_name = getattr(backend, "backend_type", "storage") + delete_async = getattr(backend, "delete_async", None) + if delete_async is not None: + try: + await execute_async_storage_operation( + partial(delete_async, path), backend=backend_name, operation="delete", path=path + ) + except Exception: + if not ignore_errors: + raise + continue + + def _delete_sync( + backend: "ObjectStoreProtocol" = backend, path: str = path, backend_name: str = backend_name + ) -> None: + execute_sync_storage_operation( + partial(backend.delete, path), backend=backend_name, operation="delete", path=path + ) + + try: + await async_(_delete_sync)() + except Exception: + if not ignore_errors: + raise + + async def _write_bytes_async( + self, + payload: bytes, + destination: StorageDestination, + *, + rows: int, + format_label: str, + storage_options: "dict[str, Any]", + ) -> StorageTelemetry: + backend, path = self._resolve_backend(destination, **storage_options) + backend_name = getattr(backend, "backend_type", "storage") + writer = getattr(backend, "write_bytes_async", None) + start = perf_counter() + if writer is not None: + await execute_async_storage_operation( + partial(writer, path, payload), backend=backend_name, operation="write_bytes", path=path + ) + else: + + def _write_sync( + backend: "ObjectStoreProtocol" = backend, + path: str = path, + payload: bytes = payload, + backend_name: str = backend_name, + ) -> None: + execute_sync_storage_operation( + partial(backend.write_bytes, path, payload), + backend=backend_name, + operation="write_bytes", + path=path, + ) + + await async_(_write_sync)() + + elapsed = perf_counter() - start + bytes_written = len(payload) + _METRICS.record_bytes(bytes_written) + telemetry: StorageTelemetry = { + "destination": path, + "bytes_processed": bytes_written, + "rows_processed": rows, + "duration_s": elapsed, + "format": format_label, + "backend": backend_name, + } + return telemetry + + async def read_arrow_async( + self, source: StorageDestination, *, file_format: StorageFormat, storage_options: "dict[str, Any] | None" = None + ) -> "tuple[ArrowTable, StorageTelemetry]": + backend, path = self._resolve_backend(source, **(storage_options or {})) + backend_name = getattr(backend, "backend_type", "storage") + reader = getattr(backend, "read_bytes_async", None) + if reader is not None: + payload = await execute_async_storage_operation( + partial(reader, path), backend=backend_name, operation="read_bytes", path=path + ) + else: + + def _read_sync( + backend: "ObjectStoreProtocol" = backend, path: str = path, backend_name: str = backend_name + ) -> bytes: + return execute_sync_storage_operation( + partial(backend.read_bytes, path), backend=backend_name, operation="read_bytes", path=path + ) + + payload = await async_(_read_sync)() + + table = _decode_arrow_payload(payload, file_format) + telemetry: StorageTelemetry = { + "destination": path, + "bytes_processed": len(payload), + "rows_processed": int(getattr(table, "num_rows", 0)), + "format": file_format, + "backend": backend_name, + } + return table, telemetry + + def _resolve_backend( + self, destination: StorageDestination, **backend_options: Any + ) -> "tuple[ObjectStoreProtocol, str]": + destination_str = destination.as_posix() if isinstance(destination, Path) else str(destination) + alias_resolution = self._resolve_alias_destination(destination_str, backend_options) + if alias_resolution is not None: + return alias_resolution + backend = self.registry.get(destination_str, **backend_options) + normalized_path = self._normalize_path_for_backend(destination_str) + return backend, normalized_path + + def _resolve_alias_destination( + self, destination: str, backend_options: "dict[str, Any]" + ) -> "tuple[ObjectStoreProtocol, str] | None": + if not destination.startswith("alias://"): + return None + payload = destination.removeprefix("alias://") + alias_name, _, relative_path = payload.partition("/") + alias = alias_name.strip() + if not alias: + msg = "Alias destinations must include a registry alias before the path component" + raise ImproperConfigurationError(msg) + path_segment = relative_path.strip() + if not path_segment: + msg = "Alias destinations must include an object path after the alias name" + raise ImproperConfigurationError(msg) + backend = self.registry.get(alias, **backend_options) + return backend, path_segment.lstrip("/") + + def _normalize_path_for_backend(self, destination: str) -> str: + if destination.startswith("file://"): + return destination.removeprefix("file://") + if "://" in destination: + _, remainder = destination.split("://", 1) + return remainder.lstrip("/") + return destination diff --git a/tests/integration/test_adapters/_storage_bridge_helpers.py b/tests/integration/test_adapters/_storage_bridge_helpers.py new file mode 100644 index 00000000..7a5fcf77 --- /dev/null +++ b/tests/integration/test_adapters/_storage_bridge_helpers.py @@ -0,0 +1,28 @@ +"""Shared helpers for storage bridge integration tests.""" + +from typing import TYPE_CHECKING + +from sqlspec.storage.registry import storage_registry + +if TYPE_CHECKING: # pragma: no cover + from pytest_databases.docker.minio import MinioService + +__all__ = ("register_minio_alias",) + + +def register_minio_alias( + alias: str, minio_service: "MinioService", bucket: str, *, prefix: str = "storage-bridge" +) -> str: + """Register a storage registry alias backed by the pytest-databases MinIO service.""" + + storage_registry.register_alias( + alias, + f"s3://{bucket}/{prefix}", + backend="fsspec", + endpoint_url=f"http://{minio_service.endpoint}", + key=minio_service.access_key, + secret=minio_service.secret_key, + use_ssl=False, + client_kwargs={"endpoint_url": f"http://{minio_service.endpoint}", "verify": False}, + ) + return prefix diff --git a/tests/integration/test_adapters/test_adbc/test_storage_bridge.py b/tests/integration/test_adapters/test_adbc/test_storage_bridge.py new file mode 100644 index 00000000..34255c89 --- /dev/null +++ b/tests/integration/test_adapters/test_adbc/test_storage_bridge.py @@ -0,0 +1,63 @@ +"""Storage bridge integration tests for the ADBC adapter.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from sqlspec.adapters.adbc import AdbcDriver +from sqlspec.storage.registry import storage_registry + +pytestmark = [pytest.mark.xdist_group("storage"), pytest.mark.postgres] + + +def _prepare_tables(session: AdbcDriver, source: str, target: str) -> None: + session.execute_script( + f""" + DROP TABLE IF EXISTS {target}; + DROP TABLE IF EXISTS {source}; + CREATE TABLE {source} ( + id INT PRIMARY KEY, + label TEXT NOT NULL + ); + CREATE TABLE {target} ( + id INT PRIMARY KEY, + label TEXT NOT NULL + ); + """ + ) + + +def _seed_source(session: AdbcDriver, source: str) -> None: + session.execute(f"INSERT INTO {source} (id, label) VALUES (1, 'alpha'), (2, 'beta'), (3, 'gamma')") + + +@pytest.mark.usefixtures("adbc_postgresql_session") +def test_adbc_postgres_storage_bridge_round_trip(tmp_path: Path, adbc_postgresql_session: AdbcDriver) -> None: + source_table = "storage_bridge_adbc_source" + target_table = "storage_bridge_adbc_target" + alias = "adbc_storage_bridge_local" + storage_registry.register_alias(alias, f"file://{tmp_path}", backend="local") + destination = f"alias://{alias}/adbc_storage_bridge.parquet" + + session = adbc_postgresql_session + _prepare_tables(session, source_table, target_table) + _seed_source(session, source_table) + + export_job = session.select_to_storage( + f"SELECT id, label FROM {source_table} ORDER BY id", destination, format_hint="parquet" + ) + assert export_job.telemetry["rows_processed"] == 3 + destination_path = tmp_path / "adbc_storage_bridge.parquet" + assert destination_path.exists() + + load_job = session.load_from_storage(target_table, destination, file_format="parquet", overwrite=True) + assert load_job.telemetry["rows_processed"] == 3 + assert load_job.telemetry["destination"] == target_table + + result = session.execute(f"SELECT id, label FROM {target_table} ORDER BY id") + assert [(row["id"], row["label"]) for row in result] == [(1, "alpha"), (2, "beta"), (3, "gamma")] + + session.execute_script(f"DROP TABLE IF EXISTS {target_table}; DROP TABLE IF EXISTS {source_table};") + storage_registry.clear() diff --git a/tests/integration/test_adapters/test_aiosqlite/test_storage_bridge.py b/tests/integration/test_adapters/test_aiosqlite/test_storage_bridge.py new file mode 100644 index 00000000..b909f752 --- /dev/null +++ b/tests/integration/test_adapters/test_aiosqlite/test_storage_bridge.py @@ -0,0 +1,46 @@ +"""Storage bridge integration tests for AioSQLite adapter.""" + +from __future__ import annotations + +from pathlib import Path + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +from sqlspec.adapters.aiosqlite import AiosqliteDriver + +pytestmark = [pytest.mark.asyncio, pytest.mark.xdist_group("sqlite")] + + +async def test_aiosqlite_load_from_arrow(aiosqlite_session: AiosqliteDriver) -> None: + await aiosqlite_session.execute("DROP TABLE IF EXISTS storage_bridge_aiosqlite") + await aiosqlite_session.execute("CREATE TABLE storage_bridge_aiosqlite (id INTEGER PRIMARY KEY, label TEXT)") + + arrow_table = pa.table({"id": [1, 2], "label": ["north", "south"]}) + + job = await aiosqlite_session.load_from_arrow("storage_bridge_aiosqlite", arrow_table, overwrite=True) + + assert job.telemetry["rows_processed"] == arrow_table.num_rows + + result = await aiosqlite_session.execute("SELECT id, label FROM storage_bridge_aiosqlite ORDER BY id") + assert result.data == [{"id": 1, "label": "north"}, {"id": 2, "label": "south"}] + + +async def test_aiosqlite_load_from_storage(aiosqlite_session: AiosqliteDriver, tmp_path: Path) -> None: + await aiosqlite_session.execute("DROP TABLE IF EXISTS storage_bridge_aiosqlite") + await aiosqlite_session.execute("CREATE TABLE storage_bridge_aiosqlite (id INTEGER PRIMARY KEY, label TEXT)") + + arrow_table = pa.table({"id": [3, 4], "label": ["east", "west"]}) + destination = tmp_path / "aiosqlite-bridge.parquet" + pq.write_table(arrow_table, destination) + + job = await aiosqlite_session.load_from_storage( + "storage_bridge_aiosqlite", str(destination), file_format="parquet", overwrite=True + ) + + assert job.telemetry["extra"]["source"]["destination"].endswith("aiosqlite-bridge.parquet") + assert job.telemetry["extra"]["source"]["backend"] + + result = await aiosqlite_session.execute("SELECT id, label FROM storage_bridge_aiosqlite ORDER BY id") + assert result.data == [{"id": 3, "label": "east"}, {"id": 4, "label": "west"}] diff --git a/tests/integration/test_adapters/test_asyncmy/test_storage_bridge.py b/tests/integration/test_adapters/test_asyncmy/test_storage_bridge.py new file mode 100644 index 00000000..ca7f0941 --- /dev/null +++ b/tests/integration/test_adapters/test_asyncmy/test_storage_bridge.py @@ -0,0 +1,63 @@ +"""Storage bridge integration tests for AsyncMy adapter.""" + +from __future__ import annotations + +from pathlib import Path + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +from sqlspec.adapters.asyncmy import AsyncmyDriver + +pytestmark = [pytest.mark.asyncio, pytest.mark.xdist_group("mysql")] + + +async def _fetch_rows(asyncmy_driver: AsyncmyDriver, table: str) -> list[dict[str, object]]: + rows = await asyncmy_driver.select(f"SELECT id, name FROM {table} ORDER BY id") + assert isinstance(rows, list) + return rows + + +async def test_asyncmy_load_from_arrow(asyncmy_driver: AsyncmyDriver) -> None: + table_name = "storage_bridge_users" + await asyncmy_driver.execute(f"DROP TABLE IF EXISTS {table_name}") + await asyncmy_driver.execute(f"CREATE TABLE {table_name} (id INT PRIMARY KEY, name VARCHAR(64))") + + arrow_table = pa.table({"id": [1, 2], "name": ["alpha", "beta"]}) + + job = await asyncmy_driver.load_from_arrow(table_name, arrow_table, overwrite=True) + + assert job.telemetry["rows_processed"] == arrow_table.num_rows + assert job.telemetry["destination"] == table_name + + rows = await _fetch_rows(asyncmy_driver, table_name) + assert rows == [{"id": 1, "name": "alpha"}, {"id": 2, "name": "beta"}] + + await asyncmy_driver.execute(f"DROP TABLE IF EXISTS {table_name}") + + +async def test_asyncmy_load_from_storage(tmp_path: Path, asyncmy_driver: AsyncmyDriver) -> None: + await asyncmy_driver.execute("DROP TABLE IF EXISTS storage_bridge_scores") + await asyncmy_driver.execute("CREATE TABLE storage_bridge_scores (id INT PRIMARY KEY, score DECIMAL(5,2))") + + arrow_table = pa.table({"id": [5, 6], "score": [12.5, 99.1]}) + destination = tmp_path / "scores.parquet" + pq.write_table(arrow_table, destination) + + job = await asyncmy_driver.load_from_storage( + "storage_bridge_scores", str(destination), file_format="parquet", overwrite=True + ) + + assert job.telemetry["destination"] == "storage_bridge_scores" + assert job.telemetry["extra"]["source"]["destination"].endswith("scores.parquet") + assert job.telemetry["extra"]["source"]["backend"] + + rows = await asyncmy_driver.select("SELECT id, score FROM storage_bridge_scores ORDER BY id") + assert len(rows) == 2 + assert rows[0]["id"] == 5 + assert float(rows[0]["score"]) == pytest.approx(12.5) + assert rows[1]["id"] == 6 + assert float(rows[1]["score"]) == pytest.approx(99.1) + + await asyncmy_driver.execute("DROP TABLE IF EXISTS storage_bridge_scores") diff --git a/tests/integration/test_adapters/test_asyncpg/test_storage_bridge.py b/tests/integration/test_adapters/test_asyncpg/test_storage_bridge.py new file mode 100644 index 00000000..20d4469a --- /dev/null +++ b/tests/integration/test_adapters/test_asyncpg/test_storage_bridge.py @@ -0,0 +1,69 @@ +"""Storage bridge integration tests for AsyncPG using MinIO.""" + +from typing import TYPE_CHECKING + +import pytest + +from sqlspec.adapters.asyncpg import AsyncpgDriver +from sqlspec.storage.registry import storage_registry +from sqlspec.typing import FSSPEC_INSTALLED, PYARROW_INSTALLED +from tests.integration.test_adapters._storage_bridge_helpers import register_minio_alias + +if TYPE_CHECKING: # pragma: no cover + from minio import Minio + from pytest_databases.docker.minio import MinioService + +pytestmark = [ + pytest.mark.asyncpg, + pytest.mark.xdist_group("storage"), + pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed"), + pytest.mark.skipif(not PYARROW_INSTALLED, reason="pyarrow not installed"), +] + + +@pytest.mark.anyio +async def test_asyncpg_storage_bridge_with_minio( + asyncpg_async_driver: AsyncpgDriver, + minio_service: "MinioService", + minio_client: "Minio", + minio_default_bucket_name: str, +) -> None: + alias = "storage_bridge_asyncpg" + destination_path = "alias://storage_bridge_asyncpg/asyncpg/export.parquet" + source_table = "storage_bridge_asyncpg_source" + target_table = "storage_bridge_asyncpg_target" + + storage_registry.clear() + try: + prefix = register_minio_alias(alias, minio_service, minio_default_bucket_name) + + await asyncpg_async_driver.execute(f"DROP TABLE IF EXISTS {source_table} CASCADE") + await asyncpg_async_driver.execute(f"DROP TABLE IF EXISTS {target_table} CASCADE") + await asyncpg_async_driver.execute(f"CREATE TABLE {source_table} (id INT PRIMARY KEY, label TEXT NOT NULL)") + await asyncpg_async_driver.execute( + f"INSERT INTO {source_table} (id, label) VALUES (1, 'north'), (2, 'south'), (3, 'east')" + ) + + export_job = await asyncpg_async_driver.select_to_storage( + f"SELECT id, label FROM {source_table} ORDER BY id", destination_path, format_hint="parquet" + ) + assert export_job.telemetry["rows_processed"] == 3 + + await asyncpg_async_driver.execute(f"CREATE TABLE {target_table} (id INT PRIMARY KEY, label TEXT NOT NULL)") + load_job = await asyncpg_async_driver.load_from_storage( + target_table, destination_path, file_format="parquet", overwrite=True + ) + assert load_job.telemetry["rows_processed"] == 3 + + result = await asyncpg_async_driver.execute(f"SELECT id, label FROM {target_table} ORDER BY id") + rows = [(row["id"], row["label"]) for row in result] + assert rows == [(1, "north"), (2, "south"), (3, "east")] + + object_name = f"{prefix}/asyncpg/export.parquet" + stat = minio_client.stat_object(minio_default_bucket_name, object_name) + object_size = stat.size if stat.size is not None else 0 + assert object_size > 0 + finally: + storage_registry.clear() + await asyncpg_async_driver.execute(f"DROP TABLE IF EXISTS {source_table} CASCADE") + await asyncpg_async_driver.execute(f"DROP TABLE IF EXISTS {target_table} CASCADE") diff --git a/tests/integration/test_adapters/test_duckdb/test_storage_bridge.py b/tests/integration/test_adapters/test_duckdb/test_storage_bridge.py new file mode 100644 index 00000000..35513e91 --- /dev/null +++ b/tests/integration/test_adapters/test_duckdb/test_storage_bridge.py @@ -0,0 +1,69 @@ +"""Storage bridge integration tests for DuckDB using MinIO.""" + +from typing import TYPE_CHECKING + +import pytest + +from sqlspec.adapters.duckdb import DuckDBDriver +from sqlspec.storage.registry import storage_registry +from sqlspec.typing import FSSPEC_INSTALLED, PYARROW_INSTALLED +from tests.integration.test_adapters._storage_bridge_helpers import register_minio_alias + +if TYPE_CHECKING: # pragma: no cover + from minio import Minio + from pytest_databases.docker.minio import MinioService + +pytestmark = [ + pytest.mark.xdist_group("storage"), + pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed"), + pytest.mark.skipif(not PYARROW_INSTALLED, reason="pyarrow not installed"), +] + + +def test_duckdb_storage_bridge_with_minio( + duckdb_basic_session: DuckDBDriver, + minio_service: "MinioService", + minio_client: "Minio", + minio_default_bucket_name: str, +) -> None: + alias = "storage_bridge_duckdb" + destination_path = "alias://storage_bridge_duckdb/duckdb/export.parquet" + + storage_registry.clear() + try: + prefix = register_minio_alias(alias, minio_service, minio_default_bucket_name) + + duckdb_basic_session.execute("DROP TABLE IF EXISTS storage_bridge_duckdb_source") + duckdb_basic_session.execute("DROP TABLE IF EXISTS storage_bridge_duckdb_target") + duckdb_basic_session.execute( + "CREATE TABLE storage_bridge_duckdb_source (id INTEGER PRIMARY KEY, label TEXT NOT NULL)" + ) + duckdb_basic_session.execute( + "INSERT INTO storage_bridge_duckdb_source VALUES (1, 'alpha'), (2, 'beta'), (3, 'gamma')" + ) + + export_job = duckdb_basic_session.select_to_storage( + "SELECT id, label FROM storage_bridge_duckdb_source ORDER BY id", destination_path, format_hint="parquet" + ) + assert export_job.telemetry["rows_processed"] == 3 + + duckdb_basic_session.execute( + "CREATE TABLE storage_bridge_duckdb_target (id INTEGER PRIMARY KEY, label TEXT NOT NULL)" + ) + load_job = duckdb_basic_session.load_from_storage( + "storage_bridge_duckdb_target", destination_path, file_format="parquet", overwrite=True + ) + assert load_job.telemetry["rows_processed"] == 3 + + result = duckdb_basic_session.execute("SELECT id, label FROM storage_bridge_duckdb_target ORDER BY id") + rows = [(row["id"], row["label"]) for row in result.get_data()] + assert rows == [(1, "alpha"), (2, "beta"), (3, "gamma")] + + object_name = f"{prefix}/duckdb/export.parquet" + stat = minio_client.stat_object(minio_default_bucket_name, object_name) + object_size = stat.size if stat.size is not None else 0 + assert object_size > 0 + finally: + storage_registry.clear() + duckdb_basic_session.execute("DROP TABLE IF EXISTS storage_bridge_duckdb_source") + duckdb_basic_session.execute("DROP TABLE IF EXISTS storage_bridge_duckdb_target") diff --git a/tests/integration/test_adapters/test_psqlpy/test_storage_bridge.py b/tests/integration/test_adapters/test_psqlpy/test_storage_bridge.py new file mode 100644 index 00000000..c5c4dba7 --- /dev/null +++ b/tests/integration/test_adapters/test_psqlpy/test_storage_bridge.py @@ -0,0 +1,65 @@ +"""Storage bridge integration tests for PSQLPy driver.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from sqlspec.adapters.psqlpy import PsqlpyDriver +from sqlspec.storage.registry import storage_registry +from sqlspec.typing import FSSPEC_INSTALLED, PYARROW_INSTALLED +from tests.integration.test_adapters._storage_bridge_helpers import register_minio_alias + +if TYPE_CHECKING: # pragma: no cover + from minio import Minio + from pytest_databases.docker.minio import MinioService + +pytestmark = [ + pytest.mark.xdist_group("storage"), + pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed"), + pytest.mark.skipif(not PYARROW_INSTALLED, reason="pyarrow not installed"), +] + + +@pytest.mark.anyio +async def test_psqlpy_storage_bridge_with_minio( + psqlpy_driver: PsqlpyDriver, minio_service: MinioService, minio_client: Minio, minio_default_bucket_name: str +) -> None: + alias = "storage_bridge_psqlpy" + destination = f"alias://{alias}/psqlpy/export.parquet" + source_table = "storage_bridge_psqlpy_source" + target_table = "storage_bridge_psqlpy_target" + + storage_registry.clear() + try: + prefix = register_minio_alias(alias, minio_service, minio_default_bucket_name) + + await psqlpy_driver.execute(f"DROP TABLE IF EXISTS {source_table}") + await psqlpy_driver.execute(f"DROP TABLE IF EXISTS {target_table}") + await psqlpy_driver.execute(f"CREATE TABLE {source_table} (id INT PRIMARY KEY, label TEXT NOT NULL)") + for idx, label in enumerate(["delta", "omega", "zeta"], start=1): + await psqlpy_driver.execute(f"INSERT INTO {source_table} (id, label) VALUES (?, ?)", (idx, label)) + + export_job = await psqlpy_driver.select_to_storage( + f"SELECT id, label FROM {source_table} ORDER BY id", destination, format_hint="parquet" + ) + assert export_job.telemetry["rows_processed"] == 3 + + await psqlpy_driver.execute(f"CREATE TABLE {target_table} (id INT PRIMARY KEY, label TEXT NOT NULL)") + load_job = await psqlpy_driver.load_from_storage( + target_table, destination, file_format="parquet", overwrite=True + ) + assert load_job.telemetry["rows_processed"] == 3 + + rows = await psqlpy_driver.select(f"SELECT id, label FROM {target_table} ORDER BY id") + assert rows == [{"id": 1, "label": "delta"}, {"id": 2, "label": "omega"}, {"id": 3, "label": "zeta"}] + + object_name = f"{prefix}/psqlpy/export.parquet" + stat = minio_client.stat_object(minio_default_bucket_name, object_name) + object_size = stat.size if stat.size is not None else 0 + assert object_size > 0 + finally: + storage_registry.clear() + await psqlpy_driver.execute(f"DROP TABLE IF EXISTS {source_table}") + await psqlpy_driver.execute(f"DROP TABLE IF EXISTS {target_table}") diff --git a/tests/integration/test_adapters/test_psycopg/test_storage_bridge.py b/tests/integration/test_adapters/test_psycopg/test_storage_bridge.py new file mode 100644 index 00000000..935f8f7a --- /dev/null +++ b/tests/integration/test_adapters/test_psycopg/test_storage_bridge.py @@ -0,0 +1,153 @@ +"""Storage bridge integration tests for psycopg drivers.""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator, Generator +from typing import TYPE_CHECKING + +import pytest + +from sqlspec.adapters.psycopg import PsycopgAsyncConfig, PsycopgAsyncDriver, PsycopgSyncConfig, PsycopgSyncDriver +from sqlspec.core import SQLResult +from sqlspec.storage.registry import storage_registry +from sqlspec.typing import FSSPEC_INSTALLED, PYARROW_INSTALLED +from tests.integration.test_adapters._storage_bridge_helpers import register_minio_alias + +if TYPE_CHECKING: # pragma: no cover + from minio import Minio + from pytest_databases.docker.minio import MinioService + +pytestmark = [ + pytest.mark.xdist_group("storage"), + pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed"), + pytest.mark.skipif(not PYARROW_INSTALLED, reason="pyarrow not installed"), +] + + +@pytest.fixture +def psycopg_sync_session(psycopg_sync_config: PsycopgSyncConfig) -> Generator[PsycopgSyncDriver, None, None]: + with psycopg_sync_config.provide_session() as session: + yield session + + +@pytest.fixture +async def psycopg_async_session(psycopg_async_config: PsycopgAsyncConfig) -> AsyncGenerator[PsycopgAsyncDriver, None]: + async with psycopg_async_config.provide_session() as session: + yield session + + +def test_psycopg_sync_storage_bridge_with_minio( + psycopg_sync_config: PsycopgSyncConfig, + minio_service: MinioService, + minio_client: Minio, + minio_default_bucket_name: str, +) -> None: + alias = "storage_bridge_psycopg_sync" + destination = f"alias://{alias}/psycopg_sync/export.parquet" + source_table = "storage_bridge_psycopg_sync_source" + target_table = "storage_bridge_psycopg_sync_target" + + storage_registry.clear() + try: + prefix = register_minio_alias(alias, minio_service, minio_default_bucket_name) + + with psycopg_sync_config.provide_session() as session: + session.execute_script(f"DROP TABLE IF EXISTS {source_table} CASCADE") + session.execute_script(f"DROP TABLE IF EXISTS {target_table} CASCADE") + session.execute_script(f"CREATE TABLE {source_table} (id INT PRIMARY KEY, label TEXT NOT NULL)") + session.commit() + session.execute( + f"INSERT INTO {source_table} (id, label) VALUES (%s, %s), (%s, %s), (%s, %s)", + 1, + "alpha", + 2, + "beta", + 3, + "gamma", + ) + session.commit() + + export_job = session.select_to_storage( + f"SELECT id, label FROM {source_table} WHERE label IN ($1, $2, $3) ORDER BY id", + destination, + "alpha", + "beta", + "gamma", + format_hint="parquet", + ) + assert export_job.telemetry["rows_processed"] == 3 + + session.execute_script(f"CREATE TABLE {target_table} (id INT PRIMARY KEY, label TEXT NOT NULL)") + session.commit() + load_job = session.load_from_storage(target_table, destination, file_format="parquet", overwrite=True) + assert load_job.telemetry["rows_processed"] == 3 + + result = session.execute(f"SELECT id, label FROM {target_table} ORDER BY id") + assert isinstance(result, SQLResult) + assert result.data == [{"id": 1, "label": "alpha"}, {"id": 2, "label": "beta"}, {"id": 3, "label": "gamma"}] + + object_name = f"{prefix}/psycopg_sync/export.parquet" + stat = minio_client.stat_object(minio_default_bucket_name, object_name) + object_size = stat.size if stat.size is not None else 0 + assert object_size > 0 + finally: + storage_registry.clear() + with psycopg_sync_config.provide_session() as cleanup: + cleanup.execute_script(f"DROP TABLE IF EXISTS {source_table} CASCADE") + cleanup.execute_script(f"DROP TABLE IF EXISTS {target_table} CASCADE") + cleanup.commit() + + +@pytest.mark.anyio +async def test_psycopg_async_storage_bridge_with_minio( + psycopg_async_session: PsycopgAsyncDriver, + minio_service: MinioService, + minio_client: Minio, + minio_default_bucket_name: str, +) -> None: + alias = "storage_bridge_psycopg_async" + destination = f"alias://{alias}/psycopg_async/export.parquet" + source_table = "storage_bridge_psycopg_async_source" + target_table = "storage_bridge_psycopg_async_target" + + storage_registry.clear() + try: + prefix = register_minio_alias(alias, minio_service, minio_default_bucket_name) + + await psycopg_async_session.execute_script(f"DROP TABLE IF EXISTS {source_table} CASCADE") + await psycopg_async_session.execute_script(f"DROP TABLE IF EXISTS {target_table} CASCADE") + await psycopg_async_session.execute_script( + f"CREATE TABLE {source_table} (id INT PRIMARY KEY, label TEXT NOT NULL)" + ) + for idx, label in enumerate(["north", "south", "east"], start=1): + await psycopg_async_session.execute(f"INSERT INTO {source_table} (id, label) VALUES (%s, %s)", idx, label) + + export_job = await psycopg_async_session.select_to_storage( + f"SELECT id, label FROM {source_table} WHERE label IN ($1, $2, $3) ORDER BY id", + destination, + "north", + "south", + "east", + format_hint="parquet", + ) + assert export_job.telemetry["rows_processed"] == 3 + + await psycopg_async_session.execute_script( + f"CREATE TABLE {target_table} (id INT PRIMARY KEY, label TEXT NOT NULL)" + ) + load_job = await psycopg_async_session.load_from_storage( + target_table, destination, file_format="parquet", overwrite=True + ) + assert load_job.telemetry["rows_processed"] == 3 + + rows = await psycopg_async_session.select(f"SELECT id, label FROM {target_table} ORDER BY id") + assert rows == [{"id": 1, "label": "north"}, {"id": 2, "label": "south"}, {"id": 3, "label": "east"}] + + object_name = f"{prefix}/psycopg_async/export.parquet" + stat = minio_client.stat_object(minio_default_bucket_name, object_name) + object_size = stat.size if stat.size is not None else 0 + assert object_size > 0 + finally: + storage_registry.clear() + await psycopg_async_session.execute_script(f"DROP TABLE IF EXISTS {source_table} CASCADE") + await psycopg_async_session.execute_script(f"DROP TABLE IF EXISTS {target_table} CASCADE") diff --git a/tests/integration/test_adapters/test_sqlite/test_storage_bridge.py b/tests/integration/test_adapters/test_sqlite/test_storage_bridge.py new file mode 100644 index 00000000..6c9b708d --- /dev/null +++ b/tests/integration/test_adapters/test_sqlite/test_storage_bridge.py @@ -0,0 +1,47 @@ +"""Storage bridge integration tests for SQLite adapter.""" + +from __future__ import annotations + +from pathlib import Path + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +from sqlspec.adapters.sqlite import SqliteDriver + +pytestmark = pytest.mark.xdist_group("sqlite") + + +def test_sqlite_load_from_arrow(sqlite_session: SqliteDriver) -> None: + sqlite_session.execute("DROP TABLE IF EXISTS storage_bridge_sqlite") + sqlite_session.execute("CREATE TABLE storage_bridge_sqlite (id INTEGER PRIMARY KEY, label TEXT)") + + arrow_table = pa.table({"id": [1, 2], "label": ["alpha", "beta"]}) + + job = sqlite_session.load_from_arrow("storage_bridge_sqlite", arrow_table, overwrite=True) + + assert job.telemetry["destination"] == "storage_bridge_sqlite" + assert job.telemetry["rows_processed"] == arrow_table.num_rows + + result = sqlite_session.execute("SELECT id, label FROM storage_bridge_sqlite ORDER BY id") + assert result.data == [{"id": 1, "label": "alpha"}, {"id": 2, "label": "beta"}] + + +def test_sqlite_load_from_storage(sqlite_session: SqliteDriver, tmp_path: Path) -> None: + sqlite_session.execute("DROP TABLE IF EXISTS storage_bridge_sqlite") + sqlite_session.execute("CREATE TABLE storage_bridge_sqlite (id INTEGER PRIMARY KEY, label TEXT)") + + arrow_table = pa.table({"id": [10, 11], "label": ["gamma", "delta"]}) + destination = tmp_path / "sqlite-bridge.parquet" + pq.write_table(arrow_table, destination) + + job = sqlite_session.load_from_storage( + "storage_bridge_sqlite", str(destination), file_format="parquet", overwrite=True + ) + + assert job.telemetry["extra"]["source"]["destination"].endswith("sqlite-bridge.parquet") + assert job.telemetry["extra"]["source"]["backend"] + + result = sqlite_session.execute("SELECT id, label FROM storage_bridge_sqlite ORDER BY id") + assert result.data == [{"id": 10, "label": "gamma"}, {"id": 11, "label": "delta"}] diff --git a/tests/unit/test_adapters/test_psycopg/test_config.py b/tests/unit/test_adapters/test_psycopg/test_config.py index 8ee13ec7..188e54bf 100644 --- a/tests/unit/test_adapters/test_psycopg/test_config.py +++ b/tests/unit/test_adapters/test_psycopg/test_config.py @@ -1,7 +1,8 @@ """Psycopg configuration tests covering statement config builders.""" from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig, PsycopgSyncConfig -from sqlspec.adapters.psycopg.driver import build_psycopg_statement_config +from sqlspec.adapters.psycopg.driver import build_psycopg_statement_config, psycopg_statement_config +from sqlspec.core import SQL def test_build_psycopg_statement_config_custom_serializer() -> None: @@ -38,3 +39,20 @@ def serializer(_: object) -> str: parameter_config = config.statement_config.parameter_config assert parameter_config.json_serializer is serializer + + +def test_psycopg_numeric_placeholders_convert_to_pyformat() -> None: + """Numeric placeholders should be rewritten for psycopg execution.""" + + statement = SQL( + "SELECT * FROM bridge_validation WHERE label IN ($1, $2, $3)", + "alpha", + "beta", + "gamma", + statement_config=psycopg_statement_config, + ) + compiled_sql, parameters = statement.compile() + + assert "$1" not in compiled_sql + assert compiled_sql.count("%s") == 3 + assert parameters == ["alpha", "beta", "gamma"] diff --git a/tests/unit/test_config/test_storage_capabilities.py b/tests/unit/test_config/test_storage_capabilities.py new file mode 100644 index 00000000..835e35e8 --- /dev/null +++ b/tests/unit/test_config/test_storage_capabilities.py @@ -0,0 +1,101 @@ +from contextlib import AbstractContextManager, contextmanager +from typing import Any + +from sqlspec.config import NoPoolSyncConfig +from sqlspec.driver import SyncDriverAdapterBase +from sqlspec.driver._sync import SyncDataDictionaryBase + + +class _DummyDriver(SyncDriverAdapterBase): + __slots__ = () + + @property + def data_dictionary(self) -> SyncDataDictionaryBase: # type: ignore[override] + raise NotImplementedError + + def with_cursor(self, connection: Any) -> AbstractContextManager[Any]: # type: ignore[override] + @contextmanager + def _cursor_ctx(): + yield object() + + return _cursor_ctx() + + def handle_database_exceptions(self) -> AbstractContextManager[None]: # type: ignore[override] + @contextmanager + def _handler_ctx(): + yield None + + return _handler_ctx() + + def begin(self) -> None: # type: ignore[override] + raise NotImplementedError + + def rollback(self) -> None: # type: ignore[override] + raise NotImplementedError + + def commit(self) -> None: # type: ignore[override] + raise NotImplementedError + + def _try_special_handling(self, cursor: Any, statement: Any): # type: ignore[override] + return None + + def _execute_script(self, cursor: Any, statement: Any): # type: ignore[override] + raise NotImplementedError + + def _execute_many(self, cursor: Any, statement: Any): # type: ignore[override] + raise NotImplementedError + + def _execute_statement(self, cursor: Any, statement: Any): # type: ignore[override] + raise NotImplementedError + + +class _CapabilityConfig(NoPoolSyncConfig[Any, "_DummyDriver"]): + driver_type = _DummyDriver + connection_type = object + supports_native_arrow_export = True + supports_native_arrow_import = True + supports_native_parquet_export = False + supports_native_parquet_import = False + requires_staging_for_load = True + staging_protocols = ("s3://",) + storage_partition_strategies = ("fixed", "rows_per_chunk") + default_storage_profile = "local-temp" + + def create_connection(self) -> object: + return object() + + @contextmanager + def provide_connection(self, *args: Any, **kwargs: Any): # type: ignore[override] + yield object() + + @contextmanager + def provide_session(self, *args: Any, **kwargs: Any): # type: ignore[override] + yield object() + + +def test_storage_capabilities_snapshot(monkeypatch): + monkeypatch.setattr(_CapabilityConfig, "_dependency_available", staticmethod(lambda checker: True)) + config = _CapabilityConfig() + + capabilities = config.storage_capabilities() + assert capabilities["arrow_export_enabled"] is True + assert capabilities["arrow_import_enabled"] is True + assert capabilities["parquet_export_enabled"] is False + assert capabilities["requires_staging_for_load"] is True + assert capabilities["partition_strategies"] == ["fixed", "rows_per_chunk"] + assert capabilities["default_storage_profile"] == "local-temp" + + capabilities["arrow_export_enabled"] = False + assert config.storage_capabilities()["arrow_export_enabled"] is True + + monkeypatch.setattr(_CapabilityConfig, "supports_native_arrow_export", False) + config.reset_storage_capabilities_cache() + assert config.storage_capabilities()["arrow_export_enabled"] is False + + +def test_driver_features_seed_capabilities(monkeypatch): + monkeypatch.setattr(_CapabilityConfig, "_dependency_available", staticmethod(lambda checker: False)) + config = _CapabilityConfig() + assert "storage_capabilities" in config.driver_features + snapshot = config.driver_features["storage_capabilities"] + assert isinstance(snapshot, dict) diff --git a/tests/unit/test_loader/test_fixtures_directory_loading.py b/tests/unit/test_loader/test_fixtures_directory_loading.py index d97d3522..adc5278b 100644 --- a/tests/unit/test_loader/test_fixtures_directory_loading.py +++ b/tests/unit/test_loader/test_fixtures_directory_loading.py @@ -31,6 +31,8 @@ def print(self, *args: Any, **kwargs: Any) -> None: pytestmark = pytest.mark.xdist_group("loader") +MAX_LARGE_QUERY_LOOKUP_SECONDS = 0.75 + @pytest.fixture def fixtures_path() -> Path: @@ -636,6 +638,9 @@ def test_large_query_handling() -> None: loader.get_sql("large_database_analysis") elapsed = time.perf_counter() - start_time - assert elapsed < 0.1, f"Large query retrieval too slow: {elapsed:.3f}s for 100 calls" + assert elapsed < MAX_LARGE_QUERY_LOOKUP_SECONDS, ( + f"Large query retrieval too slow: {elapsed:.3f}s for 100 calls " + f"(threshold {MAX_LARGE_QUERY_LOOKUP_SECONDS:.2f}s)" + ) console.print(f"[green]✓[/green] Large query ({len(sql.sql)} chars) handled efficiently") console.print(f" • Performance: {elapsed * 1000:.1f}ms for 100 calls ({elapsed * 10.0:.1f}ms per call)") diff --git a/tests/unit/test_storage_bridge.py b/tests/unit/test_storage_bridge.py new file mode 100644 index 00000000..10e8a38b --- /dev/null +++ b/tests/unit/test_storage_bridge.py @@ -0,0 +1,365 @@ +"""Unit tests for storage bridge ingestion helpers.""" + +import sqlite3 +from pathlib import Path +from typing import Any, cast + +import aiosqlite +import duckdb +import pyarrow as pa +import pytest + +from sqlspec.adapters.aiosqlite import AiosqliteDriver, aiosqlite_statement_config +from sqlspec.adapters.asyncmy import AsyncmyConnection, AsyncmyDriver, asyncmy_statement_config +from sqlspec.adapters.asyncpg import AsyncpgConnection, AsyncpgDriver, asyncpg_statement_config +from sqlspec.adapters.duckdb import DuckDBDriver, duckdb_statement_config +from sqlspec.adapters.psqlpy import PsqlpyConnection, PsqlpyDriver, psqlpy_statement_config +from sqlspec.adapters.sqlite import SqliteDriver, sqlite_statement_config +from sqlspec.storage import SyncStoragePipeline, get_storage_bridge_diagnostics, reset_storage_bridge_metrics +from sqlspec.storage.pipeline import StorageDestination +from sqlspec.storage.registry import storage_registry +from sqlspec.utils.serializers import reset_serializer_cache, serialize_collection + +CAPABILITIES = { + "arrow_export_enabled": True, + "arrow_import_enabled": True, + "parquet_export_enabled": True, + "parquet_import_enabled": True, + "requires_staging_for_load": False, + "staging_protocols": [], + "partition_strategies": ["fixed"], +} + + +class DummyAsyncpgConnection: + def __init__(self) -> None: + self.calls: list[tuple[str, list[tuple[object, ...]], list[str]]] = [] + + async def copy_records_to_table(self, table: str, *, records: list[tuple[object, ...]], columns: list[str]) -> None: + self.calls.append((table, records, columns)) + + +class DummyPsqlpyConnection: + def __init__(self) -> None: + self.copy_calls: list[dict[str, Any]] = [] + self.statements: list[str] = [] + + async def binary_copy_to_table( + self, + source: list[tuple[object, ...]], + table_name: str, + *, + columns: list[str] | None = None, + schema_name: str | None = None, + ) -> None: + self.copy_calls.append({ + "table": table_name, + "schema": schema_name, + "columns": columns or [], + "records": source, + }) + + async def execute(self, sql: str, params: "list[Any] | None" = None) -> None: + _ = params + self.statements.append(sql) + + +class DummyAsyncmyCursorImpl: + def __init__(self, operations: "list[tuple[str, Any, Any | None]]") -> None: + self.operations = operations + + async def executemany(self, sql: str, params: Any) -> None: + self.operations.append(("executemany", sql, params)) + + async def execute(self, sql: str, params: Any | None = None) -> None: + self.operations.append(("execute", sql, params)) + + async def close(self) -> None: + return None + + +class DummyAsyncmyConnection: + def __init__(self) -> None: + self.operations: list[tuple[str, Any, Any | None]] = [] + + def cursor(self) -> DummyAsyncmyCursorImpl: + return DummyAsyncmyCursorImpl(self.operations) + + +@pytest.mark.asyncio +async def test_asyncpg_load_from_storage(monkeypatch: pytest.MonkeyPatch) -> None: + arrow_table = pa.table({"id": [1, 2], "name": ["alpha", "beta"]}) + + async def _fake_read(self, *_: object, **__: object) -> tuple[pa.Table, dict[str, object]]: + return arrow_table, {"destination": "file://tmp/part-0.parquet", "bytes_processed": 128} + + driver = AsyncpgDriver( + connection=cast(AsyncpgConnection, DummyAsyncpgConnection()), + statement_config=asyncpg_statement_config, + driver_features={"storage_capabilities": CAPABILITIES}, + ) + monkeypatch.setattr(AsyncpgDriver, "_read_arrow_from_storage_async", _fake_read) + + job = await driver.load_from_storage("public.ingest_target", "file://tmp/part-0.parquet", file_format="parquet") + + assert driver.connection.calls[0][0] == "public.ingest_target" + assert driver.connection.calls[0][2] == ["id", "name"] + assert job.telemetry["rows_processed"] == arrow_table.num_rows + assert job.telemetry["destination"] == "public.ingest_target" + + +def test_duckdb_load_from_storage(monkeypatch: pytest.MonkeyPatch) -> None: + arrow_table = pa.table({"id": [10, 11], "label": ["east", "west"]}) + + def _fake_read(self, *_: object, **__: object) -> tuple[pa.Table, dict[str, object]]: + return arrow_table, {"destination": "file://tmp/part-1.parquet", "bytes_processed": 256} + + connection = duckdb.connect(database=":memory:") + connection.execute("CREATE TABLE ingest_target (id INTEGER, label TEXT)") + + driver = DuckDBDriver( + connection=connection, + statement_config=duckdb_statement_config, + driver_features={"storage_capabilities": CAPABILITIES}, + ) + + monkeypatch.setattr(DuckDBDriver, "_read_arrow_from_storage_sync", _fake_read) + + job = driver.load_from_storage("ingest_target", "file://tmp/part-1.parquet", file_format="parquet", overwrite=True) + + rows = connection.execute("SELECT id, label FROM ingest_target ORDER BY id").fetchall() + assert rows == [(10, "east"), (11, "west")] + assert job.telemetry["rows_processed"] == arrow_table.num_rows + assert job.telemetry["destination"] == "ingest_target" + + +@pytest.mark.asyncio +async def test_psqlpy_load_from_arrow_overwrite() -> None: + arrow_table = pa.table({"id": [7, 8], "name": ["east", "west"]}) + dummy_connection = DummyPsqlpyConnection() + driver = PsqlpyDriver( + connection=cast(PsqlpyConnection, dummy_connection), + statement_config=psqlpy_statement_config, + driver_features={"storage_capabilities": CAPABILITIES}, + ) + + job = await driver.load_from_arrow("analytics.ingest_target", arrow_table, overwrite=True) + + assert dummy_connection.statements == ['TRUNCATE TABLE "analytics"."ingest_target"'] + assert dummy_connection.copy_calls[0]["table"] == "ingest_target" + assert dummy_connection.copy_calls[0]["schema"] == "analytics" + payload = dummy_connection.copy_calls[0]["records"] + if isinstance(payload, bytes): + assert payload == b"7\teast\n8\twest\n" + else: + assert payload == [(7, "east"), (8, "west")] + assert job.telemetry["destination"] == "analytics.ingest_target" + assert job.telemetry["rows_processed"] == arrow_table.num_rows + + +@pytest.mark.asyncio +async def test_psqlpy_load_from_storage_merges_telemetry(monkeypatch: pytest.MonkeyPatch) -> None: + arrow_table = pa.table({"id": [1, 2], "name": ["north", "south"]}) + dummy_connection = DummyPsqlpyConnection() + driver = PsqlpyDriver( + connection=cast(PsqlpyConnection, dummy_connection), + statement_config=psqlpy_statement_config, + driver_features={"storage_capabilities": CAPABILITIES}, + ) + + async def _fake_read(self, *_: object, **__: object) -> tuple[pa.Table, dict[str, object]]: + return arrow_table, {"destination": "s3://bucket/part-2.parquet", "bytes_processed": 512} + + monkeypatch.setattr(PsqlpyDriver, "_read_arrow_from_storage_async", _fake_read) + + job = await driver.load_from_storage("public.delta_load", "s3://bucket/part-2.parquet", file_format="parquet") + + assert dummy_connection.copy_calls[0]["table"] == "delta_load" + assert dummy_connection.copy_calls[0]["columns"] == ["id", "name"] + assert job.telemetry["destination"] == "public.delta_load" + assert job.telemetry["extra"]["source"]["destination"] == "s3://bucket/part-2.parquet" + + +@pytest.mark.asyncio +async def test_aiosqlite_load_from_arrow_overwrite() -> None: + connection = await aiosqlite.connect(":memory:") + try: + await connection.execute("CREATE TABLE ingest (id INTEGER, name TEXT)") + await connection.execute("INSERT INTO ingest (id, name) VALUES (99, 'stale')") + await connection.commit() + + driver = AiosqliteDriver( + connection=connection, + statement_config=aiosqlite_statement_config, + driver_features={"storage_capabilities": CAPABILITIES}, + ) + arrow_table = pa.table({"id": [1, 2], "name": ["alpha", "beta"]}) + + job = await driver.load_from_arrow("ingest", arrow_table, overwrite=True) + + async with connection.execute("SELECT id, name FROM ingest ORDER BY id") as cursor: + rows = await cursor.fetchall() + assert rows == [(1, "alpha"), (2, "beta")] # type: ignore[comparison-overlap] + assert job.telemetry["destination"] == "ingest" + assert job.telemetry["rows_processed"] == arrow_table.num_rows + finally: + await connection.close() + + +@pytest.mark.asyncio +async def test_aiosqlite_load_from_storage_includes_source(monkeypatch: pytest.MonkeyPatch) -> None: + connection = await aiosqlite.connect(":memory:") + try: + await connection.execute("CREATE TABLE raw_data (id INTEGER, label TEXT)") + await connection.commit() + + driver = AiosqliteDriver( + connection=connection, + statement_config=aiosqlite_statement_config, + driver_features={"storage_capabilities": CAPABILITIES}, + ) + arrow_table = pa.table({"id": [5], "label": ["gamma"]}) + + async def _fake_read(self, *_: object, **__: object) -> tuple[pa.Table, dict[str, object]]: + return arrow_table, {"destination": "file:///tmp/chunk.parquet", "bytes_processed": 64} + + monkeypatch.setattr(AiosqliteDriver, "_read_arrow_from_storage_async", _fake_read) + + job = await driver.load_from_storage("raw_data", "file:///tmp/chunk.parquet", file_format="parquet") + + async with connection.execute("SELECT id, label FROM raw_data") as cursor: + rows = await cursor.fetchall() + assert rows == [(5, "gamma")] # type: ignore[comparison-overlap] + assert job.telemetry["extra"]["source"]["destination"] == "file:///tmp/chunk.parquet" + finally: + await connection.close() + + +def test_sqlite_load_from_arrow_overwrite() -> None: + connection = sqlite3.connect(":memory:") + try: + connection.execute("CREATE TABLE staging (id INTEGER, description TEXT)") + connection.execute("INSERT INTO staging (id, description) VALUES (42, 'legacy')") + + driver = SqliteDriver( + connection=connection, + statement_config=sqlite_statement_config, + driver_features={"storage_capabilities": CAPABILITIES}, + ) + arrow_table = pa.table({"id": [10, 11], "description": ["north", "south"]}) + + job = driver.load_from_arrow("staging", arrow_table, overwrite=True) + + rows = connection.execute("SELECT id, description FROM staging ORDER BY id").fetchall() + normalized_rows = [tuple(row) for row in rows] + assert normalized_rows == [(10, "north"), (11, "south")] + assert job.telemetry["rows_processed"] == arrow_table.num_rows + finally: + connection.close() + + +def test_sqlite_load_from_storage_merges_source(monkeypatch: pytest.MonkeyPatch) -> None: + connection = sqlite3.connect(":memory:") + try: + connection.execute("CREATE TABLE metrics (val INTEGER)") + + driver = SqliteDriver( + connection=connection, + statement_config=sqlite_statement_config, + driver_features={"storage_capabilities": CAPABILITIES}, + ) + arrow_table = pa.table({"val": [99]}) + + def _fake_read(self, *_: object, **__: object) -> tuple[pa.Table, dict[str, object]]: + return arrow_table, {"destination": "s3://bucket/segment.parquet", "bytes_processed": 32} + + monkeypatch.setattr(SqliteDriver, "_read_arrow_from_storage_sync", _fake_read) + + job = driver.load_from_storage("metrics", "s3://bucket/segment.parquet", file_format="parquet") + + rows = connection.execute("SELECT val FROM metrics").fetchall() + normalized_rows = [tuple(row) for row in rows] + assert normalized_rows == [(99,)] + assert job.telemetry["extra"]["source"]["destination"] == "s3://bucket/segment.parquet" + finally: + connection.close() + + +@pytest.mark.asyncio +async def test_asyncmy_load_from_arrow_overwrite() -> None: + connection = DummyAsyncmyConnection() + driver = AsyncmyDriver( + connection=cast(AsyncmyConnection, connection), + statement_config=asyncmy_statement_config, + driver_features={"storage_capabilities": CAPABILITIES}, + ) + arrow_table = pa.table({"id": [3], "score": [9.5]}) + + job = await driver.load_from_arrow("analytics.scores", arrow_table, overwrite=True) + + assert connection.operations[0][1].startswith("TRUNCATE TABLE `analytics`.`scores`") + assert connection.operations[1][0] == "executemany" + assert job.telemetry["destination"] == "analytics.scores" + + +@pytest.mark.asyncio +async def test_asyncmy_load_from_storage_merges_source(monkeypatch: pytest.MonkeyPatch) -> None: + connection = DummyAsyncmyConnection() + driver = AsyncmyDriver( + connection=cast(AsyncmyConnection, connection), + statement_config=asyncmy_statement_config, + driver_features={"storage_capabilities": CAPABILITIES}, + ) + arrow_table = pa.table({"id": [11], "score": [8.2]}) + + async def _fake_read(self, *_: object, **__: object) -> tuple[pa.Table, dict[str, object]]: + return arrow_table, {"destination": "s3://bucket/segment.parquet", "bytes_processed": 48, "backend": "fsspec"} + + monkeypatch.setattr(AsyncmyDriver, "_read_arrow_from_storage_async", _fake_read) + + job = await driver.load_from_storage("analytics.scores", "s3://bucket/segment.parquet", file_format="parquet") + + assert job.telemetry["extra"]["source"]["backend"] == "fsspec" + + +def test_sync_pipeline_write_rows_includes_backend(monkeypatch: pytest.MonkeyPatch) -> None: + pipeline = SyncStoragePipeline() + + class _Backend: + backend_type = "test-backend" + + def __init__(self) -> None: + self.payloads: list[tuple[str, bytes]] = [] + + def write_bytes(self, path: str, payload: bytes) -> None: + self.payloads.append((path, payload)) + + backend = _Backend() + + def _fake_resolve(self: SyncStoragePipeline, destination: "StorageDestination", **_: Any) -> tuple[_Backend, str]: + return backend, "objects/data.jsonl" + + monkeypatch.setattr(SyncStoragePipeline, "_resolve_backend", _fake_resolve) + + telemetry = pipeline.write_rows([{"id": 1}], "alias://data") + assert telemetry["backend"] == "test-backend" + + +def test_sync_pipeline_supports_alias_destinations(tmp_path: "Path") -> None: + storage_registry.clear() + alias_name = "storage_bridge_unit_tests" + storage_registry.register_alias(alias_name, f"file://{tmp_path}", backend="local") + pipeline = SyncStoragePipeline() + + telemetry = pipeline.write_rows([{"id": 1}], f"alias://{alias_name}/payload.jsonl") + + assert telemetry["destination"].endswith("payload.jsonl") + storage_registry.clear() + + +def test_storage_bridge_diagnostics_include_serializer_metrics() -> None: + reset_storage_bridge_metrics() + reset_serializer_cache() + serialize_collection([{"id": 1}]) + diagnostics = get_storage_bridge_diagnostics() + assert "serializer.size" in diagnostics