diff --git a/testbench/common.py b/testbench/common.py index ba955e93..0a0767d7 100644 --- a/testbench/common.py +++ b/testbench/common.py @@ -1126,6 +1126,22 @@ def get_expect_redirect_token(db, context): ) +def get_return_read_redirect_token(db, context): + return _get_grpc_instruction_match( + db, context, "storage.objects.get", retry_return_redirection_token + ) + + +def get_expect_read_redirect_token(db, context): + return _get_grpc_instruction_match( + db, + context, + "storage.objects.get", + retry_expect_redirection_token, + dequeue=False, + ) + + def handle_gzip_request(request): """ Handle gzip compressed JSON payloads when Content-Encoding: gzip is present on metadata requests. diff --git a/testbench/grpc_server.py b/testbench/grpc_server.py index 5c3e7fee..194ddcd6 100644 --- a/testbench/grpc_server.py +++ b/testbench/grpc_server.py @@ -479,7 +479,7 @@ def precondition(_, live_version, ctx): bucket = self.db.get_bucket(request.destination.bucket, context).metadata metadata = storage_pb2.Object() metadata.MergeFrom(request.destination) - (blob, _,) = gcs.object.Object.init( + (blob, _) = gcs.object.Object.init( request, metadata, composed_media, bucket, True, context ) self.db.insert_object( @@ -586,6 +586,30 @@ def ReadObject(self, request, context): @retry_test(method="storage.objects.get") def BidiReadObject(self, request_iterator, context): + def abort_with_redirect_error(routing_token, handle=None): + err = storage_pb2.BidiReadObjectRedirectedError() + if handle is not None: + err.read_handle.handle = handle + err.routing_token = routing_token + detail = any_pb2.Any() + detail.Pack(err) + status_proto = status_pb2.Status( + code=grpc.StatusCode.ABORTED.value[0], + message=grpc.StatusCode.ABORTED.value[1], + details=[detail], + ) + context.abort_with_status(rpc_status.to_status(status_proto)) + + # Check for expected redirect token in the request. + expected_token = testbench.common.get_expect_read_redirect_token( + self.db, context + ) + if expected_token: + params = testbench.common.get_context_request_params(context) + if params and f"routing_token={expected_token}" in params: + test_id = testbench.common.get_retry_test_id_from_context(context) + self.db.dequeue_next_instruction(test_id, "storage.objects.get") + # handle first message try: first_message = next(request_iterator) @@ -593,6 +617,18 @@ def BidiReadObject(self, request_iterator, context): # ok if no messages arrive from the client. return + # Routing-only redirect. + token_only = testbench.common.get_return_read_redirect_token(self.db, context) + if token_only: + abort_with_redirect_error(token_only) + + # Redirect with handle. + token_w_handle = testbench.common.get_return_read_handle_and_redirect_token( + self.db, context + ) + if token_w_handle: + abort_with_redirect_error(token_w_handle, handle=b"an-opaque-handle") + obj_spec = first_message.read_object_spec blob = self.db.get_object( obj_spec.bucket, @@ -614,9 +650,6 @@ def BidiReadObject(self, request_iterator, context): broken_stream_after_bytes = testbench.common.get_broken_stream_after_bytes( next_instruction ) - return_redirect_token = ( - testbench.common.get_return_read_handle_and_redirect_token(self.db, context) - ) # first_response is protected by GIL first_response = True @@ -630,20 +663,6 @@ def response(resp): # We ignore the read_mask for this test server return resp - if return_redirect_token and len(return_redirect_token): - detail = any_pb2.Any() - detail.Pack( - storage_pb2.BidiReadObjectRedirectedError( - routing_token=return_redirect_token - ) - ) - status_proto = status_pb2.Status( - code=grpc.StatusCode.ABORTED.value[0], - message=grpc.StatusCode.ABORTED.value[1], - details=[detail], - ) - context.abort_with_status(rpc_status.to_status(status_proto)) - if not first_message.read_ranges: # always emit a response to the first request. yield response(storage_pb2.BidiReadObjectResponse()) diff --git a/tests/test_testbench_retry.py b/tests/test_testbench_retry.py index 4e80d315..f98c3e0a 100644 --- a/tests/test_testbench_retry.py +++ b/tests/test_testbench_retry.py @@ -1442,6 +1442,125 @@ def test_grpc_bidiread_open_redirect(self): # Verify the early break occurred during first message only. self.assertEqual(len(responses), 0) + def test_grpc_bidiread_create_routing_only_redirect(self): + # Setup a routing-only redirect instruction + response = self.rest_client.post( + "/retry_test", + data=json.dumps( + { + "instructions": { + "storage.objects.get": [ + "redirect-send-token-sometoken", + "redirect-expect-token-sometoken", + ], + }, + "transport": "GRPC", + } + ), + ) + self.assertEqual(response.status_code, 200) + test_id = json.loads(response.data)["id"] + + context = unittest.mock.Mock() + context.abort_with_status = unittest.mock.MagicMock() + context.abort_with_status.side_effect = RpcError() + context.invocation_metadata = unittest.mock.Mock( + return_value=( + ("x-retry-test-id", test_id), + ( + "x-goog-request-params", + "bucket=projects/_/buckets/bucket-name", + ), + ) + ) + + r1 = storage_pb2.BidiReadObjectRequest( + read_object_spec=storage_pb2.BidiReadObjectSpec( + bucket="projects/_/buckets/bucket-name", + object="object-name", + ), + ) + + with self.assertRaises(RpcError): + list(self.grpc.BidiReadObject(iter([r1]), context=context)) + + context.abort_with_status.assert_called() + status = context.abort_with_status.call_args.args[0] + self.assertEqual(status.code, StatusCode.ABORTED) + + redirect_error = storage_pb2.BidiReadObjectRedirectedError() + self._unpack_details(status, redirect_error) + + self.assertFalse(redirect_error.HasField("read_handle")) + self.assertEqual(redirect_error.routing_token, "sometoken") + + def test_grpc_bidiread_redirect_expect_token_match(self): + token = "".join(random.choice(string.ascii_lowercase) for _ in range(5)) + response = self.rest_client.post( + "/retry_test", + data=json.dumps( + { + "instructions": { + "storage.objects.get": [ + f"redirect-expect-token-{token}", + ], + }, + "transport": "GRPC", + } + ), + ) + self.assertEqual(response.status_code, 200) + test_id = json.loads(response.data)["id"] + + r1 = storage_pb2.BidiReadObjectRequest( + read_object_spec=storage_pb2.BidiReadObjectSpec( + bucket="projects/_/buckets/bucket-name", + object="object-name", + ), + ) + + context = unittest.mock.Mock() + context.abort_with_status = unittest.mock.MagicMock() + context.abort_with_status.side_effect = RpcError() + + # With an incorrect routing token, the instruction should still be present + context.invocation_metadata = unittest.mock.Mock( + return_value=( + ("x-retry-test-id", test_id), + ( + "x-goog-request-params", + "bucket=projects/_/buckets/bucket-name&routing_token=incorrect_token", + ), + ) + ) + # We expect a failure here because no object exists yet, but the instruction + # check happens first. + try: + list(self.grpc.BidiReadObject(iter([r1]), context=context)) + except Exception: + pass + self.assertIsNotNone( + rest_server.db.peek_next_instruction(test_id, "storage.objects.get") + ) + + # With the correct routing token, the instruction should be consumed + context.invocation_metadata = unittest.mock.Mock( + return_value=( + ("x-retry-test-id", test_id), + ( + "x-goog-request-params", + f"bucket=projects/_/buckets/bucket-name&routing_token={token}", + ), + ) + ) + try: + list(self.grpc.BidiReadObject(iter([r1]), context=context)) + except Exception: + pass + self.assertIsNone( + rest_server.db.peek_next_instruction(test_id, "storage.objects.get") + ) + class _StatusAsCall: """_StatusAsCall wraps a status and pretends it is a client-side call"""