diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 6df17303f..77a642659 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -3,10 +3,9 @@ # # For syntax help see: # https://help.github.com/en/github/creating-cloning-and-archiving-repositories/about-code-owners#codeowners-syntax -# Note: This file is autogenerated. To make changes to the codeowner team, please update .repo-metadata.json. # @googleapis/yoshi-python @googleapis/gcs-sdk-team are the default owners for changes in this repo -* @googleapis/yoshi-python @googleapis/gcs-sdk-team +* @googleapis/yoshi-python @googleapis/gcs-sdk-team @googleapis/gcs-fs # @googleapis/python-samples-reviewers @googleapis/gcs-sdk-team are the default owners for samples changes /samples/ @googleapis/python-samples-reviewers @googleapis/gcs-sdk-team diff --git a/.github/sync-repo-settings.yaml b/.github/sync-repo-settings.yaml index 0d304cfe2..19c1d0ba4 100644 --- a/.github/sync-repo-settings.yaml +++ b/.github/sync-repo-settings.yaml @@ -9,12 +9,6 @@ branchProtectionRules: requiredStatusCheckContexts: - 'Kokoro' - 'cla/google' - - 'Kokoro system-3.12' + - 'Kokoro system-3.14' + - 'Kokoro system-3.9' - 'OwlBot Post Processor' -- pattern: python2 - requiresCodeOwnerReviews: true - requiresStrictStatusChecks: true - requiredStatusCheckContexts: - - 'Kokoro' - - 'cla/google' - - 'Kokoro system-2.7' diff --git a/.kokoro/presubmit/prerelease-deps.cfg b/.kokoro/presubmit/prerelease-deps.cfg new file mode 100644 index 000000000..3595fb43f --- /dev/null +++ b/.kokoro/presubmit/prerelease-deps.cfg @@ -0,0 +1,7 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +# Only run this nox session. +env_vars: { + key: "NOX_SESSION" + value: "prerelease_deps" +} diff --git a/.kokoro/presubmit/presubmit.cfg b/.kokoro/presubmit/presubmit.cfg index b158096f0..5423df92a 100644 --- a/.kokoro/presubmit/presubmit.cfg +++ b/.kokoro/presubmit/presubmit.cfg @@ -1,6 +1,7 @@ # Format: //devtools/kokoro/config/proto/build.proto -# Disable system tests. +# Disable system tests in this presubmit because they are run in separate +# presubmit jobs, whose configs are in system-3.xx.cfg files. env_vars: { key: "RUN_SYSTEM_TESTS" value: "false" diff --git a/.kokoro/presubmit/system-3.12.cfg b/.kokoro/presubmit/system-3.14.cfg similarity index 91% rename from .kokoro/presubmit/system-3.12.cfg rename to .kokoro/presubmit/system-3.14.cfg index d4cca031b..fcc70a922 100644 --- a/.kokoro/presubmit/system-3.12.cfg +++ b/.kokoro/presubmit/system-3.14.cfg @@ -3,7 +3,7 @@ # Only run this nox session. env_vars: { key: "NOX_SESSION" - value: "system-3.12" + value: "system-3.14" } # Credentials needed to test universe domain. diff --git a/.kokoro/presubmit/system-3.9.cfg b/.kokoro/presubmit/system-3.9.cfg new file mode 100644 index 000000000..d21467d02 --- /dev/null +++ b/.kokoro/presubmit/system-3.9.cfg @@ -0,0 +1,13 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +# Only run this nox session. +env_vars: { + key: "NOX_SESSION" + value: "system-3.9" +} + +# Credentials needed to test universe domain. +env_vars: { + key: "SECRET_MANAGER_KEYS" + value: "client-library-test-universe-domain-credential" +} \ No newline at end of file diff --git a/.librarian/generator-input/.repo-metadata.json b/.librarian/generator-input/.repo-metadata.json index f644429bc..bd870f959 100644 --- a/.librarian/generator-input/.repo-metadata.json +++ b/.librarian/generator-input/.repo-metadata.json @@ -12,7 +12,7 @@ "api_id": "storage.googleapis.com", "requires_billing": true, "default_version": "v2", - "codeowner_team": "@googleapis/gcs-sdk-team", + "codeowner_team": "@googleapis/yoshi-python @googleapis/gcs-sdk-team @googleapis/gcs-fs", "api_shortname": "storage", "api_description": "is a durable and highly available object storage service. Google Cloud Storage is almost infinitely scalable and guarantees consistency: when a write succeeds, the latest copy of the object will be returned to any GET, globally." } diff --git a/.librarian/generator-input/noxfile.py b/.librarian/generator-input/noxfile.py index 16cf97b01..ca527decd 100644 --- a/.librarian/generator-input/noxfile.py +++ b/.librarian/generator-input/noxfile.py @@ -26,9 +26,9 @@ BLACK_VERSION = "black==23.7.0" BLACK_PATHS = ["docs", "google", "tests", "noxfile.py", "setup.py"] -DEFAULT_PYTHON_VERSION = "3.12" -SYSTEM_TEST_PYTHON_VERSIONS = ["3.12"] -UNIT_TEST_PYTHON_VERSIONS = ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] +DEFAULT_PYTHON_VERSION = "3.14" +SYSTEM_TEST_PYTHON_VERSIONS = ["3.9", "3.14"] +UNIT_TEST_PYTHON_VERSIONS = ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] CONFORMANCE_TEST_PYTHON_VERSIONS = ["3.12"] CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() @@ -51,6 +51,7 @@ "unit-3.11", "unit-3.12", "unit-3.13", + "unit-3.14", # cover must be last to avoid error `No data to report` "cover", ] diff --git a/.librarian/generator-input/setup.py b/.librarian/generator-input/setup.py index 2c4504749..89971aa33 100644 --- a/.librarian/generator-input/setup.py +++ b/.librarian/generator-input/setup.py @@ -94,6 +94,7 @@ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Operating System :: OS Independent", "Topic :: Internet", ], diff --git a/.repo-metadata.json b/.repo-metadata.json index f644429bc..bd870f959 100644 --- a/.repo-metadata.json +++ b/.repo-metadata.json @@ -12,7 +12,7 @@ "api_id": "storage.googleapis.com", "requires_billing": true, "default_version": "v2", - "codeowner_team": "@googleapis/gcs-sdk-team", + "codeowner_team": "@googleapis/yoshi-python @googleapis/gcs-sdk-team @googleapis/gcs-fs", "api_shortname": "storage", "api_description": "is a durable and highly available object storage service. Google Cloud Storage is almost infinitely scalable and guarantees consistency: when a write succeeds, the latest copy of the object will be returned to any GET, globally." } diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ee1c7beb..da1f2149b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,20 @@ [1]: https://pypi.org/project/google-cloud-storage/#history +## [3.7.0](https://github.com/googleapis/python-storage/compare/v3.6.0...v3.7.0) (2025-12-09) + + +### Features + +* Auto enable mTLS when supported certificates are detected ([#1637](https://github.com/googleapis/python-storage/issues/1637)) ([4e91c54](https://github.com/googleapis/python-storage/commit/4e91c541363f0e583bf9dd1b81a95ff2cb618bac)) +* Send entire object checksum in the final api call of resumable upload ([#1654](https://github.com/googleapis/python-storage/issues/1654)) ([ddce7e5](https://github.com/googleapis/python-storage/commit/ddce7e53a13e6c0487221bb14e88161da7ed9e08)) +* Support urllib3 >= 2.6.0 ([#1658](https://github.com/googleapis/python-storage/issues/1658)) ([57405e9](https://github.com/googleapis/python-storage/commit/57405e956a7ca579b20582bf6435cec42743c478)) + + +### Bug Fixes + +* Fix for [move_blob](https://github.com/googleapis/python-storage/blob/57405e956a7ca579b20582bf6435cec42743c478/google/cloud/storage/bucket.py#L2256) failure when the new blob name contains characters that need to be url encoded ([#1605](https://github.com/googleapis/python-storage/issues/1605)) ([ec470a2](https://github.com/googleapis/python-storage/commit/ec470a270e189e137c7229cc359367d5a897cdb9)) + ## [3.6.0](https://github.com/googleapis/python-storage/compare/v3.5.0...v3.6.0) (2025-11-17) diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 316d8b266..1c1817212 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -22,7 +22,7 @@ In order to add a feature: documentation. - The feature must work fully on the following CPython versions: - 3.7, 3.8, 3.9, 3.10, 3.11, 3.12 and 3.13 on both UNIX and Windows. + 3.7, 3.8, 3.9, 3.10, 3.11, 3.12, 3.13 and 3.14 on both UNIX and Windows. - The feature must not add unnecessary dependencies (where "unnecessary" is of course subjective, but new dependencies should @@ -69,8 +69,7 @@ We use `nox `__ to instrument our tests. - To test your changes, run unit tests with ``nox``:: - $ nox -s unit-2.7 - $ nox -s unit-3.7 + $ nox -s unit-3.9 $ ... .. note:: @@ -133,14 +132,11 @@ Running System Tests - To run system tests, you can execute:: - $ nox -s system-3.8 - $ nox -s system-2.7 + $ nox -s system-3.14 .. note:: - System tests are only configured to run under Python 2.7 and - Python 3.8. For expediency, we do not run them in older versions - of Python 3. + System tests are configured to run under Python 3.14 in ``noxfile.py``. This alone will not run the tests. You'll need to change some local auth settings and change some configuration in your project to @@ -202,25 +198,27 @@ Supported Python Versions We support: -- `Python 3.5`_ -- `Python 3.6`_ - `Python 3.7`_ - `Python 3.8`_ +- `Python 3.9`_ +- `Python 3.10`_ +- `Python 3.11`_ +- `Python 3.12`_ +- `Python 3.13`_ +- `Python 3.14`_ -.. _Python 3.5: https://docs.python.org/3.5/ -.. _Python 3.6: https://docs.python.org/3.6/ .. _Python 3.7: https://docs.python.org/3.7/ .. _Python 3.8: https://docs.python.org/3.8/ - +.. _Python 3.9: https://docs.python.org/3.9/ +.. _Python 3.10: https://docs.python.org/3.10/ +.. _Python 3.11: https://docs.python.org/3.11/ +.. _Python 3.12: https://docs.python.org/3.12/ +.. _Python 3.13: https://docs.python.org/3.13/ +.. _Python 3.14: https://docs.python.org/3.14/ Supported versions can be found in our ``noxfile.py`` `config`_. -.. _config: https://github.com/googleapis/python-storage/blob/main/noxfile.py - -Python 2.7 support is deprecated. All code changes should maintain Python 2.7 compatibility until January 1, 2020. - -We also explicitly decided to support Python 3 beginning with version -3.5. Reasons for this include: +We also explicitly decided to support Python 3 beginning with version 3.9. Reasons for this include: - Encouraging use of newest versions of Python 3 - Taking the lead of `prominent`_ open-source `projects`_ diff --git a/cloudbuild/run_zonal_tests.sh b/cloudbuild/run_zonal_tests.sh new file mode 100644 index 000000000..283ed6826 --- /dev/null +++ b/cloudbuild/run_zonal_tests.sh @@ -0,0 +1,28 @@ + +set -euxo pipefail +echo '--- Installing git and cloning repository on VM ---' +sudo apt-get update && sudo apt-get install -y git python3-pip python3-venv + +# Clone the repository and checkout the specific commit from the build trigger. +git clone https://github.com/googleapis/python-storage.git +cd python-storage +git fetch origin "refs/pull/${_PR_NUMBER}/head" +git checkout ${COMMIT_SHA} + + +echo '--- Installing Python and dependencies on VM ---' +python3 -m venv env +source env/bin/activate + +echo 'Install testing libraries explicitly, as they are not in setup.py' +pip install --upgrade pip +pip install pytest pytest-timeout pytest-subtests pytest-asyncio +pip install google-cloud-testutils google-cloud-kms +pip install -e . + +echo '--- Setting up environment variables on VM ---' +export ZONAL_BUCKET=${_ZONAL_BUCKET} +export RUN_ZONAL_SYSTEM_TESTS=True +CURRENT_ULIMIT=$(ulimit -n) +echo '--- Running Zonal tests on VM with ulimit set to ---' $CURRENT_ULIMIT +pytest -vv -s --log-format='%(asctime)s %(levelname)s %(message)s' --log-date-format='%H:%M:%S' tests/system/test_zonal.py diff --git a/cloudbuild/zb-system-tests-cloudbuild.yaml b/cloudbuild/zb-system-tests-cloudbuild.yaml new file mode 100644 index 000000000..562eae175 --- /dev/null +++ b/cloudbuild/zb-system-tests-cloudbuild.yaml @@ -0,0 +1,101 @@ +substitutions: + _REGION: "us-central1" + _ZONE: "us-central1-a" + _SHORT_BUILD_ID: ${BUILD_ID:0:8} + _VM_NAME: "py-sdk-sys-test-${_SHORT_BUILD_ID}" + _ULIMIT: "10000" # 10k, for gRPC bidi streams + + + +steps: + # Step 0: Generate a persistent SSH key for this build run. + # This prevents gcloud from adding a new key to the OS Login profile on every ssh/scp command. + - name: "gcr.io/google.com/cloudsdktool/cloud-sdk" + id: "generate-ssh-key" + entrypoint: "bash" + args: + - "-c" + - | + mkdir -p /workspace/.ssh + # Generate the SSH key + ssh-keygen -t rsa -f /workspace/.ssh/google_compute_engine -N '' -C gcb + # Save the public key content to a file for the cleanup step + cat /workspace/.ssh/google_compute_engine.pub > /workspace/gcb_ssh_key.pub + waitFor: ["-"] + + # Step 1 Create a GCE VM to run the tests. + # The VM is created in the same zone as the buckets to test rapid storage features. + # It's given the 'cloud-platform' scope to allow it to access GCS and other services. + - name: "gcr.io/google.com/cloudsdktool/cloud-sdk" + id: "create-vm" + entrypoint: "gcloud" + args: + - "compute" + - "instances" + - "create" + - "${_VM_NAME}" + - "--project=${PROJECT_ID}" + - "--zone=${_ZONE}" + - "--machine-type=e2-medium" + - "--image-family=debian-13" + - "--image-project=debian-cloud" + - "--service-account=${_ZONAL_VM_SERVICE_ACCOUNT}" + - "--scopes=https://www.googleapis.com/auth/devstorage.full_control,https://www.googleapis.com/auth/devstorage.read_only,https://www.googleapis.com/auth/devstorage.read_write" + - "--metadata=enable-oslogin=TRUE" + waitFor: ["-"] + + # Step 2: Run the integration tests inside the newly created VM and cleanup. + # This step uses 'gcloud compute ssh' to execute a remote script. + # The VM is deleted after tests are run, regardless of success. + - name: "gcr.io/google.com/cloudsdktool/cloud-sdk" + id: "run-tests-and-delete-vm" + entrypoint: "bash" + args: + - "-c" + - | + set -e + # Wait for the VM to be fully initialized and SSH to be ready. + for i in {1..10}; do + if gcloud compute ssh ${_VM_NAME} --zone=${_ZONE} --internal-ip --ssh-key-file=/workspace/.ssh/google_compute_engine --command="echo VM is ready"; then + break + fi + echo "Waiting for VM to become available... (attempt $i/10)" + sleep 15 + done + # copy the script to the VM + gcloud compute scp cloudbuild/run_zonal_tests.sh ${_VM_NAME}:~ --zone=${_ZONE} --internal-ip --ssh-key-file=/workspace/.ssh/google_compute_engine + + # Execute the script on the VM via SSH. + # Capture the exit code to ensure cleanup happens before the build fails. + set +e + gcloud compute ssh ${_VM_NAME} --zone=${_ZONE} --internal-ip --ssh-key-file=/workspace/.ssh/google_compute_engine --command="ulimit -n {_ULIMIT}; COMMIT_SHA=${COMMIT_SHA} _ZONAL_BUCKET=${_ZONAL_BUCKET} _PR_NUMBER=${_PR_NUMBER} bash run_zonal_tests.sh" + EXIT_CODE=$? + set -e + + echo "--- Deleting GCE VM ---" + gcloud compute instances delete "${_VM_NAME}" --zone=${_ZONE} --quiet + + # Exit with the original exit code from the test script. + exit $$EXIT_CODE + waitFor: + - "create-vm" + - "generate-ssh-key" + + - name: "gcr.io/google.com/cloudsdktool/cloud-sdk" + id: "cleanup-ssh-key" + entrypoint: "bash" + args: + - "-c" + - | + echo "--- Removing SSH key from OS Login profile to prevent accumulation ---" + gcloud compute os-login ssh-keys remove \ + --key-file=/workspace/gcb_ssh_key.pub || true + waitFor: + - "run-tests-and-delete-vm" + +timeout: "3600s" # 60 minutes + +options: + logging: CLOUD_LOGGING_ONLY + pool: + name: "projects/${PROJECT_ID}/locations/us-central1/workerPools/cloud-build-worker-pool" diff --git a/google/cloud/_storage_v2/services/storage/client.py b/google/cloud/_storage_v2/services/storage/client.py index 16c76a01f..cdccf3fab 100644 --- a/google/cloud/_storage_v2/services/storage/client.py +++ b/google/cloud/_storage_v2/services/storage/client.py @@ -184,6 +184,34 @@ def _get_default_mtls_endpoint(api_endpoint): _DEFAULT_ENDPOINT_TEMPLATE = "storage.{UNIVERSE_DOMAIN}" _DEFAULT_UNIVERSE = "googleapis.com" + @staticmethod + def _use_client_cert_effective(): + """Returns whether client certificate should be used for mTLS if the + google-auth version supports should_use_client_cert automatic mTLS enablement. + + Alternatively, read from the GOOGLE_API_USE_CLIENT_CERTIFICATE env var. + + Returns: + bool: whether client certificate should be used for mTLS + Raises: + ValueError: (If using a version of google-auth without should_use_client_cert and + GOOGLE_API_USE_CLIENT_CERTIFICATE is set to an unexpected value.) + """ + # check if google-auth version supports should_use_client_cert for automatic mTLS enablement + if hasattr(mtls, "should_use_client_cert"): + return mtls.should_use_client_cert() + else: + # if unsupported, fallback to reading from env var + use_client_cert_str = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + if use_client_cert_str not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be" + " either `true` or `false`" + ) + return use_client_cert_str == "true" + @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): """Creates an instance of this client using the provided credentials @@ -390,12 +418,8 @@ def get_mtls_endpoint_and_cert_source( ) if client_options is None: client_options = client_options_lib.ClientOptions() - use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_client_cert = StorageClient._use_client_cert_effective() use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_client_cert not in ("true", "false"): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) if use_mtls_endpoint not in ("auto", "never", "always"): raise MutualTLSChannelError( "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" @@ -403,7 +427,7 @@ def get_mtls_endpoint_and_cert_source( # Figure out the client cert source to use. client_cert_source = None - if use_client_cert == "true": + if use_client_cert: if client_options.client_cert_source: client_cert_source = client_options.client_cert_source elif mtls.has_default_client_cert_source(): @@ -435,20 +459,14 @@ def _read_environment_variables(): google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT is not any of ["auto", "never", "always"]. """ - use_client_cert = os.getenv( - "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" - ).lower() + use_client_cert = StorageClient._use_client_cert_effective() use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower() universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN") - if use_client_cert not in ("true", "false"): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) if use_mtls_endpoint not in ("auto", "never", "always"): raise MutualTLSChannelError( "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" ) - return use_client_cert == "true", use_mtls_endpoint, universe_domain_env + return use_client_cert, use_mtls_endpoint, universe_domain_env @staticmethod def _get_client_cert_source(provided_cert_source, use_cert_flag): diff --git a/google/cloud/storage/_experimental/asyncio/_utils.py b/google/cloud/storage/_experimental/asyncio/_utils.py new file mode 100644 index 000000000..32d83a586 --- /dev/null +++ b/google/cloud/storage/_experimental/asyncio/_utils.py @@ -0,0 +1,35 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import google_crc32c + +from google.api_core import exceptions + + +def raise_if_no_fast_crc32c(): + """Check if the C-accelerated version of google-crc32c is available. + + If not, raise an error to prevent silent performance degradation. + + raises google.api_core.exceptions.FailedPrecondition: If the C extension is not available. + returns: True if the C extension is available. + rtype: bool + + """ + if google_crc32c.implementation != "c": + raise exceptions.FailedPrecondition( + "The google-crc32c package is not installed with C support. " + "C extension is required for faster data integrity checks." + "For more information, see https://github.com/googleapis/python-crc32c." + ) diff --git a/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py b/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py index d34c844d5..b4f40b423 100644 --- a/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py +++ b/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py @@ -21,7 +21,13 @@ if you want to use these Rapid Storage APIs. """ +from io import BufferedReader from typing import Optional, Union + +from google_crc32c import Checksum +from google.api_core import exceptions + +from ._utils import raise_if_no_fast_crc32c from google.cloud import _storage_v2 from google.cloud.storage._experimental.asyncio.async_grpc_client import ( AsyncGrpcClient, @@ -32,7 +38,7 @@ _MAX_CHUNK_SIZE_BYTES = 2 * 1024 * 1024 # 2 MiB -_MAX_BUFFER_SIZE_BYTES = 16 * 1024 * 1024 # 16 MiB +_DEFAULT_FLUSH_INTERVAL_BYTES = 16 * 1024 * 1024 # 16 MiB class AsyncAppendableObjectWriter: @@ -45,6 +51,7 @@ def __init__( object_name: str, generation=None, write_handle=None, + writer_options: Optional[dict] = None, ): """ Class for appending data to a GCS Appendable Object. @@ -100,6 +107,7 @@ def __init__( :param write_handle: (Optional) An existing handle for writing the object. If provided, opening the bidi-gRPC connection will be faster. """ + raise_if_no_fast_crc32c() self.client = client self.bucket_name = bucket_name self.object_name = object_name @@ -114,8 +122,27 @@ def __init__( write_handle=self.write_handle, ) self._is_stream_open: bool = False + # `offset` is the latest size of the object without staleless. self.offset: Optional[int] = None + # `persisted_size` is the total_bytes persisted in the GCS server. + # Please note: `offset` and `persisted_size` are same when the stream is + # opened. self.persisted_size: Optional[int] = None + if writer_options is None: + writer_options = {} + self.flush_interval = writer_options.get( + "FLUSH_INTERVAL_BYTES", _DEFAULT_FLUSH_INTERVAL_BYTES + ) + # TODO: add test case for this. + if self.flush_interval < _MAX_CHUNK_SIZE_BYTES: + raise exceptions.OutOfRange( + f"flush_interval must be >= {_MAX_CHUNK_SIZE_BYTES} , but provided {self.flush_interval}" + ) + if self.flush_interval % _MAX_CHUNK_SIZE_BYTES != 0: + raise exceptions.OutOfRange( + f"flush_interval must be a multiple of {_MAX_CHUNK_SIZE_BYTES}, but provided {self.flush_interval}" + ) + self.bytes_appended_since_last_flush = 0 async def state_lookup(self) -> int: """Returns the persisted_size @@ -152,17 +179,17 @@ async def open(self) -> None: if self.generation is None: self.generation = self.write_obj_stream.generation_number self.write_handle = self.write_obj_stream.write_handle - - # Update self.persisted_size - _ = await self.state_lookup() + self.persisted_size = self.write_obj_stream.persisted_size async def append(self, data: bytes) -> None: """Appends data to the Appendable object. - This method sends the provided data to the GCS server in chunks. It - maintains an internal threshold `_MAX_BUFFER_SIZE_BYTES` and will - automatically flush the data to make it visible to readers when that - threshold has reached. + calling `self.append` will append bytes at the end of the current size + ie. `self.offset` bytes relative to the begining of the object. + + This method sends the provided `data` to the GCS server in chunks. + and persists data in GCS at every `_MAX_BUFFER_SIZE_BYTES` bytes by + calling `self.simple_flush`. :type data: bytes :param data: The bytes to append to the object. @@ -184,23 +211,24 @@ async def append(self, data: bytes) -> None: self.offset = self.persisted_size start_idx = 0 - bytes_to_flush = 0 while start_idx < total_bytes: end_idx = min(start_idx + _MAX_CHUNK_SIZE_BYTES, total_bytes) + data_chunk = data[start_idx:end_idx] await self.write_obj_stream.send( _storage_v2.BidiWriteObjectRequest( write_offset=self.offset, checksummed_data=_storage_v2.ChecksummedData( - content=data[start_idx:end_idx] + content=data_chunk, + crc32c=int.from_bytes(Checksum(data_chunk).digest(), "big"), ), ) ) chunk_size = end_idx - start_idx self.offset += chunk_size - bytes_to_flush += chunk_size - if bytes_to_flush >= _MAX_BUFFER_SIZE_BYTES: + self.bytes_appended_since_last_flush += chunk_size + if self.bytes_appended_since_last_flush >= self.flush_interval: await self.simple_flush() - bytes_to_flush = 0 + self.bytes_appended_since_last_flush = 0 start_idx = end_idx async def simple_flush(self) -> None: @@ -267,7 +295,8 @@ async def close(self, finalize_on_close=False) -> Union[int, _storage_v2.Object] await self.finalize() else: await self.flush() - await self.write_obj_stream.close() + + await self.write_obj_stream.close() self._is_stream_open = False self.offset = None @@ -311,6 +340,16 @@ async def append_from_stream(self, stream_obj): """ raise NotImplementedError("append_from_stream is not implemented yet.") - async def append_from_file(self, file_path: str): - """Create a file object from `file_path` and call append_from_stream(file_obj)""" - raise NotImplementedError("append_from_file is not implemented yet.") + async def append_from_file( + self, file_obj: BufferedReader, block_size: int = _DEFAULT_FLUSH_INTERVAL_BYTES + ): + """ + Appends data to an Appendable Object using file_handle which is opened + for reading in binary mode. + + :type file_obj: file + :param file_obj: A file handle opened in binary mode for reading. + + """ + while block := file_obj.read(block_size): + await self.append(block) diff --git a/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py b/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py index 16dce5025..8f16294d8 100644 --- a/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py +++ b/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py @@ -14,54 +14,80 @@ from __future__ import annotations import asyncio -import google_crc32c +import logging from google.api_core import exceptions -from google_crc32c import Checksum +from google.api_core.retry_async import AsyncRetry +from google.cloud.storage._experimental.asyncio.retry._helpers import _handle_redirect +from google.rpc import status_pb2 -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Any, Dict +from ._utils import raise_if_no_fast_crc32c from google.cloud.storage._experimental.asyncio.async_read_object_stream import ( _AsyncReadObjectStream, ) from google.cloud.storage._experimental.asyncio.async_grpc_client import ( AsyncGrpcClient, ) +from google.cloud.storage._experimental.asyncio.retry.bidi_stream_retry_manager import ( + _BidiStreamRetryManager, +) +from google.cloud.storage._experimental.asyncio.retry.reads_resumption_strategy import ( + _ReadResumptionStrategy, + _DownloadState, +) from io import BytesIO from google.cloud import _storage_v2 -from google.cloud.storage.exceptions import DataCorruption from google.cloud.storage._helpers import generate_random_56_bit_integer _MAX_READ_RANGES_PER_BIDI_READ_REQUEST = 100 +_BIDI_READ_REDIRECTED_TYPE_URL = ( + "type.googleapis.com/google.storage.v2.BidiReadObjectRedirectedError" +) +logger = logging.getLogger(__name__) + + +def _is_read_retryable(exc): + """Predicate to determine if a read operation should be retried.""" + if isinstance( + exc, + ( + exceptions.InternalServerError, + exceptions.ServiceUnavailable, + exceptions.DeadlineExceeded, + exceptions.TooManyRequests, + ), + ): + return True + + if not isinstance(exc, exceptions.Aborted) or not exc.errors: + return False + + try: + grpc_error = exc.errors[0] + trailers = grpc_error.trailing_metadata() + if not trailers: + return False + + status_details_bin = next( + (v for k, v in trailers if k == "grpc-status-details-bin"), None + ) -class Result: - """An instance of this class will be populated and retured for each - `read_range` provided to ``download_ranges`` method. - - """ - - def __init__(self, bytes_requested: int): - # only while instantiation, should not be edited later. - # hence there's no setter, only getter is provided. - self._bytes_requested: int = bytes_requested - self._bytes_written: int = 0 - - @property - def bytes_requested(self) -> int: - return self._bytes_requested - - @property - def bytes_written(self) -> int: - return self._bytes_written - - @bytes_written.setter - def bytes_written(self, value: int): - self._bytes_written = value + if not status_details_bin: + return False - def __repr__(self): - return f"bytes_requested: {self._bytes_requested}, bytes_written: {self._bytes_written}" + status_proto = status_pb2.Status() + status_proto.ParseFromString(status_details_bin) + return any( + detail.type_url == _BIDI_READ_REDIRECTED_TYPE_URL + for detail in status_proto.details + ) + except Exception as e: + logger.error(f"Error parsing status_details_bin: {e}") + return False class AsyncMultiRangeDownloader: @@ -104,6 +130,8 @@ async def create_mrd( object_name: str, generation_number: Optional[int] = None, read_handle: Optional[bytes] = None, + retry_policy: Optional[AsyncRetry] = None, + metadata: Optional[List[Tuple[str, str]]] = None, ) -> AsyncMultiRangeDownloader: """Initializes a MultiRangeDownloader and opens the underlying bidi-gRPC object for reading. @@ -125,11 +153,17 @@ async def create_mrd( :param read_handle: (Optional) An existing handle for reading the object. If provided, opening the bidi-gRPC connection will be faster. + :type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry` + :param retry_policy: (Optional) The retry policy to use for the ``open`` operation. + + :type metadata: List[Tuple[str, str]] + :param metadata: (Optional) The metadata to be sent with the ``open`` request. + :rtype: :class:`~google.cloud.storage._experimental.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader` :returns: An initialized AsyncMultiRangeDownloader instance for reading. """ mrd = cls(client, bucket_name, object_name, generation_number, read_handle) - await mrd.open() + await mrd.open(retry_policy=retry_policy, metadata=metadata) return mrd def __init__( @@ -160,14 +194,7 @@ def __init__( :param read_handle: (Optional) An existing read handle. """ - # Verify that the fast, C-accelerated version of crc32c is available. - # If not, raise an error to prevent silent performance degradation. - if google_crc32c.implementation != "c": - raise exceptions.NotFound( - "The google-crc32c package is not installed with C support. " - "Bidi reads require the C extension for data integrity checks." - "For more information, see https://github.com/googleapis/python-crc32c." - ) + raise_if_no_fast_crc32c() self.client = client self.bucket_name = bucket_name @@ -176,24 +203,65 @@ def __init__( self.read_handle = read_handle self.read_obj_str: Optional[_AsyncReadObjectStream] = None self._is_stream_open: bool = False - + self._routing_token: Optional[str] = None self._read_id_to_writable_buffer_dict = {} self._read_id_to_download_ranges_id = {} self._download_ranges_id_to_pending_read_ids = {} + self.persisted_size: Optional[int] = None # updated after opening the stream - async def open(self) -> None: - """Opens the bidi-gRPC connection to read from the object. - - This method initializes and opens an `_AsyncReadObjectStream` (bidi-gRPC stream) to - for downloading ranges of data from GCS ``Object``. + def _on_open_error(self, exc): + """Extracts routing token and read handle on redirect error during open.""" + routing_token, read_handle = _handle_redirect(exc) + if routing_token: + self._routing_token = routing_token + if read_handle: + self.read_handle = read_handle - "Opening" constitutes fetching object metadata such as generation number - and read handle and sets them as attributes if not already set. - """ + async def open( + self, + retry_policy: Optional[AsyncRetry] = None, + metadata: Optional[List[Tuple[str, str]]] = None, + ) -> None: + """Opens the bidi-gRPC connection to read from the object.""" if self._is_stream_open: raise ValueError("Underlying bidi-gRPC stream is already open") - if self.read_obj_str is None: + if retry_policy is None: + retry_policy = AsyncRetry( + predicate=_is_read_retryable, on_error=self._on_open_error + ) + else: + original_on_error = retry_policy._on_error + + def combined_on_error(exc): + self._on_open_error(exc) + if original_on_error: + original_on_error(exc) + + retry_policy = AsyncRetry( + predicate=_is_read_retryable, + initial=retry_policy._initial, + maximum=retry_policy._maximum, + multiplier=retry_policy._multiplier, + deadline=retry_policy._deadline, + on_error=combined_on_error, + ) + + async def _do_open(): + current_metadata = list(metadata) if metadata else [] + + # Cleanup stream from previous failed attempt, if any. + if self.read_obj_str: + if self.read_obj_str.is_stream_open: + try: + await self.read_obj_str.close() + except exceptions.GoogleAPICallError as e: + logger.warning( + f"Failed to close existing stream during resumption: {e}" + ) + self.read_obj_str = None + self._is_stream_open = False + self.read_obj_str = _AsyncReadObjectStream( client=self.client, bucket_name=self.bucket_name, @@ -201,22 +269,42 @@ async def open(self) -> None: generation_number=self.generation_number, read_handle=self.read_handle, ) - await self.read_obj_str.open() - self._is_stream_open = True - if self.generation_number is None: - self.generation_number = self.read_obj_str.generation_number - self.read_handle = self.read_obj_str.read_handle - return + + if self._routing_token: + current_metadata.append( + ("x-goog-request-params", f"routing_token={self._routing_token}") + ) + self._routing_token = None + + await self.read_obj_str.open( + metadata=current_metadata if current_metadata else None + ) + + if self.read_obj_str.generation_number: + self.generation_number = self.read_obj_str.generation_number + if self.read_obj_str.read_handle: + self.read_handle = self.read_obj_str.read_handle + if self.read_obj_str.persisted_size is not None: + self.persisted_size = self.read_obj_str.persisted_size + + self._is_stream_open = True + + await retry_policy(_do_open)() async def download_ranges( - self, read_ranges: List[Tuple[int, int, BytesIO]], lock: asyncio.Lock = None + self, + read_ranges: List[Tuple[int, int, BytesIO]], + lock: asyncio.Lock = None, + retry_policy: Optional[AsyncRetry] = None, + metadata: Optional[List[Tuple[str, str]]] = None, ) -> None: """Downloads multiple byte ranges from the object into the buffers - provided by user. + provided by user with automatic retries. :type read_ranges: List[Tuple[int, int, "BytesIO"]] :param read_ranges: A list of tuples, where each tuple represents a - byte range (start_byte, bytes_to_read, writeable_buffer). Buffer has + combintaion of byte_range and writeable buffer in format - + (`start_byte`, `bytes_to_read`, `writeable_buffer`). Buffer has to be provided by the user, and user has to make sure appropriate memory is available in the application to avoid out-of-memory crash. @@ -246,6 +334,8 @@ async def download_ranges( ``` + :type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry` + :param retry_policy: (Optional) The retry policy to use for the operation. :raises ValueError: if the underlying bidi-GRPC stream is not open. :raises ValueError: if the length of read_ranges is more than 1000. @@ -264,72 +354,122 @@ async def download_ranges( if lock is None: lock = asyncio.Lock() - _func_id = generate_random_56_bit_integer() - read_ids_in_current_func = set() - for i in range(0, len(read_ranges), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST): - read_ranges_segment = read_ranges[ - i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST - ] + if retry_policy is None: + retry_policy = AsyncRetry(predicate=_is_read_retryable) + + # Initialize Global State for Retry Strategy + download_states = {} + for read_range in read_ranges: + read_id = generate_random_56_bit_integer() + download_states[read_id] = _DownloadState( + initial_offset=read_range[0], + initial_length=read_range[1], + user_buffer=read_range[2], + ) - read_ranges_for_bidi_req = [] - for j, read_range in enumerate(read_ranges_segment): - read_id = generate_random_56_bit_integer() - read_ids_in_current_func.add(read_id) - self._read_id_to_download_ranges_id[read_id] = _func_id - self._read_id_to_writable_buffer_dict[read_id] = read_range[2] - bytes_requested = read_range[1] - read_ranges_for_bidi_req.append( - _storage_v2.ReadRange( - read_offset=read_range[0], - read_length=bytes_requested, - read_id=read_id, + initial_state = { + "download_states": download_states, + "read_handle": self.read_handle, + "routing_token": None, + } + + # Track attempts to manage stream reuse + attempt_count = 0 + + def send_ranges_and_get_bytes( + requests: List[_storage_v2.ReadRange], + state: Dict[str, Any], + metadata: Optional[List[Tuple[str, str]]] = None, + ): + async def generator(): + nonlocal attempt_count + attempt_count += 1 + + if attempt_count > 1: + logger.info( + f"Resuming download (attempt {attempt_count - 1}) for {len(requests)} ranges." ) - ) - async with lock: - await self.read_obj_str.send( - _storage_v2.BidiReadObjectRequest( - read_ranges=read_ranges_for_bidi_req - ) - ) - self._download_ranges_id_to_pending_read_ids[ - _func_id - ] = read_ids_in_current_func - - while len(self._download_ranges_id_to_pending_read_ids[_func_id]) > 0: - async with lock: - response = await self.read_obj_str.recv() - - if response is None: - raise Exception("None response received, something went wrong.") - - for object_data_range in response.object_data_ranges: - if object_data_range.read_range is None: - raise Exception("Invalid response, read_range is None") - checksummed_data = object_data_range.checksummed_data - data = checksummed_data.content - server_checksum = checksummed_data.crc32c - - client_crc32c = Checksum(data).digest() - client_checksum = int.from_bytes(client_crc32c, "big") - - if server_checksum != client_checksum: - raise DataCorruption( - response, - f"Checksum mismatch for read_id {object_data_range.read_range.read_id}. " - f"Server sent {server_checksum}, client calculated {client_checksum}.", + async with lock: + current_handle = state.get("read_handle") + current_token = state.get("routing_token") + + # We reopen if it's a redirect (token exists) OR if this is a retry + # (not first attempt). This prevents trying to send data on a dead + # stream from a previous failed attempt. + should_reopen = ( + (attempt_count > 1) + or (current_token is not None) + or (metadata is not None) ) - read_id = object_data_range.read_range.read_id - buffer = self._read_id_to_writable_buffer_dict[read_id] - buffer.write(data) + if should_reopen: + if current_token: + logger.info( + f"Re-opening stream with routing token: {current_token}" + ) + # Close existing stream if any + if self.read_obj_str and self.read_obj_str.is_stream_open: + await self.read_obj_str.close() + + # Re-initialize stream + self.read_obj_str = _AsyncReadObjectStream( + client=self.client, + bucket_name=self.bucket_name, + object_name=self.object_name, + generation_number=self.generation_number, + read_handle=current_handle, + ) + + # Inject routing_token into metadata if present + current_metadata = list(metadata) if metadata else [] + if current_token: + current_metadata.append( + ( + "x-goog-request-params", + f"routing_token={current_token}", + ) + ) + + await self.read_obj_str.open( + metadata=current_metadata if current_metadata else None + ) + self._is_stream_open = True + + pending_read_ids = {r.read_id for r in requests} + + # Send Requests + for i in range( + 0, len(requests), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST + ): + batch = requests[i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST] + await self.read_obj_str.send( + _storage_v2.BidiReadObjectRequest(read_ranges=batch) + ) + + while pending_read_ids: + response = await self.read_obj_str.recv() + if response is None: + break + if response.object_data_ranges: + for data_range in response.object_data_ranges: + if data_range.range_end: + pending_read_ids.discard( + data_range.read_range.read_id + ) + yield response + + return generator() + + strategy = _ReadResumptionStrategy() + retry_manager = _BidiStreamRetryManager( + strategy, lambda r, s: send_ranges_and_get_bytes(r, s, metadata=metadata) + ) - if object_data_range.range_end: - tmp_dn_ranges_id = self._read_id_to_download_ranges_id[read_id] - self._download_ranges_id_to_pending_read_ids[ - tmp_dn_ranges_id - ].remove(read_id) - del self._read_id_to_download_ranges_id[read_id] + await retry_manager.execute(initial_state, retry_policy) + + if initial_state.get("read_handle"): + self.read_handle = initial_state["read_handle"] async def close(self): """ @@ -337,7 +477,10 @@ async def close(self): """ if not self._is_stream_open: raise ValueError("Underlying bidi-gRPC stream is not open") - await self.read_obj_str.close() + + if self.read_obj_str: + await self.read_obj_str.close() + self.read_obj_str = None self._is_stream_open = False @property diff --git a/google/cloud/storage/_experimental/asyncio/async_read_object_stream.py b/google/cloud/storage/_experimental/asyncio/async_read_object_stream.py index ddaaf9a54..7adcdd1c9 100644 --- a/google/cloud/storage/_experimental/asyncio/async_read_object_stream.py +++ b/google/cloud/storage/_experimental/asyncio/async_read_object_stream.py @@ -22,7 +22,7 @@ """ -from typing import Optional +from typing import List, Optional, Tuple from google.cloud import _storage_v2 from google.cloud.storage._experimental.asyncio.async_grpc_client import AsyncGrpcClient from google.cloud.storage._experimental.asyncio.async_abstract_object_stream import ( @@ -84,32 +84,66 @@ def __init__( self.rpc = self.client._client._transport._wrapped_methods[ self.client._client._transport.bidi_read_object ] - self.first_bidi_read_req = _storage_v2.BidiReadObjectRequest( - read_object_spec=_storage_v2.BidiReadObjectSpec( - bucket=self._full_bucket_name, object=object_name - ), - ) self.metadata = (("x-goog-request-params", f"bucket={self._full_bucket_name}"),) self.socket_like_rpc: Optional[AsyncBidiRpc] = None self._is_stream_open: bool = False + self.persisted_size: Optional[int] = None - async def open(self) -> None: + async def open(self, metadata: Optional[List[Tuple[str, str]]] = None) -> None: """Opens the bidi-gRPC connection to read from the object. This method sends an initial request to start the stream and receives the first response containing metadata and a read handle. + + Args: + metadata (Optional[List[Tuple[str, str]]]): Additional metadata + to send with the initial stream request, e.g., for routing tokens. """ if self._is_stream_open: raise ValueError("Stream is already open") + + read_handle = self.read_handle if self.read_handle else None + + read_object_spec = _storage_v2.BidiReadObjectSpec( + bucket=self._full_bucket_name, + object=self.object_name, + generation=self.generation_number if self.generation_number else None, + read_handle=read_handle, + ) + self.first_bidi_read_req = _storage_v2.BidiReadObjectRequest( + read_object_spec=read_object_spec + ) + + # Build the x-goog-request-params header + request_params = [f"bucket={self._full_bucket_name}"] + other_metadata = [] + if metadata: + for key, value in metadata: + if key == "x-goog-request-params": + request_params.append(value) + else: + other_metadata.append((key, value)) + + current_metadata = other_metadata + current_metadata.append(("x-goog-request-params", ",".join(request_params))) + self.socket_like_rpc = AsyncBidiRpc( - self.rpc, initial_request=self.first_bidi_read_req, metadata=self.metadata + self.rpc, + initial_request=self.first_bidi_read_req, + metadata=current_metadata, ) await self.socket_like_rpc.open() # this is actually 1 send response = await self.socket_like_rpc.recv() - if self.generation_number is None: - self.generation_number = response.metadata.generation + # populated only in the first response of bidi-stream and when opened + # without using `read_handle` + if hasattr(response, "metadata") and response.metadata: + if self.generation_number is None: + self.generation_number = response.metadata.generation + # update persisted size + self.persisted_size = response.metadata.size - self.read_handle = response.read_handle + if response and response.read_handle: + self.read_handle = response.read_handle self._is_stream_open = True diff --git a/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py b/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py index 6d1fd5b31..183a8eeb1 100644 --- a/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py +++ b/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py @@ -117,7 +117,6 @@ async def open(self) -> None: object=self.object_name, generation=self.generation_number, ), - state_lookup=True, ) self.socket_like_rpc = AsyncBidiRpc( @@ -136,11 +135,17 @@ async def open(self) -> None: raise ValueError( "Failed to obtain object generation after opening the stream" ) - self.generation_number = response.resource.generation if not response.write_handle: raise ValueError("Failed to obtain write_handle after opening the stream") + if not response.resource.size: + # Appending to a 0 byte appendable object. + self.persisted_size = 0 + else: + self.persisted_size = response.resource.size + + self.generation_number = response.resource.generation self.write_handle = response.write_handle async def close(self) -> None: diff --git a/google/cloud/storage/_experimental/asyncio/retry/_helpers.py b/google/cloud/storage/_experimental/asyncio/retry/_helpers.py new file mode 100644 index 000000000..627bf5944 --- /dev/null +++ b/google/cloud/storage/_experimental/asyncio/retry/_helpers.py @@ -0,0 +1,83 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Tuple, Optional + +from google.api_core import exceptions +from google.cloud._storage_v2.types import BidiReadObjectRedirectedError +from google.rpc import status_pb2 + +_BIDI_READ_REDIRECTED_TYPE_URL = ( + "type.googleapis.com/google.storage.v2.BidiReadObjectRedirectedError" +) + + +def _handle_redirect( + exc: Exception, +) -> Tuple[Optional[str], Optional[bytes]]: + """ + Extracts routing token and read handle from a gRPC error. + + :type exc: Exception + :param exc: The exception to parse. + + :rtype: Tuple[Optional[str], Optional[bytes]] + :returns: A tuple of (routing_token, read_handle). + """ + routing_token = None + read_handle = None + + grpc_error = None + if isinstance(exc, exceptions.Aborted) and exc.errors: + grpc_error = exc.errors[0] + + if grpc_error: + if isinstance(grpc_error, BidiReadObjectRedirectedError): + routing_token = grpc_error.routing_token + if grpc_error.read_handle: + read_handle = grpc_error.read_handle + return routing_token, read_handle + + if hasattr(grpc_error, "trailing_metadata"): + trailers = grpc_error.trailing_metadata() + if not trailers: + return None, None + + status_details_bin = None + for key, value in trailers: + if key == "grpc-status-details-bin": + status_details_bin = value + break + + if status_details_bin: + status_proto = status_pb2.Status() + try: + status_proto.ParseFromString(status_details_bin) + for detail in status_proto.details: + if detail.type_url == _BIDI_READ_REDIRECTED_TYPE_URL: + redirect_proto = BidiReadObjectRedirectedError.deserialize( + detail.value + ) + if redirect_proto.routing_token: + routing_token = redirect_proto.routing_token + if redirect_proto.read_handle: + read_handle = redirect_proto.read_handle + break + except Exception as e: + logging.ERROR(f"Error unpacking redirect: {e}") + + return routing_token, read_handle diff --git a/google/cloud/storage/_experimental/asyncio/retry/base_strategy.py b/google/cloud/storage/_experimental/asyncio/retry/base_strategy.py index e32125069..ff193f109 100644 --- a/google/cloud/storage/_experimental/asyncio/retry/base_strategy.py +++ b/google/cloud/storage/_experimental/asyncio/retry/base_strategy.py @@ -1,3 +1,17 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import abc from typing import Any, Iterable diff --git a/google/cloud/storage/_experimental/asyncio/retry/bidi_stream_retry_manager.py b/google/cloud/storage/_experimental/asyncio/retry/bidi_stream_retry_manager.py new file mode 100644 index 000000000..a8caae4eb --- /dev/null +++ b/google/cloud/storage/_experimental/asyncio/retry/bidi_stream_retry_manager.py @@ -0,0 +1,69 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, AsyncIterator, Callable + +from google.cloud.storage._experimental.asyncio.retry.base_strategy import ( + _BaseResumptionStrategy, +) + +logger = logging.getLogger(__name__) + + +class _BidiStreamRetryManager: + """Manages the generic retry loop for a bidi streaming operation.""" + + def __init__( + self, + strategy: _BaseResumptionStrategy, + send_and_recv: Callable[..., AsyncIterator[Any]], + ): + """Initializes the retry manager. + Args: + strategy: The strategy for managing the state of a specific + bidi operation (e.g., reads or writes). + send_and_recv: An async callable that opens a new gRPC stream. + """ + self._strategy = strategy + self._send_and_recv = send_and_recv + + async def execute(self, initial_state: Any, retry_policy): + """ + Executes the bidi operation with the configured retry policy. + Args: + initial_state: An object containing all state for the operation. + retry_policy: The `google.api_core.retry.AsyncRetry` object to + govern the retry behavior for this specific operation. + """ + state = initial_state + + async def attempt(): + requests = self._strategy.generate_requests(state) + stream = self._send_and_recv(requests, state) + try: + async for response in stream: + self._strategy.update_state_from_response(response, state) + return + except Exception as e: + if retry_policy._predicate(e): + logger.info( + f"Bidi stream operation failed: {e}. Attempting state recovery and retry." + ) + await self._strategy.recover_state_on_failure(e, state) + raise e + + wrapped_attempt = retry_policy(attempt) + + await wrapped_attempt() diff --git a/google/cloud/storage/_experimental/asyncio/retry/reads_resumption_strategy.py b/google/cloud/storage/_experimental/asyncio/retry/reads_resumption_strategy.py index d5d080358..916b82e6e 100644 --- a/google/cloud/storage/_experimental/asyncio/retry/reads_resumption_strategy.py +++ b/google/cloud/storage/_experimental/asyncio/retry/reads_resumption_strategy.py @@ -1,11 +1,35 @@ -from typing import Any, List, IO +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, IO +import logging + +from google_crc32c import Checksum from google.cloud import _storage_v2 as storage_v2 from google.cloud.storage.exceptions import DataCorruption +from google.cloud.storage._experimental.asyncio.retry._helpers import ( + _handle_redirect, +) from google.cloud.storage._experimental.asyncio.retry.base_strategy import ( _BaseResumptionStrategy, ) -from google.cloud._storage_v2.types.storage import BidiReadObjectRedirectedError + + +_BIDI_READ_REDIRECTED_TYPE_URL = ( + "type.googleapis.com/google.storage.v2.BidiReadObjectRedirectedError" +) +logger = logging.getLogger(__name__) class _DownloadState: @@ -25,7 +49,7 @@ def __init__( class _ReadResumptionStrategy(_BaseResumptionStrategy): """The concrete resumption strategy for bidi reads.""" - def generate_requests(self, state: dict) -> List[storage_v2.ReadRange]: + def generate_requests(self, state: Dict[str, Any]) -> List[storage_v2.ReadRange]: """Generates new ReadRange requests for all incomplete downloads. :type state: dict @@ -33,10 +57,17 @@ def generate_requests(self, state: dict) -> List[storage_v2.ReadRange]: _DownloadState object. """ pending_requests = [] - for read_id, read_state in state.items(): + download_states: Dict[int, _DownloadState] = state["download_states"] + + for read_id, read_state in download_states.items(): if not read_state.is_complete: new_offset = read_state.initial_offset + read_state.bytes_written - new_length = read_state.initial_length - read_state.bytes_written + + # Calculate remaining length. If initial_length is 0 (read to end), + # it stays 0. Otherwise, subtract bytes_written. + new_length = 0 + if read_state.initial_length > 0: + new_length = read_state.initial_length - read_state.bytes_written new_request = storage_v2.ReadRange( read_offset=new_offset, @@ -47,39 +78,80 @@ def generate_requests(self, state: dict) -> List[storage_v2.ReadRange]: return pending_requests def update_state_from_response( - self, response: storage_v2.BidiReadObjectResponse, state: dict + self, response: storage_v2.BidiReadObjectResponse, state: Dict[str, Any] ) -> None: """Processes a server response, performs integrity checks, and updates state.""" + + # Capture read_handle if provided. + if response.read_handle: + state["read_handle"] = response.read_handle + + download_states = state["download_states"] + for object_data_range in response.object_data_ranges: + # Ignore empty ranges or ranges for IDs not in our state + # (e.g., from a previously cancelled request on the same stream). + if not object_data_range.read_range: + logger.warning( + "Received response with missing read_range field; ignoring." + ) + continue + read_id = object_data_range.read_range.read_id - read_state = state[read_id] + if read_id not in download_states: + logger.warning( + f"Received data for unknown or stale read_id {read_id}; ignoring." + ) + continue + + read_state = download_states[read_id] # Offset Verification chunk_offset = object_data_range.read_range.read_offset if chunk_offset != read_state.next_expected_offset: - raise DataCorruption(response, f"Offset mismatch for read_id {read_id}") + raise DataCorruption( + response, + f"Offset mismatch for read_id {read_id}. " + f"Expected {read_state.next_expected_offset}, got {chunk_offset}", + ) + # Checksum Verification + # We must validate data before updating state or writing to buffer. data = object_data_range.checksummed_data.content + server_checksum = object_data_range.checksummed_data.crc32c + + if server_checksum is not None: + client_checksum = int.from_bytes(Checksum(data).digest(), "big") + if server_checksum != client_checksum: + raise DataCorruption( + response, + f"Checksum mismatch for read_id {read_id}. " + f"Server sent {server_checksum}, client calculated {client_checksum}.", + ) + + # Update State & Write Data chunk_size = len(data) + read_state.user_buffer.write(data) read_state.bytes_written += chunk_size read_state.next_expected_offset += chunk_size - read_state.user_buffer.write(data) # Final Byte Count Verification if object_data_range.range_end: read_state.is_complete = True if ( read_state.initial_length != 0 - and read_state.bytes_written != read_state.initial_length + and read_state.bytes_written > read_state.initial_length ): raise DataCorruption( - response, f"Byte count mismatch for read_id {read_id}" + response, + f"Byte count mismatch for read_id {read_id}. " + f"Expected {read_state.initial_length}, got {read_state.bytes_written}", ) async def recover_state_on_failure(self, error: Exception, state: Any) -> None: """Handles BidiReadObjectRedirectedError for reads.""" - # This would parse the gRPC error details, extract the routing_token, - # and store it on the shared state object. - cause = getattr(error, "cause", error) - if isinstance(cause, BidiReadObjectRedirectedError): - state["routing_token"] = cause.routing_token + routing_token, read_handle = _handle_redirect(error) + if routing_token: + state["routing_token"] = routing_token + if read_handle: + state["read_handle"] = read_handle diff --git a/google/cloud/storage/_helpers.py b/google/cloud/storage/_helpers.py index 682f8784d..24f72ad71 100644 --- a/google/cloud/storage/_helpers.py +++ b/google/cloud/storage/_helpers.py @@ -111,10 +111,6 @@ def _virtual_hosted_style_base_url(url, bucket, trailing_slash=False): return base_url -def _use_client_cert(): - return os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE") == "true" - - def _get_environ_project(): return os.getenv( environment_vars.PROJECT, diff --git a/google/cloud/storage/_media/_upload.py b/google/cloud/storage/_media/_upload.py index 765716882..4a919d18a 100644 --- a/google/cloud/storage/_media/_upload.py +++ b/google/cloud/storage/_media/_upload.py @@ -688,6 +688,13 @@ def _prepare_request(self): _CONTENT_TYPE_HEADER: self._content_type, _helpers.CONTENT_RANGE_HEADER: content_range, } + if (start_byte + len(payload) == self._total_bytes) and ( + self._checksum_object is not None + ): + local_checksum = _helpers.prepare_checksum_digest( + self._checksum_object.digest() + ) + headers["x-goog-hash"] = f"{self._checksum_type}={local_checksum}" return _PUT, self.resumable_url, payload, headers def _update_checksum(self, start_byte, payload): diff --git a/google/cloud/storage/_media/requests/download.py b/google/cloud/storage/_media/requests/download.py index b8e2758e1..c5686fcb7 100644 --- a/google/cloud/storage/_media/requests/download.py +++ b/google/cloud/storage/_media/requests/download.py @@ -711,7 +711,7 @@ def __init__(self, checksum): super().__init__() self._checksum = checksum - def decompress(self, data): + def decompress(self, data, max_length=-1): """Decompress the bytes. Args: @@ -721,7 +721,11 @@ def decompress(self, data): bytes: The decompressed bytes from ``data``. """ self._checksum.update(data) - return super().decompress(data) + try: + return super().decompress(data, max_length=max_length) + except TypeError: + # Fallback for urllib3 < 2.6.0 which lacks `max_length` support. + return super().decompress(data) # urllib3.response.BrotliDecoder might not exist depending on whether brotli is @@ -747,7 +751,7 @@ def __init__(self, checksum): self._decoder = urllib3.response.BrotliDecoder() self._checksum = checksum - def decompress(self, data): + def decompress(self, data, max_length=-1): """Decompress the bytes. Args: @@ -757,10 +761,18 @@ def decompress(self, data): bytes: The decompressed bytes from ``data``. """ self._checksum.update(data) - return self._decoder.decompress(data) + try: + return self._decoder.decompress(data, max_length=max_length) + except TypeError: + # Fallback for urllib3 < 2.6.0 which lacks `max_length` support. + return self._decoder.decompress(data) def flush(self): return self._decoder.flush() + @property + def has_unconsumed_tail(self) -> bool: + return self._decoder.has_unconsumed_tail + else: # pragma: NO COVER _BrotliDecoder = None # type: ignore # pragma: NO COVER diff --git a/google/cloud/storage/client.py b/google/cloud/storage/client.py index 3764c7a53..85575f067 100644 --- a/google/cloud/storage/client.py +++ b/google/cloud/storage/client.py @@ -20,11 +20,12 @@ import datetime import functools import json +import os import warnings import google.api_core.client_options from google.auth.credentials import AnonymousCredentials - +from google.auth.transport import mtls from google.api_core import page_iterator from google.cloud._helpers import _LocalStack from google.cloud.client import ClientWithProject @@ -35,7 +36,6 @@ from google.cloud.storage._helpers import _get_api_endpoint_override from google.cloud.storage._helpers import _get_environ_project from google.cloud.storage._helpers import _get_storage_emulator_override -from google.cloud.storage._helpers import _use_client_cert from google.cloud.storage._helpers import _virtual_hosted_style_base_url from google.cloud.storage._helpers import _DEFAULT_UNIVERSE_DOMAIN from google.cloud.storage._helpers import _DEFAULT_SCHEME @@ -218,7 +218,15 @@ def __init__( # The final decision of whether to use mTLS takes place in # google-auth-library-python. We peek at the environment variable # here only to issue an exception in case of a conflict. - if _use_client_cert(): + use_client_cert = False + if hasattr(mtls, "should_use_client_cert"): + use_client_cert = mtls.should_use_client_cert() + else: + use_client_cert = ( + os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE") == "true" + ) + + if use_client_cert: raise ValueError( 'The "GOOGLE_API_USE_CLIENT_CERTIFICATE" env variable is ' 'set to "true" and a non-default universe domain is ' diff --git a/google/cloud/storage/version.py b/google/cloud/storage/version.py index 102b96095..dc87b3c5b 100644 --- a/google/cloud/storage/version.py +++ b/google/cloud/storage/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "3.6.0" +__version__ = "3.7.0" diff --git a/noxfile.py b/noxfile.py index 16cf97b01..4f77d1e2d 100644 --- a/noxfile.py +++ b/noxfile.py @@ -26,9 +26,16 @@ BLACK_VERSION = "black==23.7.0" BLACK_PATHS = ["docs", "google", "tests", "noxfile.py", "setup.py"] -DEFAULT_PYTHON_VERSION = "3.12" -SYSTEM_TEST_PYTHON_VERSIONS = ["3.12"] -UNIT_TEST_PYTHON_VERSIONS = ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] +DEFAULT_PYTHON_VERSION = "3.14" +SYSTEM_TEST_PYTHON_VERSIONS = ["3.9", "3.14"] +UNIT_TEST_PYTHON_VERSIONS = [ + "3.9", + "3.10", + "3.11", + "3.12", + "3.13", + "3.14", +] CONFORMANCE_TEST_PYTHON_VERSIONS = ["3.12"] CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() @@ -51,6 +58,7 @@ "unit-3.11", "unit-3.12", "unit-3.13", + "unit-3.14", # cover must be last to avoid error `No data to report` "cover", ] @@ -65,7 +73,7 @@ def lint(session): """ # Pin flake8 to 6.0.0 # See https://github.com/googleapis/python-storage/issues/1102 - session.install("flake8==6.0.0", BLACK_VERSION) + session.install("flake8", BLACK_VERSION) session.run( "black", "--check", @@ -118,6 +126,8 @@ def default(session, install_extras=True): session.install("-e", ".", "-c", constraints_path) + session.run("python", "-m", "pip", "freeze") + # This dependency is included in setup.py for backwards compatibility only # and the client library is expected to pass all tests without it. See # setup.py and README for details. @@ -215,10 +225,22 @@ def conftest_retry(session): if not conformance_test_folder_exists: session.skip("Conformance tests were not found") + constraints_path = str( + CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" + ) + # Install all test dependencies and pytest plugin to run tests in parallel. # Then install this package in-place. - session.install("pytest", "pytest-xdist") - session.install("-e", ".") + session.install( + "pytest", + "pytest-xdist", + "grpcio", + "grpcio-status", + "grpc-google-iam-v1", + "-c", + constraints_path, + ) + session.install("-e", ".", "-c", constraints_path) # Run #CPU processes in parallel if no test session arguments are passed in. if session.posargs: @@ -232,7 +254,7 @@ def conftest_retry(session): test_cmd = ["py.test", "-n", "auto", "--quiet", conformance_test_folder_path] # Run py.test against the conformance tests. - session.run(*test_cmd) + session.run(*test_cmd, env={"DOCKER_API_VERSION": "1.39"}) @nox.session(python=DEFAULT_PYTHON_VERSION) diff --git a/samples/snippets/snippets_test.py b/samples/snippets/snippets_test.py index 91018f3dd..1d3c8c1c4 100644 --- a/samples/snippets/snippets_test.py +++ b/samples/snippets/snippets_test.py @@ -18,6 +18,7 @@ import tempfile import time import uuid +import sys from google.cloud import storage import google.cloud.exceptions @@ -99,8 +100,10 @@ import storage_upload_with_kms_key KMS_KEY = os.environ.get("CLOUD_KMS_KEY") +IS_PYTHON_3_14 = sys.version_info[:2] == (3, 14) +@pytest.mark.skipif(IS_PYTHON_3_14, reason="b/470276398") def test_enable_default_kms_key(test_bucket): storage_set_bucket_default_kms_key.enable_default_kms_key( bucket_name=test_bucket.name, kms_key_name=KMS_KEY @@ -305,6 +308,7 @@ def test_upload_blob_from_stream(test_bucket, capsys): assert "Stream data uploaded to test_upload_blob" in out +@pytest.mark.skipif(IS_PYTHON_3_14, reason="b/470276398") def test_upload_blob_with_kms(test_bucket): blob_name = f"test_upload_with_kms_{uuid.uuid4().hex}" with tempfile.NamedTemporaryFile() as source_file: @@ -399,6 +403,7 @@ def test_delete_blob(test_blob): storage_delete_file.delete_blob(test_blob.bucket.name, test_blob.name) +@pytest.mark.xfail(reason="wait until b/469643064 is fixed") def test_make_blob_public(test_public_blob): storage_make_public.make_blob_public( test_public_blob.bucket.name, test_public_blob.name @@ -597,6 +602,7 @@ def test_create_bucket_dual_region(test_bucket_create, capsys): assert "dual-region" in out +@pytest.mark.skipif(IS_PYTHON_3_14, reason="b/470276398") def test_bucket_delete_default_kms_key(test_bucket, capsys): test_bucket.default_kms_key_name = KMS_KEY test_bucket.patch() @@ -620,6 +626,7 @@ def test_get_service_account(capsys): assert "@gs-project-accounts.iam.gserviceaccount.com" in out +@pytest.mark.xfail(reason="wait until b/469643064 is fixed") def test_download_public_file(test_public_blob): storage_make_public.make_blob_public( test_public_blob.bucket.name, test_public_blob.name @@ -644,6 +651,7 @@ def test_define_bucket_website_configuration(test_bucket): assert bucket._properties["website"] == website_val +@pytest.mark.skipif(IS_PYTHON_3_14, reason="b/470276398") def test_object_get_kms_key(test_bucket): with tempfile.NamedTemporaryFile() as source_file: storage_upload_with_kms_key.upload_blob_with_kms( diff --git a/setup.py b/setup.py index 2c4504749..b45053856 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,18 @@ "google-crc32c >= 1.1.3, < 2.0.0", ] extras = { + # TODO: Make these extra dependencies as mandatory once gRPC out of + # experimental in this SDK. More info in b/465352227 + "grpc": [ + "google-api-core[grpc] >= 2.27.0, < 3.0.0", + "grpcio >= 1.33.2, < 2.0.0; python_version < '3.14'", + "grpcio >= 1.75.1, < 2.0.0; python_version >= '3.14'", + "grpcio-status >= 1.76.0, < 2.0.0", + "proto-plus >= 1.22.3, <2.0.0; python_version < '3.13'", + "proto-plus >= 1.25.0, <2.0.0; python_version >= '3.13'", + "protobuf>=3.20.2,<7.0.0,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", + "grpc-google-iam-v1 >= 0.14.0, <1.0.0", + ], "protobuf": ["protobuf >= 3.20.2, < 7.0.0"], "tracing": [ "opentelemetry-api >= 1.1.0, < 2.0.0", @@ -94,6 +106,7 @@ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Operating System :: OS Independent", "Topic :: Internet", ], diff --git a/testing/constraints-3.14.txt b/testing/constraints-3.14.txt index 2ae5a677e..62739fc5d 100644 --- a/testing/constraints-3.14.txt +++ b/testing/constraints-3.14.txt @@ -7,7 +7,7 @@ # Then this file should have google-cloud-foo>=1 google-api-core>=2 google-auth>=2 -grpcio>=1 +grpcio>=1.75.1 proto-plus>=1 protobuf>=6 grpc-google-iam-v1>=0 diff --git a/tests/conformance/test_bidi_reads.py b/tests/conformance/test_bidi_reads.py new file mode 100644 index 000000000..4157182cb --- /dev/null +++ b/tests/conformance/test_bidi_reads.py @@ -0,0 +1,266 @@ +import asyncio +import io +import uuid +import grpc +import requests + +from google.api_core import exceptions +from google.auth import credentials as auth_credentials +from google.cloud import _storage_v2 as storage_v2 + +from google.cloud.storage._experimental.asyncio.async_multi_range_downloader import ( + AsyncMultiRangeDownloader, +) + +# --- Configuration --- +PROJECT_NUMBER = "12345" # A dummy project number is fine for the testbench. +GRPC_ENDPOINT = "localhost:8888" +HTTP_ENDPOINT = "http://localhost:9000" +CONTENT_LENGTH = 1024 * 10 # 10 KB + + +def _is_retriable(exc): + """Predicate for identifying retriable errors.""" + return isinstance( + exc, + ( + exceptions.ServiceUnavailable, + exceptions.Aborted, # Required to retry on redirect + exceptions.InternalServerError, + exceptions.ResourceExhausted, + ), + ) + + +async def run_test_scenario( + gapic_client, http_client, bucket_name, object_name, scenario +): + """Runs a single fault-injection test scenario.""" + print(f"\n--- RUNNING SCENARIO: {scenario['name']} ---") + + retry_test_id = None + try: + # 1. Create a Retry Test resource on the testbench. + retry_test_config = { + "instructions": {scenario["method"]: [scenario["instruction"]]}, + "transport": "GRPC", + } + resp = http_client.post(f"{HTTP_ENDPOINT}/retry_test", json=retry_test_config) + resp.raise_for_status() + retry_test_id = resp.json()["id"] + + # 2. Set up downloader and metadata for fault injection. + downloader = await AsyncMultiRangeDownloader.create_mrd( + gapic_client, bucket_name, object_name + ) + fault_injection_metadata = (("x-retry-test-id", retry_test_id),) + + buffer = io.BytesIO() + + # 3. Execute the download and assert the outcome. + try: + await downloader.download_ranges( + [(0, 5 * 1024, buffer), (6 * 1024, 4 * 1024, buffer)], + metadata=fault_injection_metadata, + ) + # If an exception was expected, this line should not be reached. + if scenario["expected_error"] is not None: + raise AssertionError( + f"Expected exception {scenario['expected_error']} was not raised." + ) + + assert len(buffer.getvalue()) == 9 * 1024 + + except scenario["expected_error"] as e: + print(f"Caught expected exception for {scenario['name']}: {e}") + + await downloader.close() + + finally: + # 4. Clean up the Retry Test resource. + if retry_test_id: + http_client.delete(f"{HTTP_ENDPOINT}/retry_test/{retry_test_id}") + + +async def main(): + """Main function to set up resources and run all test scenarios.""" + channel = grpc.aio.insecure_channel(GRPC_ENDPOINT) + creds = auth_credentials.AnonymousCredentials() + transport = storage_v2.services.storage.transports.StorageGrpcAsyncIOTransport( + channel=channel, credentials=creds + ) + gapic_client = storage_v2.StorageAsyncClient(transport=transport) + http_client = requests.Session() + + bucket_name = f"grpc-test-bucket-{uuid.uuid4().hex[:8]}" + object_name = "retry-test-object" + + # Define all test scenarios + test_scenarios = [ + { + "name": "Retry on Service Unavailable (503)", + "method": "storage.objects.get", + "instruction": "return-503", + "expected_error": None, + }, + { + "name": "Retry on 500", + "method": "storage.objects.get", + "instruction": "return-500", + "expected_error": None, + }, + { + "name": "Retry on 504", + "method": "storage.objects.get", + "instruction": "return-504", + "expected_error": None, + }, + { + "name": "Retry on 429", + "method": "storage.objects.get", + "instruction": "return-429", + "expected_error": None, + }, + { + "name": "Smarter Resumption: Retry 503 after partial data", + "method": "storage.objects.get", + "instruction": "return-broken-stream-after-2K", + "expected_error": None, + }, + { + "name": "Retry on BidiReadObjectRedirectedError", + "method": "storage.objects.get", + "instruction": "redirect-send-handle-and-token-tokenval", # Testbench instruction for redirect + "expected_error": None, + }, + ] + + try: + # Create a single bucket and object for all tests to use. + content = b"A" * CONTENT_LENGTH + bucket_resource = storage_v2.Bucket(project=f"projects/{PROJECT_NUMBER}") + create_bucket_request = storage_v2.CreateBucketRequest( + parent="projects/_", bucket_id=bucket_name, bucket=bucket_resource + ) + await gapic_client.create_bucket(request=create_bucket_request) + + write_spec = storage_v2.WriteObjectSpec( + resource=storage_v2.Object( + bucket=f"projects/_/buckets/{bucket_name}", name=object_name + ) + ) + + async def write_req_gen(): + yield storage_v2.WriteObjectRequest( + write_object_spec=write_spec, + checksummed_data={"content": content}, + finish_write=True, + ) + + await gapic_client.write_object(requests=write_req_gen()) + + # Run all defined test scenarios. + for scenario in test_scenarios: + await run_test_scenario( + gapic_client, http_client, bucket_name, object_name, scenario + ) + + # Define and run test scenarios specifically for the open() method + open_test_scenarios = [ + { + "name": "Open: Retry on 503", + "method": "storage.objects.get", + "instruction": "return-503", + "expected_error": None, + }, + { + "name": "Open: Retry on BidiReadObjectRedirectedError", + "method": "storage.objects.get", + "instruction": "redirect-send-handle-and-token-tokenval", + "expected_error": None, + }, + { + "name": "Open: Fail Fast on 401", + "method": "storage.objects.get", + "instruction": "return-401", + "expected_error": exceptions.Unauthorized, + }, + ] + for scenario in open_test_scenarios: + await run_open_test_scenario( + gapic_client, http_client, bucket_name, object_name, scenario + ) + + except Exception: + import traceback + + traceback.print_exc() + finally: + # Clean up the test bucket. + try: + delete_object_req = storage_v2.DeleteObjectRequest( + bucket="projects/_/buckets/" + bucket_name, object=object_name + ) + await gapic_client.delete_object(request=delete_object_req) + + delete_bucket_req = storage_v2.DeleteBucketRequest( + name=f"projects/_/buckets/{bucket_name}" + ) + await gapic_client.delete_bucket(request=delete_bucket_req) + except Exception as e: + print(f"Warning: Cleanup failed: {e}") + + +async def run_open_test_scenario( + gapic_client, http_client, bucket_name, object_name, scenario +): + """Runs a fault-injection test scenario specifically for the open() method.""" + print(f"\n--- RUNNING SCENARIO: {scenario['name']} ---") + + retry_test_id = None + try: + # 1. Create a Retry Test resource on the testbench. + retry_test_config = { + "instructions": {scenario["method"]: [scenario["instruction"]]}, + "transport": "GRPC", + } + resp = http_client.post(f"{HTTP_ENDPOINT}/retry_test", json=retry_test_config) + resp.raise_for_status() + retry_test_id = resp.json()["id"] + print(f"Retry Test created with ID: {retry_test_id}") + + # 2. Set up metadata for fault injection. + fault_injection_metadata = (("x-retry-test-id", retry_test_id),) + + # 3. Execute the open (via create_mrd) and assert the outcome. + try: + downloader = await AsyncMultiRangeDownloader.create_mrd( + gapic_client, + bucket_name, + object_name, + metadata=fault_injection_metadata, + ) + + # If open was successful, perform a simple download to ensure the stream is usable. + buffer = io.BytesIO() + await downloader.download_ranges([(0, 1024, buffer)]) + await downloader.close() + assert len(buffer.getvalue()) == 1024 + + # If an exception was expected, this line should not be reached. + if scenario["expected_error"] is not None: + raise AssertionError( + f"Expected exception {scenario['expected_error']} was not raised." + ) + + except scenario["expected_error"] as e: + print(f"Caught expected exception for {scenario['name']}: {e}") + + finally: + # 4. Clean up the Retry Test resource. + if retry_test_id: + http_client.delete(f"{HTTP_ENDPOINT}/retry_test/{retry_test_id}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/resumable_media/system/requests/test_upload.py b/tests/resumable_media/system/requests/test_upload.py index dd90aa53b..47f4f6003 100644 --- a/tests/resumable_media/system/requests/test_upload.py +++ b/tests/resumable_media/system/requests/test_upload.py @@ -27,7 +27,6 @@ import google.cloud.storage._media.requests as resumable_requests from google.cloud.storage._media import _helpers from .. import utils -from google.cloud.storage._media import _upload from google.cloud.storage.exceptions import InvalidResponse from google.cloud.storage.exceptions import DataCorruption @@ -372,29 +371,6 @@ def test_resumable_upload_with_headers( _resumable_upload_helper(authorized_transport, img_stream, cleanup, headers=headers) -@pytest.mark.parametrize("checksum", ["md5", "crc32c"]) -def test_resumable_upload_with_bad_checksum( - authorized_transport, img_stream, bucket, cleanup, checksum -): - fake_checksum_object = _helpers._get_checksum_object(checksum) - fake_checksum_object.update(b"bad data") - fake_prepared_checksum_digest = _helpers.prepare_checksum_digest( - fake_checksum_object.digest() - ) - with mock.patch.object( - _helpers, "prepare_checksum_digest", return_value=fake_prepared_checksum_digest - ): - with pytest.raises(DataCorruption) as exc_info: - _resumable_upload_helper( - authorized_transport, img_stream, cleanup, checksum=checksum - ) - expected_checksums = {"md5": "1bsd83IYNug8hd+V1ING3Q==", "crc32c": "YQGPxA=="} - expected_message = _upload._UPLOAD_CHECKSUM_MISMATCH_MESSAGE.format( - checksum.upper(), fake_prepared_checksum_digest, expected_checksums[checksum] - ) - assert exc_info.value.args[0] == expected_message - - def test_resumable_upload_bad_chunk_size(authorized_transport, img_stream): blob_name = os.path.basename(img_stream.name) # Create the actual upload object. diff --git a/tests/system/test_notification.py b/tests/system/test_notification.py index 9b631c29b..48c6c4ba8 100644 --- a/tests/system/test_notification.py +++ b/tests/system/test_notification.py @@ -60,13 +60,19 @@ def topic_path(storage_client, topic_name): @pytest.fixture(scope="session") def notification_topic(storage_client, publisher_client, topic_path, no_mtls): _helpers.retry_429(publisher_client.create_topic)(request={"name": topic_path}) - policy = publisher_client.get_iam_policy(request={"resource": topic_path}) - binding = policy.bindings.add() - binding.role = "roles/pubsub.publisher" - binding.members.append( - f"serviceAccount:{storage_client.get_service_account_email()}" - ) - publisher_client.set_iam_policy(request={"resource": topic_path, "policy": policy}) + try: + policy = publisher_client.get_iam_policy(request={"resource": topic_path}) + binding = policy.bindings.add() + binding.role = "roles/pubsub.publisher" + binding.members.append( + f"serviceAccount:{storage_client.get_service_account_email()}" + ) + publisher_client.set_iam_policy( + request={"resource": topic_path, "policy": policy} + ) + yield topic_path + finally: + publisher_client.delete_topic(request={"topic": topic_path}) def test_notification_create_minimal( diff --git a/tests/system/test_zonal.py b/tests/system/test_zonal.py new file mode 100644 index 000000000..d8d20ba36 --- /dev/null +++ b/tests/system/test_zonal.py @@ -0,0 +1,257 @@ +# py standard imports +import os +import uuid +from io import BytesIO + +# python additional imports +import google_crc32c + +import pytest +import gc + +# current library imports +from google.cloud.storage._experimental.asyncio.async_grpc_client import AsyncGrpcClient +from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import ( + AsyncAppendableObjectWriter, + _DEFAULT_FLUSH_INTERVAL_BYTES, +) +from google.cloud.storage._experimental.asyncio.async_multi_range_downloader import ( + AsyncMultiRangeDownloader, +) + + +pytestmark = pytest.mark.skipif( + os.getenv("RUN_ZONAL_SYSTEM_TESTS") != "True", + reason="Zonal system tests need to be explicitly enabled. This helps scheduling tests in Kokoro and Cloud Build.", +) + + +# TODO: replace this with a fixture once zonal bucket creation / deletion +# is supported in grpc client or json client client. +_ZONAL_BUCKET = os.getenv("ZONAL_BUCKET") +_BYTES_TO_UPLOAD = b"dummy_bytes_to_write_read_and_delete_appendable_object" + + +def _get_equal_dist(a: int, b: int) -> tuple[int, int]: + step = (b - a) // 3 + return a + step, a + 2 * step + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "object_size", + [ + 256, # less than _chunk size + 10 * 1024 * 1024, # less than _MAX_BUFFER_SIZE_BYTES + 20 * 1024 * 1024, # greater than _MAX_BUFFER_SIZE + ], +) +@pytest.mark.parametrize( + "attempt_direct_path", + [True, False], +) +async def test_basic_wrd( + storage_client, blobs_to_delete, attempt_direct_path, object_size +): + object_name = f"test_basic_wrd-{str(uuid.uuid4())}" + + # Client instantiation; it cannot be part of fixture because. + # grpc_client's event loop and event loop of coroutine running it + # (i.e. this test) must be same. + # Note: + # 1. @pytest.mark.asyncio ensures new event loop for each test. + # 2. we can keep the same event loop for entire module but that may + # create issues if tests are run in parallel and one test hogs the event + # loop slowing down other tests. + object_data = os.urandom(object_size) + object_checksum = google_crc32c.value(object_data) + grpc_client = AsyncGrpcClient(attempt_direct_path=attempt_direct_path).grpc_client + + writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) + await writer.open() + await writer.append(object_data) + object_metadata = await writer.close(finalize_on_close=True) + assert object_metadata.size == object_size + assert int(object_metadata.checksums.crc32c) == object_checksum + + mrd = AsyncMultiRangeDownloader(grpc_client, _ZONAL_BUCKET, object_name) + buffer = BytesIO() + await mrd.open() + # (0, 0) means read the whole object + await mrd.download_ranges([(0, 0, buffer)]) + await mrd.close() + assert buffer.getvalue() == object_data + assert mrd.persisted_size == object_size + + # Clean up; use json client (i.e. `storage_client` fixture) to delete. + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + del writer + del mrd + gc.collect() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "object_size", + [ + 10, # less than _chunk size, + 10 * 1024 * 1024, # less than _MAX_BUFFER_SIZE_BYTES + 20 * 1024 * 1024, # greater than _MAX_BUFFER_SIZE_BYTES + ], +) +async def test_basic_wrd_in_slices(storage_client, blobs_to_delete, object_size): + object_name = f"test_basic_wrd-{str(uuid.uuid4())}" + + # Client instantiation; it cannot be part of fixture because. + # grpc_client's event loop and event loop of coroutine running it + # (i.e. this test) must be same. + # Note: + # 1. @pytest.mark.asyncio ensures new event loop for each test. + # 2. we can keep the same event loop for entire module but that may + # create issues if tests are run in parallel and one test hogs the event + # loop slowing down other tests. + object_data = os.urandom(object_size) + object_checksum = google_crc32c.value(object_data) + grpc_client = AsyncGrpcClient().grpc_client + + writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) + await writer.open() + mark1, mark2 = _get_equal_dist(0, object_size) + await writer.append(object_data[0:mark1]) + await writer.append(object_data[mark1:mark2]) + await writer.append(object_data[mark2:]) + object_metadata = await writer.close(finalize_on_close=True) + assert object_metadata.size == object_size + assert int(object_metadata.checksums.crc32c) == object_checksum + + mrd = AsyncMultiRangeDownloader(grpc_client, _ZONAL_BUCKET, object_name) + buffer = BytesIO() + await mrd.open() + # (0, 0) means read the whole object + await mrd.download_ranges([(0, 0, buffer)]) + await mrd.close() + assert buffer.getvalue() == object_data + assert mrd.persisted_size == object_size + + # Clean up; use json client (i.e. `storage_client` fixture) to delete. + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + del writer + del mrd + gc.collect() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "flush_interval", + [ + 2 * 1024 * 1024, + 4 * 1024 * 1024, + 8 * 1024 * 1024, + _DEFAULT_FLUSH_INTERVAL_BYTES, + ], +) +async def test_wrd_with_non_default_flush_interval( + storage_client, + blobs_to_delete, + flush_interval, +): + object_name = f"test_basic_wrd-{str(uuid.uuid4())}" + object_size = 9 * 1024 * 1024 + + # Client instantiation; it cannot be part of fixture because. + # grpc_client's event loop and event loop of coroutine running it + # (i.e. this test) must be same. + # Note: + # 1. @pytest.mark.asyncio ensures new event loop for each test. + # 2. we can keep the same event loop for entire module but that may + # create issues if tests are run in parallel and one test hogs the event + # loop slowing down other tests. + object_data = os.urandom(object_size) + object_checksum = google_crc32c.value(object_data) + grpc_client = AsyncGrpcClient().grpc_client + + writer = AsyncAppendableObjectWriter( + grpc_client, + _ZONAL_BUCKET, + object_name, + writer_options={"FLUSH_INTERVAL_BYTES": flush_interval}, + ) + await writer.open() + mark1, mark2 = _get_equal_dist(0, object_size) + await writer.append(object_data[0:mark1]) + await writer.append(object_data[mark1:mark2]) + await writer.append(object_data[mark2:]) + object_metadata = await writer.close(finalize_on_close=True) + assert object_metadata.size == object_size + assert int(object_metadata.checksums.crc32c) == object_checksum + + mrd = AsyncMultiRangeDownloader(grpc_client, _ZONAL_BUCKET, object_name) + buffer = BytesIO() + await mrd.open() + # (0, 0) means read the whole object + await mrd.download_ranges([(0, 0, buffer)]) + await mrd.close() + assert buffer.getvalue() == object_data + assert mrd.persisted_size == object_size + + # Clean up; use json client (i.e. `storage_client` fixture) to delete. + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + del writer + del mrd + gc.collect() + + +@pytest.mark.asyncio +async def test_read_unfinalized_appendable_object(storage_client, blobs_to_delete): + object_name = f"read_unfinalized_appendable_object-{str(uuid.uuid4())[:4]}" + grpc_client = AsyncGrpcClient(attempt_direct_path=True).grpc_client + + writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) + await writer.open() + await writer.append(_BYTES_TO_UPLOAD) + await writer.flush() + + mrd = AsyncMultiRangeDownloader(grpc_client, _ZONAL_BUCKET, object_name) + buffer = BytesIO() + await mrd.open() + assert mrd.persisted_size == len(_BYTES_TO_UPLOAD) + # (0, 0) means read the whole object + await mrd.download_ranges([(0, 0, buffer)]) + await mrd.close() + assert buffer.getvalue() == _BYTES_TO_UPLOAD + + # Clean up; use json client (i.e. `storage_client` fixture) to delete. + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + del writer + del mrd + gc.collect() + + +@pytest.mark.asyncio +async def test_mrd_open_with_read_handle(): + grpc_client = AsyncGrpcClient().grpc_client + object_name = f"test_read_handl-{str(uuid.uuid4())[:4]}" + writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) + await writer.open() + await writer.append(_BYTES_TO_UPLOAD) + await writer.close() + + mrd = AsyncMultiRangeDownloader(grpc_client, _ZONAL_BUCKET, object_name) + await mrd.open() + read_handle = mrd.read_handle + await mrd.close() + + # Open a new MRD using the `read_handle` obtained above + new_mrd = AsyncMultiRangeDownloader( + grpc_client, _ZONAL_BUCKET, object_name, read_handle=read_handle + ) + await new_mrd.open() + # persisted_size not set when opened with read_handle + assert new_mrd.persisted_size is None + buffer = BytesIO() + await new_mrd.download_ranges([(0, 0, buffer)]) + await new_mrd.close() + assert buffer.getvalue() == _BYTES_TO_UPLOAD + del mrd + del new_mrd + gc.collect() diff --git a/tests/unit/asyncio/retry/test_bidi_stream_retry_manager.py b/tests/unit/asyncio/retry/test_bidi_stream_retry_manager.py new file mode 100644 index 000000000..6c837ec5c --- /dev/null +++ b/tests/unit/asyncio/retry/test_bidi_stream_retry_manager.py @@ -0,0 +1,156 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +import pytest +from google.api_core import exceptions +from google.api_core.retry_async import AsyncRetry + +from google.cloud.storage._experimental.asyncio.retry import ( + bidi_stream_retry_manager as manager, +) +from google.cloud.storage._experimental.asyncio.retry import base_strategy + + +def _is_retriable(exc): + return isinstance(exc, exceptions.ServiceUnavailable) + + +DEFAULT_TEST_RETRY = AsyncRetry(predicate=_is_retriable, deadline=1) + + +class TestBidiStreamRetryManager: + @pytest.mark.asyncio + async def test_execute_success_on_first_try(self): + mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy) + + async def mock_send_and_recv(*args, **kwargs): + yield "response_1" + + retry_manager = manager._BidiStreamRetryManager( + strategy=mock_strategy, send_and_recv=mock_send_and_recv + ) + await retry_manager.execute(initial_state={}, retry_policy=DEFAULT_TEST_RETRY) + mock_strategy.generate_requests.assert_called_once() + mock_strategy.update_state_from_response.assert_called_once_with( + "response_1", {} + ) + mock_strategy.recover_state_on_failure.assert_not_called() + + @pytest.mark.asyncio + async def test_execute_success_on_empty_stream(self): + mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy) + + async def mock_send_and_recv(*args, **kwargs): + if False: + yield + + retry_manager = manager._BidiStreamRetryManager( + strategy=mock_strategy, send_and_recv=mock_send_and_recv + ) + await retry_manager.execute(initial_state={}, retry_policy=DEFAULT_TEST_RETRY) + + mock_strategy.generate_requests.assert_called_once() + mock_strategy.update_state_from_response.assert_not_called() + mock_strategy.recover_state_on_failure.assert_not_called() + + @pytest.mark.asyncio + async def test_execute_retries_on_initial_failure_and_succeeds(self): + mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy) + attempt_count = 0 + + async def mock_send_and_recv(*args, **kwargs): + nonlocal attempt_count + attempt_count += 1 + if attempt_count == 1: + raise exceptions.ServiceUnavailable("Service is down") + else: + yield "response_2" + + retry_manager = manager._BidiStreamRetryManager( + strategy=mock_strategy, send_and_recv=mock_send_and_recv + ) + retry_policy = AsyncRetry(predicate=_is_retriable, initial=0.01) + + with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock): + await retry_manager.execute(initial_state={}, retry_policy=retry_policy) + + assert attempt_count == 2 + assert mock_strategy.generate_requests.call_count == 2 + mock_strategy.recover_state_on_failure.assert_called_once() + mock_strategy.update_state_from_response.assert_called_once_with( + "response_2", {} + ) + + @pytest.mark.asyncio + async def test_execute_retries_and_succeeds_mid_stream(self): + """Test retry logic for a stream that fails after yielding some data.""" + mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy) + attempt_count = 0 + # Use a list to simulate stream content for each attempt + stream_content = [ + ["response_1", exceptions.ServiceUnavailable("Service is down")], + ["response_2"], + ] + + async def mock_send_and_recv(*args, **kwargs): + nonlocal attempt_count + content = stream_content[attempt_count] + attempt_count += 1 + for item in content: + if isinstance(item, Exception): + raise item + else: + yield item + + retry_manager = manager._BidiStreamRetryManager( + strategy=mock_strategy, send_and_recv=mock_send_and_recv + ) + retry_policy = AsyncRetry(predicate=_is_retriable, initial=0.01) + + with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock) as mock_sleep: + await retry_manager.execute(initial_state={}, retry_policy=retry_policy) + + assert attempt_count == 2 + mock_sleep.assert_called_once() + + assert mock_strategy.generate_requests.call_count == 2 + mock_strategy.recover_state_on_failure.assert_called_once() + assert mock_strategy.update_state_from_response.call_count == 2 + mock_strategy.update_state_from_response.assert_has_calls( + [ + mock.call("response_1", {}), + mock.call("response_2", {}), + ] + ) + + @pytest.mark.asyncio + async def test_execute_fails_immediately_on_non_retriable_error(self): + mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy) + + async def mock_send_and_recv(*args, **kwargs): + if False: + yield + raise exceptions.PermissionDenied("Auth error") + + retry_manager = manager._BidiStreamRetryManager( + strategy=mock_strategy, send_and_recv=mock_send_and_recv + ) + with pytest.raises(exceptions.PermissionDenied): + await retry_manager.execute( + initial_state={}, retry_policy=DEFAULT_TEST_RETRY + ) + + mock_strategy.recover_state_on_failure.assert_not_called() diff --git a/tests/unit/asyncio/retry/test_reads_resumption_strategy.py b/tests/unit/asyncio/retry/test_reads_resumption_strategy.py index e6b343f86..2ddd87f1f 100644 --- a/tests/unit/asyncio/retry/test_reads_resumption_strategy.py +++ b/tests/unit/asyncio/retry/test_reads_resumption_strategy.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import io import unittest -import pytest +from google_crc32c import Checksum from google.cloud.storage.exceptions import DataCorruption from google.api_core import exceptions @@ -26,6 +27,7 @@ from google.cloud._storage_v2.types.storage import BidiReadObjectRedirectedError _READ_ID = 1 +LOGGER_NAME = "google.cloud.storage._experimental.asyncio.retry.reads_resumption_strategy" class TestDownloadState(unittest.TestCase): @@ -45,14 +47,67 @@ def test_initialization(self): class TestReadResumptionStrategy(unittest.TestCase): + def setUp(self): + self.strategy = _ReadResumptionStrategy() + + self.state = {"download_states": {}, "read_handle": None, "routing_token": None} + + def _add_download(self, read_id, offset=0, length=100, buffer=None): + """Helper to inject a download state into the correct nested location.""" + if buffer is None: + buffer = io.BytesIO() + state = _DownloadState( + initial_offset=offset, initial_length=length, user_buffer=buffer + ) + self.state["download_states"][read_id] = state + return state + + def _create_response( + self, + content, + read_id, + offset, + crc=None, + range_end=False, + handle=None, + has_read_range=True, + ): + """Helper to create a response object.""" + checksummed_data = None + if content is not None: + if crc is None: + c = Checksum(content) + crc = int.from_bytes(c.digest(), "big") + checksummed_data = storage_v2.ChecksummedData(content=content, crc32c=crc) + + read_range = None + if has_read_range: + read_range = storage_v2.ReadRange(read_id=read_id, read_offset=offset) + + read_handle_message = None + if handle: + read_handle_message = storage_v2.BidiReadHandle(handle=handle) + self.state["read_handle"] = handle + + return storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + storage_v2.ObjectRangeData( + checksummed_data=checksummed_data, + read_range=read_range, + range_end=range_end, + ) + ], + read_handle=read_handle_message, + ) + + # --- Request Generation Tests --- + def test_generate_requests_single_incomplete(self): """Test generating a request for a single incomplete download.""" - read_state = _DownloadState(0, 100, io.BytesIO()) + read_state = self._add_download(_READ_ID, offset=0, length=100) read_state.bytes_written = 20 - state = {_READ_ID: read_state} - read_strategy = _ReadResumptionStrategy() - requests = read_strategy.generate_requests(state) + requests = self.strategy.generate_requests(self.state) self.assertEqual(len(requests), 1) self.assertEqual(requests[0].read_offset, 20) @@ -62,173 +117,238 @@ def test_generate_requests_single_incomplete(self): def test_generate_requests_multiple_incomplete(self): """Test generating requests for multiple incomplete downloads.""" read_id2 = 2 - read_state1 = _DownloadState(0, 100, io.BytesIO()) - read_state1.bytes_written = 50 - read_state2 = _DownloadState(200, 100, io.BytesIO()) - state = {_READ_ID: read_state1, read_id2: read_state2} + rs1 = self._add_download(_READ_ID, offset=0, length=100) + rs1.bytes_written = 50 + + self._add_download(read_id2, offset=200, length=100) - read_strategy = _ReadResumptionStrategy() - requests = read_strategy.generate_requests(state) + requests = self.strategy.generate_requests(self.state) self.assertEqual(len(requests), 2) - req1 = next(request for request in requests if request.read_id == _READ_ID) - req2 = next(request for request in requests if request.read_id == read_id2) + requests.sort(key=lambda r: r.read_id) + req1 = requests[0] + req2 = requests[1] + + self.assertEqual(req1.read_id, _READ_ID) self.assertEqual(req1.read_offset, 50) self.assertEqual(req1.read_length, 50) + + self.assertEqual(req2.read_id, read_id2) self.assertEqual(req2.read_offset, 200) self.assertEqual(req2.read_length, 100) + def test_generate_requests_read_to_end_resumption(self): + """Test resumption for 'read to end' (length=0) requests.""" + read_state = self._add_download(_READ_ID, offset=0, length=0) + read_state.bytes_written = 500 + + requests = self.strategy.generate_requests(self.state) + + self.assertEqual(len(requests), 1) + self.assertEqual(requests[0].read_offset, 500) + self.assertEqual(requests[0].read_length, 0) + def test_generate_requests_with_complete(self): """Test that no request is generated for a completed download.""" - read_state = _DownloadState(0, 100, io.BytesIO()) + read_state = self._add_download(_READ_ID) read_state.is_complete = True - state = {_READ_ID: read_state} - - read_strategy = _ReadResumptionStrategy() - requests = read_strategy.generate_requests(state) + requests = self.strategy.generate_requests(self.state) self.assertEqual(len(requests), 0) + def test_generate_requests_multiple_mixed_states(self): + """Test generating requests with mixed complete, partial, and fresh states.""" + s1 = self._add_download(1, length=100) + s1.is_complete = True + + s2 = self._add_download(2, offset=0, length=100) + s2.bytes_written = 50 + + s3 = self._add_download(3, offset=200, length=100) + s3.bytes_written = 0 + + requests = self.strategy.generate_requests(self.state) + + self.assertEqual(len(requests), 2) + requests.sort(key=lambda r: r.read_id) + + self.assertEqual(requests[0].read_id, 2) + self.assertEqual(requests[1].read_id, 3) + def test_generate_requests_empty_state(self): """Test generating requests with an empty state.""" - read_strategy = _ReadResumptionStrategy() - requests = read_strategy.generate_requests({}) + requests = self.strategy.generate_requests(self.state) self.assertEqual(len(requests), 0) + # --- Update State and response processing Tests --- + def test_update_state_processes_single_chunk_successfully(self): """Test updating state from a successful response.""" - buffer = io.BytesIO() - read_state = _DownloadState(0, 100, buffer) - state = {_READ_ID: read_state} + read_state = self._add_download(_READ_ID, offset=0, length=100) data = b"test_data" - read_strategy = _ReadResumptionStrategy() - response = storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - storage_v2.types.ObjectRangeData( - read_range=storage_v2.ReadRange( - read_id=_READ_ID, read_offset=0, read_length=len(data) - ), - checksummed_data=storage_v2.ChecksummedData(content=data), - ) - ] - ) + response = self._create_response(data, _READ_ID, offset=0) - read_strategy.update_state_from_response(response, state) + self.strategy.update_state_from_response(response, self.state) self.assertEqual(read_state.bytes_written, len(data)) self.assertEqual(read_state.next_expected_offset, len(data)) self.assertFalse(read_state.is_complete) - self.assertEqual(buffer.getvalue(), data) + self.assertEqual(read_state.user_buffer.getvalue(), data) + + def test_update_state_accumulates_chunks(self): + """Verify that state updates correctly over multiple chunks.""" + read_state = self._add_download(_READ_ID, offset=0, length=8) + + resp1 = self._create_response(b"test", _READ_ID, offset=0) + self.strategy.update_state_from_response(resp1, self.state) + + self.assertEqual(read_state.bytes_written, 4) + self.assertEqual(read_state.user_buffer.getvalue(), b"test") + + resp2 = self._create_response(b"data", _READ_ID, offset=4, range_end=True) + self.strategy.update_state_from_response(resp2, self.state) + + self.assertEqual(read_state.bytes_written, 8) + self.assertTrue(read_state.is_complete) + self.assertEqual(read_state.user_buffer.getvalue(), b"testdata") + + def test_update_state_captures_read_handle(self): + """Verify read_handle is extracted from the response.""" + self._add_download(_READ_ID) + + new_handle = b"optimized_handle" + response = self._create_response(b"data", _READ_ID, 0, handle=new_handle) - def test_update_state_from_response_offset_mismatch(self): + self.strategy.update_state_from_response(response, self.state) + self.assertEqual(self.state["read_handle"].handle, new_handle) + + def test_update_state_unknown_id(self): + """Verify we ignore data for IDs not in our tracking state.""" + self._add_download(_READ_ID) + response = self._create_response(b"ghost", read_id=999, offset=0) + + self.strategy.update_state_from_response(response, self.state) + self.assertEqual(self.state["download_states"][_READ_ID].bytes_written, 0) + + def test_update_state_missing_read_range(self): + """Verify we ignore ranges without read_range metadata.""" + response = self._create_response(b"data", _READ_ID, 0, has_read_range=False) + self.strategy.update_state_from_response(response, self.state) + + def test_update_state_offset_mismatch(self): """Test that an offset mismatch raises DataCorruption.""" - read_state = _DownloadState(0, 100, io.BytesIO()) + read_state = self._add_download(_READ_ID, offset=0) read_state.next_expected_offset = 10 - state = {_READ_ID: read_state} - read_strategy = _ReadResumptionStrategy() - response = storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - storage_v2.types.ObjectRangeData( - read_range=storage_v2.ReadRange( - read_id=_READ_ID, read_offset=0, read_length=4 - ), - checksummed_data=storage_v2.ChecksummedData(content=b"data"), - ) - ] - ) + response = self._create_response(b"data", _READ_ID, offset=0) - with pytest.raises(DataCorruption) as exc_info: - read_strategy.update_state_from_response(response, state) - assert "Offset mismatch" in str(exc_info.value) + with self.assertRaisesRegex(DataCorruption, "Offset mismatch"): + self.strategy.update_state_from_response(response, self.state) - def test_update_state_from_response_final_byte_count_mismatch(self): - """Test that a final byte count mismatch raises DataCorruption.""" - read_state = _DownloadState(0, 100, io.BytesIO()) - state = {_READ_ID: read_state} - read_strategy = _ReadResumptionStrategy() + def test_update_state_checksum_mismatch(self): + """Test that a CRC32C mismatch raises DataCorruption.""" + self._add_download(_READ_ID) + response = self._create_response(b"data", _READ_ID, offset=0, crc=999999) - response = storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - storage_v2.types.ObjectRangeData( - read_range=storage_v2.ReadRange( - read_id=_READ_ID, read_offset=0, read_length=4 - ), - checksummed_data=storage_v2.ChecksummedData(content=b"data"), - range_end=True, - ) - ] - ) + with self.assertRaisesRegex(DataCorruption, "Checksum mismatch"): + self.strategy.update_state_from_response(response, self.state) + + def test_update_state_final_byte_count_mismatch(self): + """Test mismatch between expected length and actual bytes written on completion.""" + self._add_download(_READ_ID, length=100) - with pytest.raises(DataCorruption) as exc_info: - read_strategy.update_state_from_response(response, state) - assert "Byte count mismatch" in str(exc_info.value) + data = b"data" * 30 + response = self._create_response(data, _READ_ID, offset=0, range_end=True) - def test_update_state_from_response_completes_download(self): + with self.assertRaisesRegex(DataCorruption, "Byte count mismatch"): + self.strategy.update_state_from_response(response, self.state) + + def test_update_state_completes_download(self): """Test that the download is marked complete on range_end.""" - buffer = io.BytesIO() data = b"test_data" - read_state = _DownloadState(0, len(data), buffer) - state = {_READ_ID: read_state} - read_strategy = _ReadResumptionStrategy() + read_state = self._add_download(_READ_ID, length=len(data)) - response = storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - storage_v2.types.ObjectRangeData( - read_range=storage_v2.ReadRange( - read_id=_READ_ID, read_offset=0, read_length=len(data) - ), - checksummed_data=storage_v2.ChecksummedData(content=data), - range_end=True, - ) - ] - ) + response = self._create_response(data, _READ_ID, offset=0, range_end=True) - read_strategy.update_state_from_response(response, state) + self.strategy.update_state_from_response(response, self.state) self.assertTrue(read_state.is_complete) self.assertEqual(read_state.bytes_written, len(data)) - self.assertEqual(buffer.getvalue(), data) - def test_update_state_from_response_completes_download_zero_length(self): + def test_update_state_completes_download_zero_length(self): """Test completion for a download with initial_length of 0.""" - buffer = io.BytesIO() + read_state = self._add_download(_READ_ID, length=0) data = b"test_data" - read_state = _DownloadState(0, 0, buffer) - state = {_READ_ID: read_state} - read_strategy = _ReadResumptionStrategy() - response = storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - storage_v2.types.ObjectRangeData( - read_range=storage_v2.ReadRange( - read_id=_READ_ID, read_offset=0, read_length=len(data) - ), - checksummed_data=storage_v2.ChecksummedData(content=data), - range_end=True, - ) - ] - ) + response = self._create_response(data, _READ_ID, offset=0, range_end=True) - read_strategy.update_state_from_response(response, state) + self.strategy.update_state_from_response(response, self.state) self.assertTrue(read_state.is_complete) self.assertEqual(read_state.bytes_written, len(data)) - async def test_recover_state_on_failure_handles_redirect(self): + def test_update_state_zero_byte_file(self): + """Test downloading a completely empty file.""" + read_state = self._add_download(_READ_ID, length=0) + + response = self._create_response(b"", _READ_ID, offset=0, range_end=True) + + self.strategy.update_state_from_response(response, self.state) + + self.assertTrue(read_state.is_complete) + self.assertEqual(read_state.bytes_written, 0) + self.assertEqual(read_state.user_buffer.getvalue(), b"") + + def test_update_state_missing_read_range_logs_warning(self): + """Verify we log a warning and continue when read_range is missing.""" + response = self._create_response(b"data", _READ_ID, 0, has_read_range=False) + + # assertLogs captures logs for the given logger name and minimum level + with self.assertLogs(LOGGER_NAME, level="WARNING") as cm: + self.strategy.update_state_from_response(response, self.state) + + self.assertTrue(any("missing read_range field" in output for output in cm.output)) + + def test_update_state_unknown_id_logs_warning(self): + """Verify we log a warning and continue when read_id is unknown.""" + unknown_id = 999 + self._add_download(_READ_ID) + response = self._create_response(b"ghost", read_id=unknown_id, offset=0) + + with self.assertLogs(LOGGER_NAME, level="WARNING") as cm: + self.strategy.update_state_from_response(response, self.state) + + self.assertTrue(any(f"unknown or stale read_id {unknown_id}" in output for output in cm.output)) + + + # --- Recovery Tests --- + + def test_recover_state_on_failure_handles_redirect(self): """Verify recover_state_on_failure correctly extracts routing_token.""" - strategy = _ReadResumptionStrategy() + token = "dummy-routing-token" + redirect_error = BidiReadObjectRedirectedError(routing_token=token) + final_error = exceptions.Aborted("Retry failed", errors=[redirect_error]) + + async def run(): + await self.strategy.recover_state_on_failure(final_error, self.state) + + asyncio.new_event_loop().run_until_complete(run()) + + self.assertEqual(self.state["routing_token"], token) - state = {} - self.assertIsNone(state.get("routing_token")) + def test_recover_state_ignores_standard_errors(self): + """Verify that non-redirect errors do not corrupt the routing token.""" + self.state["routing_token"] = "existing-token" - dummy_token = "dummy-routing-token" - redirect_error = BidiReadObjectRedirectedError(routing_token=dummy_token) + std_error = exceptions.ServiceUnavailable("Maintenance") + final_error = exceptions.RetryError("Retry failed", cause=std_error) - final_error = exceptions.RetryError("Retry failed", cause=redirect_error) + async def run(): + await self.strategy.recover_state_on_failure(final_error, self.state) - await strategy.recover_state_on_failure(final_error, state) + asyncio.new_event_loop().run_until_complete(run()) - self.assertEqual(state.get("routing_token"), dummy_token) + # Token should remain unchanged + self.assertEqual(self.state["routing_token"], "existing-token") diff --git a/tests/unit/asyncio/test_async_appendable_object_writer.py b/tests/unit/asyncio/test_async_appendable_object_writer.py index a75824f8b..31013f9a7 100644 --- a/tests/unit/asyncio/test_async_appendable_object_writer.py +++ b/tests/unit/asyncio/test_async_appendable_object_writer.py @@ -12,12 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from io import BytesIO import pytest from unittest import mock +from google_crc32c import Checksum + +from google.api_core import exceptions from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import ( AsyncAppendableObjectWriter, ) +from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import ( + _MAX_CHUNK_SIZE_BYTES, + _DEFAULT_FLUSH_INTERVAL_BYTES, +) from google.cloud import _storage_v2 @@ -26,6 +34,7 @@ GENERATION = 123 WRITE_HANDLE = b"test-write-handle" PERSISTED_SIZE = 456 +EIGHT_MIB = 8 * 1024 * 1024 @pytest.fixture @@ -49,6 +58,7 @@ def test_init(mock_write_object_stream, mock_client): assert not writer._is_stream_open assert writer.offset is None assert writer.persisted_size is None + assert writer.bytes_appended_since_last_flush == 0 mock_write_object_stream.assert_called_once_with( client=mock_client, @@ -75,6 +85,7 @@ def test_init_with_optional_args(mock_write_object_stream, mock_client): assert writer.generation == GENERATION assert writer.write_handle == WRITE_HANDLE + assert writer.bytes_appended_since_last_flush == 0 mock_write_object_stream.assert_called_once_with( client=mock_client, @@ -85,6 +96,77 @@ def test_init_with_optional_args(mock_write_object_stream, mock_client): ) +@mock.patch( + "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" +) +def test_init_with_writer_options(mock_write_object_stream, mock_client): + """Test the constructor with optional arguments.""" + writer = AsyncAppendableObjectWriter( + mock_client, + BUCKET, + OBJECT, + writer_options={"FLUSH_INTERVAL_BYTES": EIGHT_MIB}, + ) + + assert writer.flush_interval == EIGHT_MIB + assert writer.bytes_appended_since_last_flush == 0 + + mock_write_object_stream.assert_called_once_with( + client=mock_client, + bucket_name=BUCKET, + object_name=OBJECT, + generation_number=None, + write_handle=None, + ) + + +@mock.patch( + "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" +) +def test_init_with_flush_interval_less_than_chunk_size_raises_error(mock_client): + """Test that an OutOfRange error is raised if flush_interval is less than the chunk size.""" + + with pytest.raises(exceptions.OutOfRange): + AsyncAppendableObjectWriter( + mock_client, + BUCKET, + OBJECT, + writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES - 1}, + ) + + +@mock.patch( + "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" +) +def test_init_with_flush_interval_not_multiple_of_chunk_size_raises_error(mock_client): + """Test that an OutOfRange error is raised if flush_interval is not a multiple of the chunk size.""" + + with pytest.raises(exceptions.OutOfRange): + AsyncAppendableObjectWriter( + mock_client, + BUCKET, + OBJECT, + writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES + 1}, + ) + + +@mock.patch("google.cloud.storage._experimental.asyncio._utils.google_crc32c") +@mock.patch( + "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" +) +def test_init_raises_if_crc32c_c_extension_is_missing( + mock_grpc_client, mock_google_crc32c +): + mock_google_crc32c.implementation = "python" + + with pytest.raises(exceptions.FailedPrecondition) as exc_info: + AsyncAppendableObjectWriter(mock_grpc_client, "bucket", "object") + + assert "The google-crc32c package is not installed with C support" in str( + exc_info.value + ) + + @pytest.mark.asyncio @mock.patch( "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" @@ -133,15 +215,10 @@ async def test_open_appendable_object_writer(mock_write_object_stream, mock_clie writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) mock_stream = mock_write_object_stream.return_value mock_stream.open = mock.AsyncMock() - mock_stream.send = mock.AsyncMock() - mock_stream.recv = mock.AsyncMock() - - mock_state_response = mock.MagicMock() - mock_state_response.persisted_size = 1024 - mock_stream.recv.return_value = mock_state_response mock_stream.generation_number = GENERATION mock_stream.write_handle = WRITE_HANDLE + mock_stream.persisted_size = 0 # Act await writer.open() @@ -151,11 +228,37 @@ async def test_open_appendable_object_writer(mock_write_object_stream, mock_clie assert writer._is_stream_open assert writer.generation == GENERATION assert writer.write_handle == WRITE_HANDLE + assert writer.persisted_size == 0 - expected_request = _storage_v2.BidiWriteObjectRequest(state_lookup=True) - mock_stream.send.assert_awaited_once_with(expected_request) - mock_stream.recv.assert_awaited_once() - assert writer.persisted_size == 1024 + +@pytest.mark.asyncio +@mock.patch( + "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" +) +async def test_open_appendable_object_writer_existing_object( + mock_write_object_stream, mock_client +): + """Test the open method.""" + # Arrange + writer = AsyncAppendableObjectWriter( + mock_client, BUCKET, OBJECT, generation=GENERATION + ) + mock_stream = mock_write_object_stream.return_value + mock_stream.open = mock.AsyncMock() + + mock_stream.generation_number = GENERATION + mock_stream.write_handle = WRITE_HANDLE + mock_stream.persisted_size = PERSISTED_SIZE + + # Act + await writer.open() + + # Assert + mock_stream.open.assert_awaited_once() + assert writer._is_stream_open + assert writer.generation == GENERATION + assert writer.write_handle == WRITE_HANDLE + assert writer.persisted_size == PERSISTED_SIZE @pytest.mark.asyncio @@ -186,9 +289,6 @@ async def test_unimplemented_methods_raise_error(mock_client): with pytest.raises(NotImplementedError): await writer.append_from_stream(mock.Mock()) - with pytest.raises(NotImplementedError): - await writer.append_from_file("file.txt") - @pytest.mark.asyncio @mock.patch( @@ -313,7 +413,7 @@ async def test_finalize_on_close(mock_write_object_stream, mock_client): result = await writer.close(finalize_on_close=True) # Assert - mock_stream.close.assert_not_awaited() # Based on new implementation + mock_stream.close.assert_awaited_once() assert not writer._is_stream_open assert writer.offset is None assert writer.object_resource == mock_resource @@ -413,10 +513,15 @@ async def test_append_sends_data_in_chunks(mock_write_object_stream, mock_client # First chunk assert first_call[0][0].write_offset == 100 assert len(first_call[0][0].checksummed_data.content) == _MAX_CHUNK_SIZE_BYTES - + assert first_call[0][0].checksummed_data.crc32c == int.from_bytes( + Checksum(data[:_MAX_CHUNK_SIZE_BYTES]).digest(), byteorder="big" + ) # Second chunk assert second_call[0][0].write_offset == 100 + _MAX_CHUNK_SIZE_BYTES assert len(second_call[0][0].checksummed_data.content) == 1 + assert second_call[0][0].checksummed_data.crc32c == int.from_bytes( + Checksum(data[_MAX_CHUNK_SIZE_BYTES:]).digest(), byteorder="big" + ) assert writer.offset == 100 + len(data) writer.simple_flush.assert_not_awaited() @@ -430,9 +535,6 @@ async def test_append_flushes_when_buffer_is_full( mock_write_object_stream, mock_client ): """Test that append flushes the stream when the buffer size is reached.""" - from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import ( - _MAX_BUFFER_SIZE_BYTES, - ) writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) writer._is_stream_open = True @@ -441,7 +543,7 @@ async def test_append_flushes_when_buffer_is_full( mock_stream.send = mock.AsyncMock() writer.simple_flush = mock.AsyncMock() - data = b"a" * _MAX_BUFFER_SIZE_BYTES + data = b"a" * _DEFAULT_FLUSH_INTERVAL_BYTES await writer.append(data) writer.simple_flush.assert_awaited_once() @@ -453,9 +555,6 @@ async def test_append_flushes_when_buffer_is_full( ) async def test_append_handles_large_data(mock_write_object_stream, mock_client): """Test that append handles data larger than the buffer size.""" - from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import ( - _MAX_BUFFER_SIZE_BYTES, - ) writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) writer._is_stream_open = True @@ -464,7 +563,7 @@ async def test_append_handles_large_data(mock_write_object_stream, mock_client): mock_stream.send = mock.AsyncMock() writer.simple_flush = mock.AsyncMock() - data = b"a" * (_MAX_BUFFER_SIZE_BYTES * 2 + 1) + data = b"a" * (_DEFAULT_FLUSH_INTERVAL_BYTES * 2 + 1) await writer.append(data) assert writer.simple_flush.await_count == 2 @@ -496,3 +595,32 @@ async def test_append_data_two_times(mock_write_object_stream, mock_client): total_data_length = len(data1) + len(data2) assert writer.offset == total_data_length assert writer.simple_flush.await_count == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "file_size, block_size", + [ + (10, 4 * 1024), + (0, _DEFAULT_FLUSH_INTERVAL_BYTES), + (20 * 1024 * 1024, _DEFAULT_FLUSH_INTERVAL_BYTES), + (16 * 1024 * 1024, _DEFAULT_FLUSH_INTERVAL_BYTES), + ], +) +async def test_append_from_file(file_size, block_size, mock_client): + # arrange + fp = BytesIO(b"a" * file_size) + writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) + writer._is_stream_open = True + writer.append = mock.AsyncMock() + + # act + await writer.append_from_file(fp, block_size=block_size) + + # assert + exepected_calls = ( + file_size // block_size + if file_size % block_size == 0 + else file_size // block_size + 1 + ) + assert writer.append.await_count == exepected_calls diff --git a/tests/unit/asyncio/test_async_multi_range_downloader.py b/tests/unit/asyncio/test_async_multi_range_downloader.py index 668006627..2f0600f8d 100644 --- a/tests/unit/asyncio/test_async_multi_range_downloader.py +++ b/tests/unit/asyncio/test_async_multi_range_downloader.py @@ -30,6 +30,7 @@ _TEST_BUCKET_NAME = "test-bucket" _TEST_OBJECT_NAME = "test-object" +_TEST_OBJECT_SIZE = 1024 * 1024 # 1 MiB _TEST_GENERATION_NUMBER = 123456789 _TEST_READ_HANDLE = b"test-handle" @@ -38,9 +39,7 @@ class TestAsyncMultiRangeDownloader: def create_read_ranges(self, num_ranges): ranges = [] for i in range(num_ranges): - ranges.append( - _storage_v2.ReadRange(read_offset=i, read_length=1, read_id=i) - ) + ranges.append((i, 1, BytesIO())) return ranges # helper method @@ -57,6 +56,7 @@ async def _make_mock_mrd( mock_stream = mock_cls_async_read_object_stream.return_value mock_stream.open = AsyncMock() mock_stream.generation_number = _TEST_GENERATION_NUMBER + mock_stream.persisted_size = _TEST_OBJECT_SIZE mock_stream.read_handle = _TEST_READ_HANDLE mrd = await AsyncMultiRangeDownloader.create_mrd( @@ -89,16 +89,6 @@ async def test_create_mrd( read_handle=_TEST_READ_HANDLE, ) - mrd.read_obj_str.open.assert_called_once() - # Assert - mock_cls_async_read_object_stream.assert_called_once_with( - client=mock_grpc_client, - bucket_name=_TEST_BUCKET_NAME, - object_name=_TEST_OBJECT_NAME, - generation_number=_TEST_GENERATION_NUMBER, - read_handle=_TEST_READ_HANDLE, - ) - mrd.read_obj_str.open.assert_called_once() assert mrd.client == mock_grpc_client @@ -106,6 +96,7 @@ async def test_create_mrd( assert mrd.object_name == _TEST_OBJECT_NAME assert mrd.generation_number == _TEST_GENERATION_NUMBER assert mrd.read_handle == _TEST_READ_HANDLE + assert mrd.persisted_size == _TEST_OBJECT_SIZE assert mrd.is_stream_open @mock.patch( @@ -132,7 +123,9 @@ async def test_download_ranges_via_async_gather( mock_mrd = await self._make_mock_mrd( mock_grpc_client, mock_cls_async_read_object_stream ) - mock_random_int.side_effect = [123, 456, 789, 91011] # for _func_id and read_id + + mock_random_int.side_effect = [456, 91011] + mock_mrd.read_obj_str.send = AsyncMock() mock_mrd.read_obj_str.recv = AsyncMock() @@ -164,12 +157,14 @@ async def test_download_ranges_via_async_gather( ) ], ), + None, ] # Act buffer = BytesIO() second_buffer = BytesIO() lock = asyncio.Lock() + task1 = asyncio.create_task(mock_mrd.download_ranges([(0, 18, buffer)], lock)) task2 = asyncio.create_task( mock_mrd.download_ranges([(10, 6, second_buffer)], lock) @@ -177,18 +172,6 @@ async def test_download_ranges_via_async_gather( await asyncio.gather(task1, task2) # Assert - mock_mrd.read_obj_str.send.side_effect = [ - _storage_v2.BidiReadObjectRequest( - read_ranges=[ - _storage_v2.ReadRange(read_offset=0, read_length=18, read_id=456) - ] - ), - _storage_v2.BidiReadObjectRequest( - read_ranges=[ - _storage_v2.ReadRange(read_offset=10, read_length=6, read_id=91011) - ] - ), - ] assert buffer.getvalue() == data assert second_buffer.getvalue() == data[10:16] @@ -213,22 +196,27 @@ async def test_download_ranges( mock_mrd = await self._make_mock_mrd( mock_grpc_client, mock_cls_async_read_object_stream ) - mock_random_int.side_effect = [123, 456] # for _func_id and read_id + + mock_random_int.side_effect = [456] + mock_mrd.read_obj_str.send = AsyncMock() mock_mrd.read_obj_str.recv = AsyncMock() - mock_mrd.read_obj_str.recv.return_value = _storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - _storage_v2.ObjectRangeData( - checksummed_data=_storage_v2.ChecksummedData( - content=data, crc32c=crc32c_int - ), - range_end=True, - read_range=_storage_v2.ReadRange( - read_offset=0, read_length=18, read_id=456 - ), - ) - ], - ) + mock_mrd.read_obj_str.recv.side_effect = [ + _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData( + content=data, crc32c=crc32c_int + ), + range_end=True, + read_range=_storage_v2.ReadRange( + read_offset=0, read_length=18, read_id=456 + ), + ) + ], + ), + None, + ] # Act buffer = BytesIO() @@ -317,7 +305,6 @@ async def test_close_mrd_not_opened_should_throw_error(self, mock_grpc_client): mrd = AsyncMultiRangeDownloader( mock_grpc_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME ) - # Act + Assert with pytest.raises(ValueError) as exc: await mrd.close() @@ -346,9 +333,7 @@ async def test_downloading_without_opening_should_throw_error( assert str(exc.value) == "Underlying bidi-gRPC stream is not open" assert not mrd.is_stream_open - @mock.patch( - "google.cloud.storage._experimental.asyncio.async_multi_range_downloader.google_crc32c" - ) + @mock.patch("google.cloud.storage._experimental.asyncio._utils.google_crc32c") @mock.patch( "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" ) @@ -357,7 +342,7 @@ def test_init_raises_if_crc32c_c_extension_is_missing( ): mock_google_crc32c.implementation = "python" - with pytest.raises(exceptions.NotFound) as exc_info: + with pytest.raises(exceptions.FailedPrecondition) as exc_info: AsyncMultiRangeDownloader(mock_grpc_client, "bucket", "object") assert "The google-crc32c package is not installed with C support" in str( @@ -366,7 +351,7 @@ def test_init_raises_if_crc32c_c_extension_is_missing( @pytest.mark.asyncio @mock.patch( - "google.cloud.storage._experimental.asyncio.async_multi_range_downloader.Checksum" + "google.cloud.storage._experimental.asyncio.retry.reads_resumption_strategy.Checksum" ) @mock.patch( "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" @@ -374,6 +359,10 @@ def test_init_raises_if_crc32c_c_extension_is_missing( async def test_download_ranges_raises_on_checksum_mismatch( self, mock_client, mock_checksum_class ): + from google.cloud.storage._experimental.asyncio.async_multi_range_downloader import ( + AsyncMultiRangeDownloader, + ) + mock_stream = mock.AsyncMock( spec=async_read_object_stream._AsyncReadObjectStream ) @@ -389,7 +378,9 @@ async def test_download_ranges_raises_on_checksum_mismatch( checksummed_data=_storage_v2.ChecksummedData( content=test_data, crc32c=server_checksum ), - read_range=_storage_v2.ReadRange(read_id=0), + read_range=_storage_v2.ReadRange( + read_id=0, read_offset=0, read_length=len(test_data) + ), range_end=True, ) ] @@ -402,7 +393,11 @@ async def test_download_ranges_raises_on_checksum_mismatch( mrd._is_stream_open = True with pytest.raises(DataCorruption) as exc_info: - await mrd.download_ranges([(0, len(test_data), BytesIO())]) + with mock.patch( + "google.cloud.storage._experimental.asyncio.async_multi_range_downloader.generate_random_56_bit_integer", + return_value=0, + ): + await mrd.download_ranges([(0, len(test_data), BytesIO())]) assert "Checksum mismatch" in str(exc_info.value) mock_checksum_class.assert_called_once_with(test_data) diff --git a/tests/unit/asyncio/test_async_read_object_stream.py b/tests/unit/asyncio/test_async_read_object_stream.py index 4e4c93dd3..4ba8d34a1 100644 --- a/tests/unit/asyncio/test_async_read_object_stream.py +++ b/tests/unit/asyncio/test_async_read_object_stream.py @@ -25,7 +25,9 @@ _TEST_BUCKET_NAME = "test-bucket" _TEST_OBJECT_NAME = "test-object" _TEST_GENERATION_NUMBER = 12345 +_TEST_OBJECT_SIZE = 1024 * 1024 # 1 MiB _TEST_READ_HANDLE = b"test-read-handle" +_TEST_READ_HANDLE_NEW = b"test-read-handle-new" async def instantiate_read_obj_stream(mock_client, mock_cls_async_bidi_rpc, open=True): @@ -37,6 +39,7 @@ async def instantiate_read_obj_stream(mock_client, mock_cls_async_bidi_rpc, open recv_response = mock.MagicMock(spec=_storage_v2.BidiReadObjectResponse) recv_response.metadata = mock.MagicMock(spec=_storage_v2.Object) recv_response.metadata.generation = _TEST_GENERATION_NUMBER + recv_response.metadata.size = _TEST_OBJECT_SIZE recv_response.read_handle = _TEST_READ_HANDLE socket_like_rpc.recv = AsyncMock(return_value=recv_response) @@ -52,6 +55,30 @@ async def instantiate_read_obj_stream(mock_client, mock_cls_async_bidi_rpc, open return read_obj_stream +async def instantiate_read_obj_stream_with_read_handle( + mock_client, mock_cls_async_bidi_rpc, open=True +): + """Helper to create an instance of _AsyncReadObjectStream and open it by default.""" + socket_like_rpc = AsyncMock() + mock_cls_async_bidi_rpc.return_value = socket_like_rpc + socket_like_rpc.open = AsyncMock() + + recv_response = mock.MagicMock(spec=_storage_v2.BidiReadObjectResponse) + recv_response.read_handle = _TEST_READ_HANDLE_NEW + socket_like_rpc.recv = AsyncMock(return_value=recv_response) + + read_obj_stream = _AsyncReadObjectStream( + client=mock_client, + bucket_name=_TEST_BUCKET_NAME, + object_name=_TEST_OBJECT_NAME, + ) + + if open: + await read_obj_stream.open() + + return read_obj_stream + + @mock.patch( "google.cloud.storage._experimental.asyncio.async_read_object_stream.AsyncBidiRpc" ) @@ -65,12 +92,6 @@ def test_init_with_bucket_object_generation(mock_client, mock_async_bidi_rpc): mock_client._client._transport._wrapped_methods = { "bidi_read_object_rpc": rpc_sentinel, } - full_bucket_name = f"projects/_/buckets/{_TEST_BUCKET_NAME}" - first_bidi_read_req = _storage_v2.BidiReadObjectRequest( - read_object_spec=_storage_v2.BidiReadObjectSpec( - bucket=full_bucket_name, object=_TEST_OBJECT_NAME - ), - ) # Act read_obj_stream = _AsyncReadObjectStream( @@ -86,7 +107,6 @@ def test_init_with_bucket_object_generation(mock_client, mock_async_bidi_rpc): assert read_obj_stream.object_name == _TEST_OBJECT_NAME assert read_obj_stream.generation_number == _TEST_GENERATION_NUMBER assert read_obj_stream.read_handle == _TEST_READ_HANDLE - assert read_obj_stream.first_bidi_read_req == first_bidi_read_req assert read_obj_stream.rpc == rpc_sentinel @@ -112,6 +132,33 @@ async def test_open(mock_client, mock_cls_async_bidi_rpc): assert read_obj_stream.generation_number == _TEST_GENERATION_NUMBER assert read_obj_stream.read_handle == _TEST_READ_HANDLE + assert read_obj_stream.persisted_size == _TEST_OBJECT_SIZE + assert read_obj_stream.is_stream_open + + +@mock.patch( + "google.cloud.storage._experimental.asyncio.async_read_object_stream.AsyncBidiRpc" +) +@mock.patch( + "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" +) +@pytest.mark.asyncio +async def test_open_with_read_handle(mock_client, mock_cls_async_bidi_rpc): + # arrange + read_obj_stream = await instantiate_read_obj_stream_with_read_handle( + mock_client, mock_cls_async_bidi_rpc, open=False + ) + + # act + await read_obj_stream.open() + + # assert + read_obj_stream.socket_like_rpc.open.assert_called_once() + read_obj_stream.socket_like_rpc.recv.assert_called_once() + + assert read_obj_stream.generation_number is None + assert read_obj_stream.persisted_size is None + assert read_obj_stream.read_handle == _TEST_READ_HANDLE_NEW assert read_obj_stream.is_stream_open diff --git a/tests/unit/asyncio/test_async_write_object_stream.py b/tests/unit/asyncio/test_async_write_object_stream.py index 7fa2123c5..c6ea8a8ff 100644 --- a/tests/unit/asyncio/test_async_write_object_stream.py +++ b/tests/unit/asyncio/test_async_write_object_stream.py @@ -55,6 +55,7 @@ async def instantiate_write_obj_stream(mock_client, mock_cls_async_bidi_rpc, ope mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) mock_response.resource.generation = GENERATION + mock_response.resource.size = 0 mock_response.write_handle = WRITE_HANDLE socket_like_rpc.recv = AsyncMock(return_value=mock_response) @@ -129,6 +130,7 @@ async def test_open_for_new_object(mock_async_bidi_rpc, mock_client): mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) mock_response.resource.generation = GENERATION + mock_response.resource.size = 0 mock_response.write_handle = WRITE_HANDLE socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response) @@ -143,6 +145,7 @@ async def test_open_for_new_object(mock_async_bidi_rpc, mock_client): socket_like_rpc.recv.assert_called_once() assert stream.generation_number == GENERATION assert stream.write_handle == WRITE_HANDLE + assert stream.persisted_size == 0 @pytest.mark.asyncio @@ -158,6 +161,7 @@ async def test_open_for_existing_object(mock_async_bidi_rpc, mock_client): mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) + mock_response.resource.size = 1024 mock_response.resource.generation = GENERATION mock_response.write_handle = WRITE_HANDLE socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response) @@ -175,6 +179,7 @@ async def test_open_for_existing_object(mock_async_bidi_rpc, mock_client): socket_like_rpc.recv.assert_called_once() assert stream.generation_number == GENERATION assert stream.write_handle == WRITE_HANDLE + assert stream.persisted_size == 1024 @pytest.mark.asyncio @@ -191,6 +196,7 @@ async def test_open_when_already_open_raises_error(mock_async_bidi_rpc, mock_cli mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) mock_response.resource.generation = GENERATION + mock_response.resource.size = 0 mock_response.write_handle = WRITE_HANDLE socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response) diff --git a/tests/unit/gapic/storage_v2/test_storage.py b/tests/unit/gapic/storage_v2/test_storage.py index 20b680341..7b6340aa7 100644 --- a/tests/unit/gapic/storage_v2/test_storage.py +++ b/tests/unit/gapic/storage_v2/test_storage.py @@ -148,12 +148,19 @@ def test__read_environment_variables(): with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): - with pytest.raises(ValueError) as excinfo: - StorageClient._read_environment_variables() - assert ( - str(excinfo.value) - == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with pytest.raises(ValueError) as excinfo: + StorageClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + else: + assert StorageClient._read_environment_variables() == ( + False, + "auto", + None, + ) with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): assert StorageClient._read_environment_variables() == (False, "never", None) @@ -176,6 +183,105 @@ def test__read_environment_variables(): assert StorageClient._read_environment_variables() == (False, "auto", "foo.com") +def test_use_client_cert_effective(): + # Test case 1: Test when `should_use_client_cert` returns True. + # We mock the `should_use_client_cert` function to simulate a scenario where + # the google-auth library supports automatic mTLS and determines that a + # client certificate should be used. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch( + "google.auth.transport.mtls.should_use_client_cert", return_value=True + ): + assert StorageClient._use_client_cert_effective() is True + + # Test case 2: Test when `should_use_client_cert` returns False. + # We mock the `should_use_client_cert` function to simulate a scenario where + # the google-auth library supports automatic mTLS and determines that a + # client certificate should NOT be used. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch( + "google.auth.transport.mtls.should_use_client_cert", return_value=False + ): + assert StorageClient._use_client_cert_effective() is False + + # Test case 3: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "true". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + assert StorageClient._use_client_cert_effective() is True + + # Test case 4: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "false". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"} + ): + assert StorageClient._use_client_cert_effective() is False + + # Test case 5: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "True". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "True"}): + assert StorageClient._use_client_cert_effective() is True + + # Test case 6: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "False". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "False"} + ): + assert StorageClient._use_client_cert_effective() is False + + # Test case 7: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "TRUE". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "TRUE"}): + assert StorageClient._use_client_cert_effective() is True + + # Test case 8: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "FALSE". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "FALSE"} + ): + assert StorageClient._use_client_cert_effective() is False + + # Test case 9: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not set. + # In this case, the method should return False, which is the default value. + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, clear=True): + assert StorageClient._use_client_cert_effective() is False + + # Test case 10: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to an invalid value. + # The method should raise a ValueError as the environment variable must be either + # "true" or "false". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "unsupported"} + ): + with pytest.raises(ValueError): + StorageClient._use_client_cert_effective() + + # Test case 11: Test when `should_use_client_cert` is available and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to an invalid value. + # The method should return False as the environment variable is set to an invalid value. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "unsupported"} + ): + assert StorageClient._use_client_cert_effective() is False + + # Test case 12: Test when `should_use_client_cert` is available and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is unset. Also, + # the GOOGLE_API_CONFIG environment variable is unset. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": ""}): + with mock.patch.dict(os.environ, {"GOOGLE_API_CERTIFICATE_CONFIG": ""}): + assert StorageClient._use_client_cert_effective() is False + + def test__get_client_cert_source(): mock_provided_cert_source = mock.Mock() mock_default_cert_source = mock.Mock() @@ -515,17 +621,6 @@ def test_storage_client_client_options(client_class, transport_class, transport_ == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" ) - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): - with pytest.raises(ValueError) as excinfo: - client = client_class(transport=transport_name) - assert ( - str(excinfo.value) - == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: @@ -733,6 +828,119 @@ def test_storage_client_get_mtls_endpoint_and_cert_source(client_class): assert api_endpoint == mock_api_endpoint assert cert_source is None + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "Unsupported". + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, + api_endpoint=mock_api_endpoint, + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test cases for mTLS enablement when GOOGLE_API_USE_CLIENT_CERTIFICATE is unset. + test_cases = [ + ( + # With workloads present in config, mTLS is enabled. + { + "version": 1, + "cert_configs": { + "workload": { + "cert_path": "path/to/cert/file", + "key_path": "path/to/key/file", + } + }, + }, + mock_client_cert_source, + ), + ( + # With workloads not present in config, mTLS is disabled. + { + "version": 1, + "cert_configs": {}, + }, + None, + ), + ] + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + for config_data, expected_cert_source in test_cases: + env = os.environ.copy() + env.pop("GOOGLE_API_USE_CLIENT_CERTIFICATE", None) + with mock.patch.dict(os.environ, env, clear=True): + config_filename = "mock_certificate_config.json" + config_file_content = json.dumps(config_data) + m = mock.mock_open(read_data=config_file_content) + with mock.patch("builtins.open", m): + with mock.patch.dict( + os.environ, {"GOOGLE_API_CERTIFICATE_CONFIG": config_filename} + ): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, + api_endpoint=mock_api_endpoint, + ) + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source is expected_cert_source + + # Test cases for mTLS enablement when GOOGLE_API_USE_CLIENT_CERTIFICATE is unset(empty). + test_cases = [ + ( + # With workloads present in config, mTLS is enabled. + { + "version": 1, + "cert_configs": { + "workload": { + "cert_path": "path/to/cert/file", + "key_path": "path/to/key/file", + } + }, + }, + mock_client_cert_source, + ), + ( + # With workloads not present in config, mTLS is disabled. + { + "version": 1, + "cert_configs": {}, + }, + None, + ), + ] + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + for config_data, expected_cert_source in test_cases: + env = os.environ.copy() + env.pop("GOOGLE_API_USE_CLIENT_CERTIFICATE", "") + with mock.patch.dict(os.environ, env, clear=True): + config_filename = "mock_certificate_config.json" + config_file_content = json.dumps(config_data) + m = mock.mock_open(read_data=config_file_content) + with mock.patch("builtins.open", m): + with mock.patch.dict( + os.environ, {"GOOGLE_API_CERTIFICATE_CONFIG": config_filename} + ): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, + api_endpoint=mock_api_endpoint, + ) + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source is expected_cert_source + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() @@ -783,18 +991,6 @@ def test_storage_client_get_mtls_endpoint_and_cert_source(client_class): == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" ) - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): - with pytest.raises(ValueError) as excinfo: - client_class.get_mtls_endpoint_and_cert_source() - - assert ( - str(excinfo.value) - == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - @pytest.mark.parametrize("client_class", [StorageClient, StorageAsyncClient]) @mock.patch.object( diff --git a/tests/unit/test_blob.py b/tests/unit/test_blob.py index f3b6da5d1..a8abb1571 100644 --- a/tests/unit/test_blob.py +++ b/tests/unit/test_blob.py @@ -3049,7 +3049,14 @@ def test__initiate_resumable_upload_with_client_custom_headers(self): self._initiate_resumable_helper(client=client) def _make_resumable_transport( - self, headers1, headers2, headers3, total_bytes, data_corruption=False + self, + headers1, + headers2, + headers3, + total_bytes, + data_corruption=False, + md5_checksum_value=None, + crc32c_checksum_value=None, ): fake_transport = mock.Mock(spec=["request"]) @@ -3057,7 +3064,13 @@ def _make_resumable_transport( fake_response2 = self._mock_requests_response( http.client.PERMANENT_REDIRECT, headers2 ) - json_body = f'{{"size": "{total_bytes:d}"}}' + json_body = json.dumps( + { + "size": str(total_bytes), + "md5Hash": md5_checksum_value, + "crc32c": crc32c_checksum_value, + } + ) if data_corruption: fake_response3 = DataCorruption(None) else: @@ -3151,6 +3164,9 @@ def _do_resumable_upload_call2( if_metageneration_match=None, if_metageneration_not_match=None, timeout=None, + checksum=None, + crc32c_checksum_value=None, + md5_checksum_value=None, ): # Third mock transport.request() does sends last chunk. content_range = f"bytes {blob.chunk_size:d}-{total_bytes - 1:d}/{total_bytes:d}" @@ -3161,6 +3177,11 @@ def _do_resumable_upload_call2( "content-type": content_type, "content-range": content_range, } + if checksum == "crc32c": + expected_headers["x-goog-hash"] = f"crc32c={crc32c_checksum_value}" + elif checksum == "md5": + expected_headers["x-goog-hash"] = f"md5={md5_checksum_value}" + payload = data[blob.chunk_size :] return mock.call( "PUT", @@ -3181,12 +3202,17 @@ def _do_resumable_helper( timeout=None, data_corruption=False, retry=None, + checksum=None, # None is also a valid value, when user decides to disable checksum validation. ): CHUNK_SIZE = 256 * 1024 USER_AGENT = "testing 1.2.3" content_type = "text/html" # Data to be uploaded. data = b"" + (b"A" * CHUNK_SIZE) + b"" + + # Data calcuated offline and entered here. (Unit test best practice). + crc32c_checksum_value = "mQ30hg==" + md5_checksum_value = "wajHeg1f2Q2u9afI6fjPOw==" total_bytes = len(data) if use_size: size = total_bytes @@ -3213,6 +3239,8 @@ def _do_resumable_helper( headers3, total_bytes, data_corruption=data_corruption, + md5_checksum_value=md5_checksum_value, + crc32c_checksum_value=crc32c_checksum_value, ) # Create some mock arguments and call the method under test. @@ -3247,7 +3275,7 @@ def _do_resumable_helper( if_generation_not_match, if_metageneration_match, if_metageneration_not_match, - checksum=None, + checksum=checksum, retry=retry, **timeout_kwarg, ) @@ -3296,6 +3324,9 @@ def _do_resumable_helper( if_metageneration_match=if_metageneration_match, if_metageneration_not_match=if_metageneration_not_match, timeout=expected_timeout, + checksum=checksum, + crc32c_checksum_value=crc32c_checksum_value, + md5_checksum_value=md5_checksum_value, ) self.assertEqual(transport.request.mock_calls, [call0, call1, call2]) @@ -3308,6 +3339,12 @@ def test__do_resumable_upload_no_size(self): def test__do_resumable_upload_with_size(self): self._do_resumable_helper(use_size=True) + def test__do_resumable_upload_with_size_with_crc32c_checksum(self): + self._do_resumable_helper(use_size=True, checksum="crc32c") + + def test__do_resumable_upload_with_size_with_md5_checksum(self): + self._do_resumable_helper(use_size=True, checksum="md5") + def test__do_resumable_upload_with_retry(self): self._do_resumable_helper(retry=DEFAULT_RETRY)