Skip to content

Commit d949d71

Browse files
committed
chore: align async support with upstream/main after rebase
- Port TLS/mTLS and experimental host support to AsyncClient - Port enable_interceptors_in_tests to AsyncInstance.database - Regenerate synchronous code via CrossSync - Fix noxfile.py for pytest-asyncio compatibility and test isolation - Add comprehensive asynchronous system tests
1 parent 9192b4f commit d949d71

File tree

16 files changed

+542
-124
lines changed

16 files changed

+542
-124
lines changed

google/cloud/spanner_v1/_async/_helpers.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,62 @@ async def _retry(
5151
before_next_retry(retries, delay)
5252
await asyncio.sleep(delay)
5353
retries += 1
54+
55+
def _create_experimental_host_transport(
56+
transport_factory,
57+
experimental_host,
58+
use_plain_text,
59+
ca_certificate,
60+
client_certificate,
61+
client_key,
62+
interceptors=None,
63+
):
64+
"""Creates an experimental host transport for Spanner in async mode.
65+
66+
Args:
67+
transport_factory (type): The transport class to instantiate (e.g.
68+
`SpannerGrpcAsyncIOTransport`).
69+
experimental_host (str): The endpoint for the experimental host.
70+
use_plain_text (bool): Whether to use a plain text (insecure) connection.
71+
ca_certificate (str): Path to the CA certificate file for TLS.
72+
client_certificate (str): Path to the client certificate file for mTLS.
73+
client_key (str): Path to the client key file for mTLS.
74+
interceptors (list): Optional list of interceptors to add to the channel.
75+
76+
Returns:
77+
object: An instance of the transport class created by `transport_factory`.
78+
79+
Raises:
80+
ValueError: If TLS/mTLS configuration is invalid.
81+
"""
82+
import grpc.aio
83+
from google.auth.credentials import AnonymousCredentials
84+
85+
channel = None
86+
if use_plain_text:
87+
channel = grpc.aio.insecure_channel(target=experimental_host, interceptors=interceptors)
88+
elif ca_certificate:
89+
with open(ca_certificate, "rb") as f:
90+
ca_cert = f.read()
91+
if client_certificate and client_key:
92+
with open(client_certificate, "rb") as f:
93+
client_cert = f.read()
94+
with open(client_key, "rb") as f:
95+
private_key = f.read()
96+
ssl_creds = grpc.ssl_channel_credentials(
97+
root_certificates=ca_cert,
98+
private_key=private_key,
99+
certificate_chain=client_cert,
100+
)
101+
elif client_certificate or client_key:
102+
raise ValueError(
103+
"Both client_certificate and client_key must be provided for mTLS connection"
104+
)
105+
else:
106+
ssl_creds = grpc.ssl_channel_credentials(root_certificates=ca_cert)
107+
channel = grpc.aio.secure_channel(experimental_host, ssl_creds, interceptors=interceptors)
108+
else:
109+
raise ValueError(
110+
"TLS/mTLS connection requires ca_certificate to be set for experimental_host"
111+
)
112+
return transport_factory(channel=channel, credentials=AnonymousCredentials())

google/cloud/spanner_v1/_async/client.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,10 @@ def __init__(
270270
default_transaction_options: Optional[DefaultTransactionOptions] = None,
271271
experimental_host=None,
272272
disable_builtin_metrics=False,
273+
use_plain_text=False,
274+
ca_certificate=None,
275+
client_certificate=None,
276+
client_key=None,
273277
):
274278
self._emulator_host = _get_spanner_emulator_host()
275279
self._experimental_host = experimental_host
@@ -284,6 +288,12 @@ def __init__(
284288
if self._emulator_host:
285289
credentials = AnonymousCredentials()
286290
elif self._experimental_host:
291+
# For all experimental host endpoints project is default
292+
project = "default"
293+
self._use_plain_text = use_plain_text
294+
self._ca_certificate = ca_certificate
295+
self._client_certificate = client_certificate
296+
self._client_key = client_key
287297
credentials = AnonymousCredentials()
288298
elif isinstance(credentials, AnonymousCredentials):
289299
self._emulator_host = self._client_options.api_endpoint
@@ -382,11 +392,31 @@ def instance_admin_api(self):
382392
transport=transport,
383393
)
384394
elif self._experimental_host:
395+
from google.cloud.spanner_v1._helpers import (
396+
_create_experimental_host_transport as _create_experimental_host_transport_sync,
397+
)
398+
from google.cloud.spanner_v1._async._helpers import (
399+
_create_experimental_host_transport as _create_experimental_host_transport_async,
400+
)
401+
385402
if CrossSync.is_async:
386-
channel = grpc.aio.insecure_channel(self._experimental_host)
403+
transport = _create_experimental_host_transport_async(
404+
InstanceAdminGrpcTransport,
405+
self._experimental_host,
406+
self._use_plain_text,
407+
self._ca_certificate,
408+
self._client_certificate,
409+
self._client_key,
410+
)
387411
else:
388-
channel = grpc.insecure_channel(self._experimental_host)
389-
transport = InstanceAdminGrpcTransport(channel=channel)
412+
transport = _create_experimental_host_transport_sync(
413+
InstanceAdminGrpcTransport,
414+
self._experimental_host,
415+
self._use_plain_text,
416+
self._ca_certificate,
417+
self._client_certificate,
418+
self._client_key,
419+
)
390420
self._instance_admin_api = InstanceAdminClient(
391421
client_info=self._client_info,
392422
client_options=self._client_options,
@@ -416,11 +446,31 @@ def database_admin_api(self):
416446
transport=transport,
417447
)
418448
elif self._experimental_host:
449+
from google.cloud.spanner_v1._helpers import (
450+
_create_experimental_host_transport as _create_experimental_host_transport_sync,
451+
)
452+
from google.cloud.spanner_v1._async._helpers import (
453+
_create_experimental_host_transport as _create_experimental_host_transport_async,
454+
)
455+
419456
if CrossSync.is_async:
420-
channel = grpc.aio.insecure_channel(self._experimental_host)
457+
transport = _create_experimental_host_transport_async(
458+
DatabaseAdminGrpcTransport,
459+
self._experimental_host,
460+
self._use_plain_text,
461+
self._ca_certificate,
462+
self._client_certificate,
463+
self._client_key,
464+
)
421465
else:
422-
channel = grpc.insecure_channel(self._experimental_host)
423-
transport = DatabaseAdminGrpcTransport(channel=channel)
466+
transport = _create_experimental_host_transport_sync(
467+
DatabaseAdminGrpcTransport,
468+
self._experimental_host,
469+
self._use_plain_text,
470+
self._ca_certificate,
471+
self._client_certificate,
472+
self._client_key,
473+
)
424474
self._database_admin_api = DatabaseAdminClient(
425475
client_info=self._client_info,
426476
client_options=self._client_options,

google/cloud/spanner_v1/_async/database.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -472,13 +472,31 @@ def spanner_api(self):
472472
)
473473
return self._spanner_api
474474
if self._instance.experimental_host is not None:
475+
from google.cloud.spanner_v1._helpers import (
476+
_create_experimental_host_transport as _create_experimental_host_transport_sync,
477+
)
478+
from google.cloud.spanner_v1._async._helpers import (
479+
_create_experimental_host_transport as _create_experimental_host_transport_async,
480+
)
481+
475482
if CrossSync.is_async:
476-
channel = grpc.aio.insecure_channel(
477-
self._instance.experimental_host
483+
transport = _create_experimental_host_transport_async(
484+
SpannerGrpcTransport,
485+
self._instance.experimental_host,
486+
self._instance._client._use_plain_text,
487+
self._instance._client._ca_certificate,
488+
self._instance._client._client_certificate,
489+
self._instance._client._client_key,
478490
)
479491
else:
480-
channel = grpc.insecure_channel(self._instance.experimental_host)
481-
transport = SpannerGrpcTransport(channel=channel)
492+
transport = _create_experimental_host_transport_sync(
493+
SpannerGrpcTransport,
494+
self._instance.experimental_host,
495+
self._instance._client._use_plain_text,
496+
self._instance._client._ca_certificate,
497+
self._instance._client._client_certificate,
498+
self._instance._client._client_key,
499+
)
482500
self._spanner_api = SpannerClient(
483501
client_info=client_info,
484502
transport=transport,

google/cloud/spanner_v1/_async/database_sessions_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Manage sessions for a database."""
16+
__CROSS_SYNC_OUTPUT__ = "google.cloud.spanner_v1.database_sessions_manager"
1617

1718
from datetime import timedelta
1819
from enum import Enum

google/cloud/spanner_v1/_async/testing/__init__.py

Whitespace-only changes.
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.api_core import grpc_helpers
16+
from google.api_core import grpc_helpers_async
17+
import google.auth.credentials
18+
import grpc
19+
20+
from google.cloud.aio._cross_sync import CrossSync
21+
from google.cloud.spanner_admin_database_v1 import DatabaseDialect
22+
from google.cloud.spanner_v1._helpers import _create_experimental_host_transport
23+
from google.cloud.spanner_v1._async.database import Database
24+
from google.cloud.spanner_v1.database import SPANNER_DATA_SCOPE
25+
from google.cloud.spanner_v1.services.spanner.transports import (
26+
SpannerGrpcTransport,
27+
SpannerTransport,
28+
)
29+
30+
if CrossSync.is_async:
31+
from google.cloud.spanner_v1.services.spanner.async_client import (
32+
SpannerAsyncClient as SpannerClient,
33+
)
34+
from google.cloud.spanner_v1._async.testing.interceptors import (
35+
MethodAbortAsyncInterceptor as MethodAbortInterceptor,
36+
MethodCountAsyncInterceptor as MethodCountInterceptor,
37+
XGoogRequestIDHeaderAsyncInterceptor as XGoogRequestIDHeaderInterceptor,
38+
)
39+
else:
40+
from google.cloud.spanner_v1 import SpannerClient
41+
from google.cloud.spanner_v1.testing.interceptors import (
42+
MethodAbortInterceptor,
43+
MethodCountInterceptor,
44+
XGoogRequestIDHeaderInterceptor,
45+
)
46+
47+
__CROSS_SYNC_OUTPUT__ = "google.cloud.spanner_v1.testing.database_test"
48+
49+
class TestDatabase(Database):
50+
"""Representation of a Cloud Spanner Database. This class is only used for
51+
system testing as there is no support for interceptors in grpc client
52+
currently, and we don't want to make changes in the Database class for
53+
testing purpose as this is a hack to use interceptors in tests."""
54+
55+
_interceptors = []
56+
57+
def __init__(
58+
self,
59+
database_id,
60+
instance,
61+
ddl_statements=(),
62+
pool=None,
63+
logger=None,
64+
encryption_config=None,
65+
database_dialect=DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED,
66+
database_role=None,
67+
enable_drop_protection=False,
68+
):
69+
super().__init__(
70+
database_id,
71+
instance,
72+
ddl_statements,
73+
pool,
74+
logger,
75+
encryption_config,
76+
database_dialect,
77+
database_role,
78+
enable_drop_protection,
79+
)
80+
81+
self._method_count_interceptor = MethodCountInterceptor()
82+
self._method_abort_interceptor = MethodAbortInterceptor()
83+
self._interceptors = [
84+
self._method_count_interceptor,
85+
self._method_abort_interceptor,
86+
]
87+
88+
@property
89+
def spanner_api(self):
90+
"""Helper for session-related API calls."""
91+
if self._spanner_api is None:
92+
client = self._instance._client
93+
client_info = client._client_info
94+
client_options = client._client_options
95+
if self._instance.emulator_host is not None:
96+
if CrossSync.is_async:
97+
self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor()
98+
self._interceptors.append(self._x_goog_request_id_interceptor)
99+
channel = grpc.aio.insecure_channel(
100+
self._instance.emulator_host,
101+
interceptors=self._interceptors
102+
)
103+
else:
104+
channel = grpc.insecure_channel(self._instance.emulator_host)
105+
self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor()
106+
self._interceptors.append(self._x_goog_request_id_interceptor)
107+
channel = grpc.intercept_channel(channel, *self._interceptors)
108+
109+
transport = SpannerGrpcTransport(channel=channel)
110+
self._spanner_api = SpannerClient(
111+
client_info=client_info,
112+
transport=transport,
113+
)
114+
return self._spanner_api
115+
if self._instance.experimental_host is not None:
116+
self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor()
117+
self._interceptors.append(self._x_goog_request_id_interceptor)
118+
119+
from google.cloud.spanner_v1._helpers import (
120+
_create_experimental_host_transport as _create_experimental_host_transport_sync,
121+
)
122+
from google.cloud.spanner_v1._async._helpers import (
123+
_create_experimental_host_transport as _create_experimental_host_transport_async,
124+
)
125+
126+
if CrossSync.is_async:
127+
transport = _create_experimental_host_transport_async(
128+
SpannerGrpcTransport,
129+
self._instance.experimental_host,
130+
client._use_plain_text,
131+
client._ca_certificate,
132+
client._client_certificate,
133+
client._client_key,
134+
self._interceptors,
135+
)
136+
else:
137+
transport = _create_experimental_host_transport_sync(
138+
SpannerGrpcTransport,
139+
self._instance.experimental_host,
140+
client._use_plain_text,
141+
client._ca_certificate,
142+
client._client_certificate,
143+
client._client_key,
144+
self._interceptors,
145+
)
146+
self._spanner_api = SpannerClient(
147+
client_info=client_info,
148+
transport=transport,
149+
client_options=client_options,
150+
)
151+
return self._spanner_api
152+
credentials = client.credentials
153+
if isinstance(credentials, google.auth.credentials.Scoped):
154+
credentials = credentials.with_scopes((SPANNER_DATA_SCOPE,))
155+
self._spanner_api = self._create_spanner_client_for_tests(
156+
client_options,
157+
credentials,
158+
)
159+
return self._spanner_api
160+
161+
def _create_spanner_client_for_tests(self, client_options, credentials):
162+
(
163+
api_endpoint,
164+
client_cert_source_func,
165+
) = SpannerClient.get_mtls_endpoint_and_cert_source(client_options)
166+
167+
if CrossSync.is_async:
168+
channel = grpc_helpers_async.create_channel(
169+
api_endpoint,
170+
credentials=credentials,
171+
credentials_file=client_options.credentials_file,
172+
quota_project_id=client_options.quota_project_id,
173+
default_scopes=SpannerTransport.AUTH_SCOPES,
174+
scopes=client_options.scopes,
175+
default_host=SpannerTransport.DEFAULT_HOST,
176+
interceptors=self._interceptors,
177+
)
178+
else:
179+
channel = grpc_helpers.create_channel(
180+
api_endpoint,
181+
credentials=credentials,
182+
credentials_file=client_options.credentials_file,
183+
quota_project_id=client_options.quota_project_id,
184+
default_scopes=SpannerTransport.AUTH_SCOPES,
185+
scopes=client_options.scopes,
186+
default_host=SpannerTransport.DEFAULT_HOST,
187+
)
188+
channel = grpc.intercept_channel(channel, *self._interceptors)
189+
190+
transport = SpannerGrpcTransport(channel=channel)
191+
return SpannerClient(
192+
client_options=client_options,
193+
transport=transport,
194+
)
195+
196+
def reset(self):
197+
if hasattr(self, "_x_goog_request_id_interceptor") and self._x_goog_request_id_interceptor:
198+
self._x_goog_request_id_interceptor.reset()

0 commit comments

Comments
 (0)