diff --git a/bigframes/bigquery/_operations/io.py b/bigframes/bigquery/_operations/io.py index daf28e6aed..6effbdb257 100644 --- a/bigframes/bigquery/_operations/io.py +++ b/bigframes/bigquery/_operations/io.py @@ -19,8 +19,8 @@ import pandas as pd from bigframes.bigquery._operations.table import _get_table_metadata +import bigframes.core.compile.sqlglot.sql as sql import bigframes.core.logging.log_adapter as log_adapter -import bigframes.core.sql.io import bigframes.session @@ -73,7 +73,7 @@ def load_data( """ import bigframes.pandas as bpd - sql = bigframes.core.sql.io.load_data_ddl( + load_data_expr = sql.load_data( table_name=table_name, write_disposition=write_disposition, columns=columns, @@ -84,11 +84,12 @@ def load_data( with_partition_columns=with_partition_columns, connection_name=connection_name, ) + sql_text = sql.to_sql(load_data_expr) if session is None: - bpd.read_gbq_query(sql) + bpd.read_gbq_query(sql_text) session = bpd.get_global_session() else: - session.read_gbq_query(sql) + session.read_gbq_query(sql_text) return _get_table_metadata(bqclient=session.bqclient, table_name=table_name) diff --git a/bigframes/core/compile/sqlglot/sql/__init__.py b/bigframes/core/compile/sqlglot/sql/__init__.py index 6d2dbd65a6..17c78ba379 100644 --- a/bigframes/core/compile/sqlglot/sql/__init__.py +++ b/bigframes/core/compile/sqlglot/sql/__init__.py @@ -22,6 +22,7 @@ table, to_sql, ) +from bigframes.core.compile.sqlglot.sql.ddl import load_data from bigframes.core.compile.sqlglot.sql.dml import insert, replace __all__ = [ @@ -33,6 +34,8 @@ "literal", "table", "to_sql", + # From ddl.py + "load_data", # From dml.py "insert", "replace", diff --git a/bigframes/core/compile/sqlglot/sql/ddl.py b/bigframes/core/compile/sqlglot/sql/ddl.py new file mode 100644 index 0000000000..911c63781b --- /dev/null +++ b/bigframes/core/compile/sqlglot/sql/ddl.py @@ -0,0 +1,164 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Mapping, Optional, Union + +import bigframes_vendored.sqlglot as sg +import bigframes_vendored.sqlglot.expressions as sge + +from bigframes.core.compile.sqlglot.sql import base + + +def _loaddata_sql(self: sg.Generator, expression: sge.LoadData) -> str: + out = ["LOAD DATA"] + if expression.args.get("overwrite"): + out.append("OVERWRITE") + + out.append(f"INTO {self.sql(expression, 'this').strip()}") + + # We ignore inpath as it's just a dummy to satisfy sqlglot requirements + # but BigQuery uses FROM FILES instead. + + columns = self.sql(expression, "columns").strip() + if columns: + out.append(columns) + + partition_by = self.sql(expression, "partition_by").strip() + if partition_by: + out.append(partition_by) + + cluster_by = self.sql(expression, "cluster_by").strip() + if cluster_by: + out.append(cluster_by) + + options = self.sql(expression, "options").strip() + if options: + out.append(options) + + from_files = self.sql(expression, "from_files").strip() + if from_files: + out.append(f"FROM FILES {from_files}") + + with_partition_columns = self.sql(expression, "with_partition_columns").strip() + if with_partition_columns: + out.append(f"WITH PARTITION COLUMNS {with_partition_columns}") + + connection = self.sql(expression, "connection").strip() + if connection: + out.append(f"WITH CONNECTION {connection}") + + return " ".join(out) + + +# Register the transform for BigQuery generator +sg.dialects.bigquery.BigQuery.Generator.TRANSFORMS[sge.LoadData] = _loaddata_sql + + +def load_data( + table_name: str, + *, + write_disposition: str = "INTO", + columns: Optional[Mapping[str, str]] = None, + partition_by: Optional[list[str]] = None, + cluster_by: Optional[list[str]] = None, + table_options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None, + from_files_options: Mapping[str, Union[str, int, float, bool, list]], + with_partition_columns: Optional[Mapping[str, str]] = None, + connection_name: Optional[str] = None, +) -> sge.LoadData: + """Generates the LOAD DATA DDL statement.""" + # We use a Table with a simple identifier for the table name. + # Quoting is handled by the dialect. + table_expr = sge.Table(this=base.identifier(table_name)) + + sge_columns = ( + sge.Schema( + this=None, + expressions=[ + sge.ColumnDef( + this=base.identifier(name), + kind=sge.DataType.build(typ, dialect="bigquery"), + ) + for name, typ in columns.items() + ], + ) + if columns + else None + ) + + sge_partition_by = ( + sge.PartitionedByProperty( + this=base.identifier(partition_by[0]) + if len(partition_by) == 1 + else sge.Tuple(expressions=[base.identifier(col) for col in partition_by]) + ) + if partition_by + else None + ) + + sge_cluster_by = ( + sge.Cluster(expressions=[base.identifier(col) for col in cluster_by]) + if cluster_by + else None + ) + + sge_table_options = ( + sge.Properties( + expressions=[ + sge.Property(this=base.identifier(k), value=base.literal(v)) + for k, v in table_options.items() + ] + ) + if table_options + else None + ) + + sge_from_files = sge.Tuple( + expressions=[ + sge.Property(this=base.identifier(k), value=base.literal(v)) + for k, v in from_files_options.items() + ] + ) + + sge_with_partition_columns = ( + sge.Schema( + this=None, + expressions=[ + sge.ColumnDef( + this=base.identifier(name), + kind=sge.DataType.build(typ, dialect="bigquery"), + ) + for name, typ in with_partition_columns.items() + ], + ) + if with_partition_columns + else None + ) + + sge_connection = base.identifier(connection_name) if connection_name else None + + return sge.LoadData( + this=table_expr, + overwrite=(write_disposition == "OVERWRITE"), + inpath=sge.convert("fake"), # satisfy sqlglot's required inpath arg + columns=sge_columns, + partition_by=sge_partition_by, + cluster_by=sge_cluster_by, + options=sge_table_options, + from_files=sge_from_files, + with_partition_columns=sge_with_partition_columns, + connection=sge_connection, + ) diff --git a/bigframes/core/sql/io.py b/bigframes/core/sql/io.py deleted file mode 100644 index 9e1a549a64..0000000000 --- a/bigframes/core/sql/io.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import Mapping, Optional, Union - - -def load_data_ddl( - table_name: str, - *, - write_disposition: str = "INTO", - columns: Optional[Mapping[str, str]] = None, - partition_by: Optional[list[str]] = None, - cluster_by: Optional[list[str]] = None, - table_options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None, - from_files_options: Mapping[str, Union[str, int, float, bool, list]], - with_partition_columns: Optional[Mapping[str, str]] = None, - connection_name: Optional[str] = None, -) -> str: - """Generates the LOAD DATA DDL statement.""" - statement = ["LOAD DATA"] - statement.append(write_disposition) - statement.append(table_name) - - if columns: - column_defs = ", ".join([f"{name} {typ}" for name, typ in columns.items()]) - statement.append(f"({column_defs})") - - if partition_by: - statement.append(f"PARTITION BY {', '.join(partition_by)}") - - if cluster_by: - statement.append(f"CLUSTER BY {', '.join(cluster_by)}") - - if table_options: - opts = [] - for key, value in table_options.items(): - if isinstance(value, str): - value_sql = repr(value) - opts.append(f"{key} = {value_sql}") - elif isinstance(value, bool): - opts.append(f"{key} = {str(value).upper()}") - elif isinstance(value, list): - list_str = ", ".join([repr(v) for v in value]) - opts.append(f"{key} = [{list_str}]") - else: - opts.append(f"{key} = {value}") - options_str = ", ".join(opts) - statement.append(f"OPTIONS ({options_str})") - - opts = [] - for key, value in from_files_options.items(): - if isinstance(value, str): - value_sql = repr(value) - opts.append(f"{key} = {value_sql}") - elif isinstance(value, bool): - opts.append(f"{key} = {str(value).upper()}") - elif isinstance(value, list): - list_str = ", ".join([repr(v) for v in value]) - opts.append(f"{key} = [{list_str}]") - else: - opts.append(f"{key} = {value}") - options_str = ", ".join(opts) - statement.append(f"FROM FILES ({options_str})") - - if with_partition_columns: - part_defs = ", ".join( - [f"{name} {typ}" for name, typ in with_partition_columns.items()] - ) - statement.append(f"WITH PARTITION COLUMNS ({part_defs})") - - if connection_name: - statement.append(f"WITH CONNECTION `{connection_name}`") - - return " ".join(statement) diff --git a/tests/unit/bigquery/_operations/test_io.py b/tests/unit/bigquery/_operations/test_io.py index 97b38f8649..b5dc9544aa 100644 --- a/tests/unit/bigquery/_operations/test_io.py +++ b/tests/unit/bigquery/_operations/test_io.py @@ -17,7 +17,6 @@ import pytest import bigframes.bigquery._operations.io -import bigframes.core.sql.io import bigframes.session @@ -36,6 +35,6 @@ def test_load_data(get_table_metadata_mock, mock_session): ) mock_session.read_gbq_query.assert_called_once() generated_sql = mock_session.read_gbq_query.call_args[0][0] - expected = "LOAD DATA INTO my-project.my_dataset.my_table (col1 INT64, col2 STRING) FROM FILES (format = 'CSV', uris = ['gs://bucket/path*'])" + expected = "LOAD DATA INTO `my-project.my_dataset.my_table` (\n `col1` INT64,\n `col2` STRING\n) FROM FILES (format='CSV', uris=['gs://bucket/path*'])" assert generated_sql == expected get_table_metadata_mock.assert_called_once() diff --git a/tests/unit/core/compile/sqlglot/sql/snapshots/test_ddl/test_load_data_all_options/out.sql b/tests/unit/core/compile/sqlglot/sql/snapshots/test_ddl/test_load_data_all_options/out.sql new file mode 100644 index 0000000000..781019a068 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/sql/snapshots/test_ddl/test_load_data_all_options/out.sql @@ -0,0 +1,10 @@ +LOAD DATA OVERWRITE INTO `my-project.my_dataset.my_table` ( + `col1` INT64, + `col2` STRING +) PARTITION BY `date_col` CLUSTER BY + `cluster_col` OPTIONS ( + description='my table' +) FROM FILES (format='CSV', uris=['gs://bucket/path*']) WITH PARTITION COLUMNS ( + `part1` DATE, + `part2` STRING +) WITH CONNECTION `my-connection` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/sql/snapshots/test_ddl/test_load_data_minimal/out.sql b/tests/unit/core/compile/sqlglot/sql/snapshots/test_ddl/test_load_data_minimal/out.sql new file mode 100644 index 0000000000..c5f6600325 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/sql/snapshots/test_ddl/test_load_data_minimal/out.sql @@ -0,0 +1 @@ +LOAD DATA INTO `my-project.my_dataset.my_table` FROM FILES (format='CSV', uris=['gs://bucket/path*']) \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_dml/test_insert_from_select/out.sql b/tests/unit/core/compile/sqlglot/sql/snapshots/test_dml/test_insert_from_select/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/snapshots/test_dml/test_insert_from_select/out.sql rename to tests/unit/core/compile/sqlglot/sql/snapshots/test_dml/test_insert_from_select/out.sql diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_dml/test_insert_from_table/out.sql b/tests/unit/core/compile/sqlglot/sql/snapshots/test_dml/test_insert_from_table/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/snapshots/test_dml/test_insert_from_table/out.sql rename to tests/unit/core/compile/sqlglot/sql/snapshots/test_dml/test_insert_from_table/out.sql diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_dml/test_replace_from_select/out.sql b/tests/unit/core/compile/sqlglot/sql/snapshots/test_dml/test_replace_from_select/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/snapshots/test_dml/test_replace_from_select/out.sql rename to tests/unit/core/compile/sqlglot/sql/snapshots/test_dml/test_replace_from_select/out.sql diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_dml/test_replace_from_table/out.sql b/tests/unit/core/compile/sqlglot/sql/snapshots/test_dml/test_replace_from_table/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/snapshots/test_dml/test_replace_from_table/out.sql rename to tests/unit/core/compile/sqlglot/sql/snapshots/test_dml/test_replace_from_table/out.sql diff --git a/tests/unit/core/compile/sqlglot/sql/test_ddl.py b/tests/unit/core/compile/sqlglot/sql/test_ddl.py new file mode 100644 index 0000000000..14d3708883 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/sql/test_ddl.py @@ -0,0 +1,42 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import bigframes.core.compile.sqlglot.sql as sql + +pytest.importorskip("pytest_snapshot") + + +def test_load_data_minimal(snapshot): + expr = sql.load_data( + "my-project.my_dataset.my_table", + from_files_options={"format": "CSV", "uris": ["gs://bucket/path*"]}, + ) + snapshot.assert_match(sql.to_sql(expr), "out.sql") + + +def test_load_data_all_options(snapshot): + expr = sql.load_data( + "my-project.my_dataset.my_table", + write_disposition="OVERWRITE", + columns={"col1": "INT64", "col2": "STRING"}, + partition_by=["date_col"], + cluster_by=["cluster_col"], + table_options={"description": "my table"}, + from_files_options={"format": "CSV", "uris": ["gs://bucket/path*"]}, + with_partition_columns={"part1": "DATE", "part2": "STRING"}, + connection_name="my-connection", + ) + snapshot.assert_match(sql.to_sql(expr), "out.sql") diff --git a/tests/unit/core/compile/sqlglot/test_dml.py b/tests/unit/core/compile/sqlglot/sql/test_dml.py similarity index 100% rename from tests/unit/core/compile/sqlglot/test_dml.py rename to tests/unit/core/compile/sqlglot/sql/test_dml.py diff --git a/tests/unit/core/sql/test_io.py b/tests/unit/core/sql/test_io.py deleted file mode 100644 index 23e5f796e3..0000000000 --- a/tests/unit/core/sql/test_io.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import bigframes.core.sql.io - - -def test_load_data_ddl(): - sql = bigframes.core.sql.io.load_data_ddl( - "my-project.my_dataset.my_table", - columns={"col1": "INT64", "col2": "STRING"}, - from_files_options={"format": "CSV", "uris": ["gs://bucket/path*"]}, - ) - expected = "LOAD DATA INTO my-project.my_dataset.my_table (col1 INT64, col2 STRING) FROM FILES (format = 'CSV', uris = ['gs://bucket/path*'])" - assert sql == expected - - -def test_load_data_ddl_overwrite(): - sql = bigframes.core.sql.io.load_data_ddl( - "my-project.my_dataset.my_table", - write_disposition="OVERWRITE", - columns={"col1": "INT64", "col2": "STRING"}, - from_files_options={"format": "CSV", "uris": ["gs://bucket/path*"]}, - ) - expected = "LOAD DATA OVERWRITE my-project.my_dataset.my_table (col1 INT64, col2 STRING) FROM FILES (format = 'CSV', uris = ['gs://bucket/path*'])" - assert sql == expected - - -def test_load_data_ddl_with_partition_columns(): - sql = bigframes.core.sql.io.load_data_ddl( - "my-project.my_dataset.my_table", - columns={"col1": "INT64", "col2": "STRING"}, - with_partition_columns={"part1": "DATE", "part2": "STRING"}, - from_files_options={"format": "CSV", "uris": ["gs://bucket/path*"]}, - ) - expected = "LOAD DATA INTO my-project.my_dataset.my_table (col1 INT64, col2 STRING) FROM FILES (format = 'CSV', uris = ['gs://bucket/path*']) WITH PARTITION COLUMNS (part1 DATE, part2 STRING)" - assert sql == expected - - -def test_load_data_ddl_connection(): - sql = bigframes.core.sql.io.load_data_ddl( - "my-project.my_dataset.my_table", - columns={"col1": "INT64", "col2": "STRING"}, - connection_name="my-connection", - from_files_options={"format": "CSV", "uris": ["gs://bucket/path*"]}, - ) - expected = "LOAD DATA INTO my-project.my_dataset.my_table (col1 INT64, col2 STRING) FROM FILES (format = 'CSV', uris = ['gs://bucket/path*']) WITH CONNECTION `my-connection`" - assert sql == expected - - -def test_load_data_ddl_partition_by(): - sql = bigframes.core.sql.io.load_data_ddl( - "my-project.my_dataset.my_table", - columns={"col1": "INT64", "col2": "STRING"}, - partition_by=["date_col"], - from_files_options={"format": "CSV", "uris": ["gs://bucket/path*"]}, - ) - expected = "LOAD DATA INTO my-project.my_dataset.my_table (col1 INT64, col2 STRING) PARTITION BY date_col FROM FILES (format = 'CSV', uris = ['gs://bucket/path*'])" - assert sql == expected - - -def test_load_data_ddl_cluster_by(): - sql = bigframes.core.sql.io.load_data_ddl( - "my-project.my_dataset.my_table", - columns={"col1": "INT64", "col2": "STRING"}, - cluster_by=["cluster_col"], - from_files_options={"format": "CSV", "uris": ["gs://bucket/path*"]}, - ) - expected = "LOAD DATA INTO my-project.my_dataset.my_table (col1 INT64, col2 STRING) CLUSTER BY cluster_col FROM FILES (format = 'CSV', uris = ['gs://bucket/path*'])" - assert sql == expected - - -def test_load_data_ddl_table_options(): - sql = bigframes.core.sql.io.load_data_ddl( - "my-project.my_dataset.my_table", - columns={"col1": "INT64", "col2": "STRING"}, - table_options={"description": "my table"}, - from_files_options={"format": "CSV", "uris": ["gs://bucket/path*"]}, - ) - expected = "LOAD DATA INTO my-project.my_dataset.my_table (col1 INT64, col2 STRING) OPTIONS (description = 'my table') FROM FILES (format = 'CSV', uris = ['gs://bucket/path*'])" - assert sql == expected diff --git a/third_party/bigframes_vendored/sqlglot/dialects/dialect.py b/third_party/bigframes_vendored/sqlglot/dialects/dialect.py index 8dbb5c3f1c..3341f3fa57 100644 --- a/third_party/bigframes_vendored/sqlglot/dialects/dialect.py +++ b/third_party/bigframes_vendored/sqlglot/dialects/dialect.py @@ -166,7 +166,7 @@ def _try_load(cls, key: str | Dialects) -> None: # files. Custom user dialects need to be imported at the top-level package, in # order for them to be registered as soon as possible. if key in DIALECT_MODULE_NAMES: - importlib.import_module(f"sqlglot.dialects.{key}") + importlib.import_module(f"bigframes_vendored.sqlglot.dialects.{key}") @classmethod def __getitem__(cls, key: str) -> t.Type[Dialect]: