diff --git a/bigframes/core/tree_properties.py b/bigframes/core/tree_properties.py index baf4b12566..5f713450f7 100644 --- a/bigframes/core/tree_properties.py +++ b/bigframes/core/tree_properties.py @@ -15,10 +15,13 @@ import functools import itertools -from typing import Callable, Dict, Optional, Sequence +from typing import Callable, Dict, Optional, Sequence, TYPE_CHECKING import bigframes.core.nodes as nodes +if TYPE_CHECKING: + import bigframes.session.execution_cache as execution_cache + def is_trivially_executable(node: nodes.BigFrameNode) -> bool: if local_only(node): @@ -65,7 +68,7 @@ def select_cache_target( root: nodes.BigFrameNode, min_complexity: float, max_complexity: float, - cache: dict[nodes.BigFrameNode, nodes.BigFrameNode], + cache: execution_cache.ExecutionCache, heuristic: Callable[[int, int], float], ) -> Optional[nodes.BigFrameNode]: """Take tree, and return candidate nodes with (# of occurences, post-caching planning complexity). @@ -75,7 +78,7 @@ def select_cache_target( @functools.cache def _with_caching(subtree: nodes.BigFrameNode) -> nodes.BigFrameNode: - return nodes.top_down(subtree, lambda x: cache.get(x, x)) + return cache.subsitute_cached_subplans(subtree) def _combine_counts( left: Dict[nodes.BigFrameNode, int], right: Dict[nodes.BigFrameNode, int] @@ -106,6 +109,7 @@ def _node_counts_inner( if len(node_counts) == 0: raise ValueError("node counts should be non-zero") + # for each considered node, calculate heuristic value, and return node with max value return max( node_counts.keys(), key=lambda node: heuristic( diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 7ea6e99954..0a2f2db189 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -265,15 +265,20 @@ def __init__( metrics=self._metrics, publisher=self._publisher, ) + + labels = {} + if not self._strictly_ordered: + labels["bigframes-mode"] = "unordered" + self._executor: executor.Executor = bq_caching_executor.BigQueryCachingExecutor( bqclient=self._clients_provider.bqclient, bqstoragereadclient=self._clients_provider.bqstoragereadclient, loader=self._loader, storage_manager=self._temp_storage_manager, - strictly_ordered=self._strictly_ordered, metrics=self._metrics, enable_polars_execution=context.enable_polars_execution, publisher=self._publisher, + labels=labels, ) def __del__(self): diff --git a/bigframes/session/bq_caching_executor.py b/bigframes/session/bq_caching_executor.py index fbcdfd33f5..cf275154ce 100644 --- a/bigframes/session/bq_caching_executor.py +++ b/bigframes/session/bq_caching_executor.py @@ -17,7 +17,6 @@ import math import threading from typing import Literal, Mapping, Optional, Sequence, Tuple -import weakref import google.api_core.exceptions from google.cloud import bigquery @@ -47,6 +46,7 @@ semi_executor, ) import bigframes.session._io.bigquery as bq_io +import bigframes.session.execution_cache as execution_cache import bigframes.session.execution_spec as ex_spec import bigframes.session.metrics import bigframes.session.planner @@ -59,58 +59,6 @@ _MAX_CLUSTER_COLUMNS = 4 MAX_SMALL_RESULT_BYTES = 10 * 1024 * 1024 * 1024 # 10G -SourceIdMapping = Mapping[str, str] - - -class ExecutionCache: - def __init__(self): - # current assumption is only 1 cache of a given node - # in future, might have multiple caches, with different layout, localities - self._cached_executions: weakref.WeakKeyDictionary[ - nodes.BigFrameNode, nodes.CachedTableNode - ] = weakref.WeakKeyDictionary() - self._uploaded_local_data: weakref.WeakKeyDictionary[ - local_data.ManagedArrowTable, - tuple[bq_data.BigqueryDataSource, SourceIdMapping], - ] = weakref.WeakKeyDictionary() - - @property - def mapping(self) -> Mapping[nodes.BigFrameNode, nodes.BigFrameNode]: - return self._cached_executions - - def cache_results_table( - self, - original_root: nodes.BigFrameNode, - data: bq_data.BigqueryDataSource, - ): - # Assumption: GBQ cached table uses field name as bq column name - scan_list = nodes.ScanList( - tuple( - nodes.ScanItem(field.id, field.id.sql) for field in original_root.fields - ) - ) - cached_replacement = nodes.CachedTableNode( - source=data, - scan_list=scan_list, - table_session=original_root.session, - original_node=original_root, - ) - assert original_root.schema == cached_replacement.schema - self._cached_executions[original_root] = cached_replacement - - def cache_remote_replacement( - self, - local_data: local_data.ManagedArrowTable, - bq_data: bq_data.BigqueryDataSource, - ): - # bq table has one extra column for offsets, those are implicit for local data - assert len(local_data.schema.items) + 1 == len(bq_data.table.physical_schema) - mapping = { - local_data.schema.items[i].column: bq_data.table.physical_schema[i].name - for i in range(len(local_data.schema)) - } - self._uploaded_local_data[local_data] = (bq_data, mapping) - class BigQueryCachingExecutor(executor.Executor): """Computes BigFrames values using BigQuery Engine. @@ -128,20 +76,20 @@ def __init__( bqstoragereadclient: google.cloud.bigquery_storage_v1.BigQueryReadClient, loader: loader.GbqDataLoader, *, - strictly_ordered: bool = True, metrics: Optional[bigframes.session.metrics.ExecutionMetrics] = None, enable_polars_execution: bool = False, publisher: bigframes.core.events.Publisher, + labels: Mapping[str, str] = {}, ): self.bqclient = bqclient self.storage_manager = storage_manager - self.strictly_ordered: bool = strictly_ordered - self.cache: ExecutionCache = ExecutionCache() + self.cache: execution_cache.ExecutionCache = execution_cache.ExecutionCache() self.metrics = metrics self.loader = loader self.bqstoragereadclient = bqstoragereadclient self._enable_polars_execution = enable_polars_execution self._publisher = publisher + self._labels = labels # TODO(tswast): Send events from semi-executors, too. self._semi_executors: Sequence[semi_executor.SemiExecutor] = ( @@ -410,8 +358,8 @@ def _run_execute_query( bigframes.options.compute.maximum_bytes_billed ) - if not self.strictly_ordered: - job_config.labels["bigframes-mode"] = "unordered" + if self._labels: + job_config.labels.update(self._labels) try: # Trick the type checker into thinking we got a literal. @@ -450,9 +398,6 @@ def _run_execute_query( else: raise - def replace_cached_subtrees(self, node: nodes.BigFrameNode) -> nodes.BigFrameNode: - return nodes.top_down(node, lambda x: self.cache.mapping.get(x, x)) - def _is_trivially_executable(self, array_value: bigframes.core.ArrayValue): """ Can the block be evaluated very cheaply? @@ -482,7 +427,7 @@ def prepare_plan( ): self._simplify_with_caching(plan) - plan = self.replace_cached_subtrees(plan) + plan = self.cache.subsitute_cached_subplans(plan) plan = rewrite.column_pruning(plan) plan = plan.top_down(rewrite.fold_row_counts) @@ -527,7 +472,7 @@ def _cache_with_session_awareness( self._cache_with_cluster_cols( bigframes.core.ArrayValue(target), cluster_cols_sql_names ) - elif self.strictly_ordered: + elif not target.order_ambiguous: self._cache_with_offsets(bigframes.core.ArrayValue(target)) else: self._cache_with_cluster_cols(bigframes.core.ArrayValue(target), []) @@ -552,7 +497,7 @@ def _cache_most_complex_subtree(self, node: nodes.BigFrameNode) -> bool: node, min_complexity=(QUERY_COMPLEXITY_LIMIT / 500), max_complexity=QUERY_COMPLEXITY_LIMIT, - cache=dict(self.cache.mapping), + cache=self.cache, # Heuristic: subtree_compleixty * (copies of subtree)^2 heuristic=lambda complexity, count: math.log(complexity) + 2 * math.log(count), @@ -581,32 +526,37 @@ def _substitute_large_local_sources(self, original_root: nodes.BigFrameNode): def map_local_scans(node: nodes.BigFrameNode): if not isinstance(node, nodes.ReadLocalNode): return node - if node.local_data_source not in self.cache._uploaded_local_data: - return node - bq_source, source_mapping = self.cache._uploaded_local_data[ + uploaded_local_data = self.cache.get_uploaded_local_data( node.local_data_source - ] - scan_list = node.scan_list.remap_source_ids(source_mapping) + ) + if uploaded_local_data is None: + return node + + scan_list = node.scan_list.remap_source_ids( + uploaded_local_data.source_mapping + ) # offsets_col isn't part of ReadTableNode, so emulate by adding to end of scan_list if node.offsets_col is not None: # Offsets are always implicitly the final column of uploaded data # See: Loader.load_data scan_list = scan_list.append( - bq_source.table.physical_schema[-1].name, + uploaded_local_data.bq_source.table.physical_schema[-1].name, bigframes.dtypes.INT_DTYPE, node.offsets_col, ) - return nodes.ReadTableNode(bq_source, scan_list, node.session) + return nodes.ReadTableNode( + uploaded_local_data.bq_source, scan_list, node.session + ) return original_root.bottom_up(map_local_scans) def _upload_local_data(self, local_table: local_data.ManagedArrowTable): - if local_table in self.cache._uploaded_local_data: + if self.cache.get_uploaded_local_data(local_table) is not None: return # Lock prevents concurrent repeated work, but slows things down. # Might be better as a queue and a worker thread with self._upload_lock: - if local_table not in self.cache._uploaded_local_data: + if self.cache.get_uploaded_local_data(local_table) is None: uploaded = self.loader.load_data_or_write_data( local_table, bigframes.core.guid.generate_guid() ) diff --git a/bigframes/session/execution_cache.py b/bigframes/session/execution_cache.py new file mode 100644 index 0000000000..782a1c5c4e --- /dev/null +++ b/bigframes/session/execution_cache.py @@ -0,0 +1,88 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import dataclasses +from typing import Mapping, Optional +import weakref + +from bigframes.core import bq_data, local_data, nodes + +SourceIdMapping = Mapping[str, str] + + +@dataclasses.dataclass(frozen=True) +class UploadedLocalData: + bq_source: bq_data.BigqueryDataSource + source_mapping: SourceIdMapping + + +class ExecutionCache: + def __init__(self): + # effectively two separate caches that don't interact + self._cached_executions: weakref.WeakKeyDictionary[ + nodes.BigFrameNode, bq_data.BigqueryDataSource + ] = weakref.WeakKeyDictionary() + # This upload cache is entirely independent of the plan cache. + self._uploaded_local_data: weakref.WeakKeyDictionary[ + local_data.ManagedArrowTable, + UploadedLocalData, + ] = weakref.WeakKeyDictionary() + + def subsitute_cached_subplans(self, root: nodes.BigFrameNode) -> nodes.BigFrameNode: + def replace_if_cached(node: nodes.BigFrameNode) -> nodes.BigFrameNode: + if node not in self._cached_executions: + return node + # Assumption: GBQ cached table uses field name as bq column name + scan_list = nodes.ScanList( + tuple(nodes.ScanItem(field.id, field.id.sql) for field in node.fields) + ) + bq_data = self._cached_executions[node] + cached_replacement = nodes.CachedTableNode( + source=bq_data, + scan_list=scan_list, + table_session=node.session, + original_node=node, + ) + assert node.schema == cached_replacement.schema + return cached_replacement + + return nodes.top_down(root, replace_if_cached) + + def cache_results_table( + self, + original_root: nodes.BigFrameNode, + data: bq_data.BigqueryDataSource, + ): + self._cached_executions[original_root] = data + + ## Local data upload caching + def cache_remote_replacement( + self, + local_data: local_data.ManagedArrowTable, + bq_data: bq_data.BigqueryDataSource, + ): + # bq table has one extra column for offsets, those are implicit for local data + assert len(local_data.schema.items) + 1 == len(bq_data.table.physical_schema) + mapping = { + local_data.schema.items[i].column: bq_data.table.physical_schema[i].name + for i in range(len(local_data.schema)) + } + self._uploaded_local_data[local_data] = UploadedLocalData(bq_data, mapping) + + def get_uploaded_local_data( + self, local_data: local_data.ManagedArrowTable + ) -> Optional[UploadedLocalData]: + return self._uploaded_local_data.get(local_data)