diff --git a/requirements-test.txt b/requirements-test.txt index 692d53953e..fed1a6de24 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -13,3 +13,4 @@ isort pylint mypy bandit +types-requests diff --git a/setup.cfg b/setup.cfg index 26c0d092eb..5baa2d2173 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,7 +15,13 @@ warn_unreachable = True strict_equality = True disallow_untyped_defs = True disallow_untyped_calls = True -files = tuf/api/, tuf/exceptions.py +files = + tuf/api/, + tuf/ngclient, + tuf/exceptions.py [mypy-securesystemslib.*] ignore_missing_imports = True + +[mypy-urllib3.*] +ignore_missing_imports = True diff --git a/tuf/api/metadata.py b/tuf/api/metadata.py index 8f5f48299d..ba16a68a17 100644 --- a/tuf/api/metadata.py +++ b/tuf/api/metadata.py @@ -24,8 +24,8 @@ from collections import OrderedDict from datetime import datetime, timedelta from typing import ( + IO, Any, - BinaryIO, ClassVar, Dict, Generic, @@ -98,7 +98,7 @@ def __init__(self, signed: T, signatures: "OrderedDict[str, Signature]"): self.signatures = signatures @classmethod - def from_dict(cls, metadata: Dict[str, Any]) -> "Metadata": + def from_dict(cls, metadata: Dict[str, Any]) -> "Metadata[T]": """Creates Metadata object from its dict representation. Arguments: @@ -753,7 +753,7 @@ class BaseFile: @staticmethod def _verify_hashes( - data: Union[bytes, BinaryIO], expected_hashes: Dict[str, str] + data: Union[bytes, IO[bytes]], expected_hashes: Dict[str, str] ) -> None: """Verifies that the hash of 'data' matches 'expected_hashes'""" is_bytes = isinstance(data, bytes) @@ -782,7 +782,7 @@ def _verify_hashes( @staticmethod def _verify_length( - data: Union[bytes, BinaryIO], expected_length: int + data: Union[bytes, IO[bytes]], expected_length: int ) -> None: """Verifies that the length of 'data' matches 'expected_length'""" if isinstance(data, bytes): @@ -867,7 +867,7 @@ def to_dict(self) -> Dict[str, Any]: return res_dict - def verify_length_and_hashes(self, data: Union[bytes, BinaryIO]) -> None: + def verify_length_and_hashes(self, data: Union[bytes, IO[bytes]]) -> None: """Verifies that the length and hashes of "data" match expected values. Args: @@ -1182,7 +1182,7 @@ def to_dict(self) -> Dict[str, Any]: **self.unrecognized_fields, } - def verify_length_and_hashes(self, data: Union[bytes, BinaryIO]) -> None: + def verify_length_and_hashes(self, data: Union[bytes, IO[bytes]]) -> None: """Verifies that length and hashes of "data" match expected values. Args: diff --git a/tuf/exceptions.py b/tuf/exceptions.py index 2a24a0429e..8ebc92c7d1 100755 --- a/tuf/exceptions.py +++ b/tuf/exceptions.py @@ -24,7 +24,7 @@ from urllib import parse -from typing import Any, Dict +from typing import Any, Dict, Optional import logging logger = logging.getLogger(__name__) @@ -206,16 +206,19 @@ def __repr__(self) -> str: class SlowRetrievalError(DownloadError): """"Indicate that downloading a file took an unreasonably long time.""" - def __init__(self, average_download_speed: int): + def __init__(self, average_download_speed: Optional[int] = None): super(SlowRetrievalError, self).__init__() self.__average_download_speed = average_download_speed #bytes/second def __str__(self) -> str: - return ( - 'Download was too slow. Average speed: ' + + msg = 'Download was too slow.' + if self.__average_download_speed is not None: + msg = ('Download was too slow. Average speed: ' + repr(self.__average_download_speed) + ' bytes per second.') + return msg + def __repr__(self) -> str: return self.__class__.__name__ + ' : ' + str(self) diff --git a/tuf/ngclient/_internal/requests_fetcher.py b/tuf/ngclient/_internal/requests_fetcher.py index a26231c5bb..ae68f1a3be 100644 --- a/tuf/ngclient/_internal/requests_fetcher.py +++ b/tuf/ngclient/_internal/requests_fetcher.py @@ -7,7 +7,7 @@ import logging import time -from typing import Iterator, Optional +from typing import Dict, Iterator, Optional from urllib import parse # Imports @@ -31,7 +31,7 @@ class RequestsFetcher(FetcherInterface): session per scheme+hostname combination. """ - def __init__(self): + def __init__(self) -> None: # http://docs.python-requests.org/en/master/user/advanced/#session-objects: # # "The Session object allows you to persist certain parameters across @@ -46,7 +46,7 @@ def __init__(self): # improve efficiency, but avoiding sharing state between different # hosts-scheme combinations to minimize subtle security issues. # Some cookies may not be HTTP-safe. - self._sessions = {} + self._sessions: Dict[str, requests.Session] = {} # Default settings self.socket_timeout: int = 4 # seconds @@ -141,12 +141,12 @@ def _chunks( ) except urllib3.exceptions.ReadTimeoutError as e: - raise exceptions.SlowRetrievalError(str(e)) + raise exceptions.SlowRetrievalError from e finally: response.close() - def _get_session(self, url): + def _get_session(self, url: str) -> requests.Session: """Returns a different customized requests.Session per schema+hostname combination. """ diff --git a/tuf/ngclient/_internal/trusted_metadata_set.py b/tuf/ngclient/_internal/trusted_metadata_set.py index 608889fb2a..be1b0b44ed 100644 --- a/tuf/ngclient/_internal/trusted_metadata_set.py +++ b/tuf/ngclient/_internal/trusted_metadata_set.py @@ -66,7 +66,7 @@ from typing import Dict, Iterator, Optional from tuf import exceptions -from tuf.api.metadata import Metadata +from tuf.api.metadata import Metadata, Root, Snapshot, Targets, Timestamp from tuf.api.serialization import DeserializationError logger = logging.getLogger(__name__) @@ -92,13 +92,13 @@ def __init__(self, root_data: bytes): RepositoryError: Metadata failed to load or verify. The actual error type and content will contain more details. """ - self._trusted_set = {} # type: Dict[str: Metadata] + self._trusted_set: Dict[str, Metadata] = {} self.reference_time = datetime.utcnow() # Load and validate the local root metadata. Valid initial trusted root # metadata is required logger.debug("Updating initial trusted root") - self.update_root(root_data) + self._load_trusted_root(root_data) def __getitem__(self, role: str) -> Metadata: """Returns current Metadata for 'role'""" @@ -114,27 +114,27 @@ def __iter__(self) -> Iterator[Metadata]: # Helper properties for top level metadata @property - def root(self) -> Optional[Metadata]: - """Current root Metadata or None""" - return self._trusted_set.get("root") + def root(self) -> Metadata[Root]: + """Current root Metadata""" + return self._trusted_set["root"] @property - def timestamp(self) -> Optional[Metadata]: + def timestamp(self) -> Optional[Metadata[Timestamp]]: """Current timestamp Metadata or None""" return self._trusted_set.get("timestamp") @property - def snapshot(self) -> Optional[Metadata]: + def snapshot(self) -> Optional[Metadata[Snapshot]]: """Current snapshot Metadata or None""" return self._trusted_set.get("snapshot") @property - def targets(self) -> Optional[Metadata]: + def targets(self) -> Optional[Metadata[Targets]]: """Current targets Metadata or None""" return self._trusted_set.get("targets") # Methods for updating metadata - def update_root(self, data: bytes): + def update_root(self, data: bytes) -> None: """Verifies and loads 'data' as new root metadata. Note that an expired intermediate root is considered valid: expiry is @@ -152,7 +152,7 @@ def update_root(self, data: bytes): logger.debug("Updating root") try: - new_root = Metadata.from_bytes(data) + new_root = Metadata[Root].from_bytes(data) except DeserializationError as e: raise exceptions.RepositoryError("Failed to load root") from e @@ -161,21 +161,21 @@ def update_root(self, data: bytes): f"Expected 'root', got '{new_root.signed.type}'" ) - if self.root is not None: - # We are not loading initial trusted root: verify the new one - self.root.verify_delegate("root", new_root) + # Verify that new root is signed by trusted root + self.root.verify_delegate("root", new_root) - if new_root.signed.version != self.root.signed.version + 1: - raise exceptions.ReplayedMetadataError( - "root", new_root.signed.version, self.root.signed.version - ) + if new_root.signed.version != self.root.signed.version + 1: + raise exceptions.ReplayedMetadataError( + "root", new_root.signed.version, self.root.signed.version + ) + # Verify that new root is signed by itself new_root.verify_delegate("root", new_root) self._trusted_set["root"] = new_root logger.debug("Updated root") - def update_timestamp(self, data: bytes): + def update_timestamp(self, data: bytes) -> None: """Verifies and loads 'data' as new timestamp metadata. Note that an expired intermediate timestamp is considered valid so it @@ -199,7 +199,7 @@ def update_timestamp(self, data: bytes): # timestamp/snapshot can not yet be loaded at this point try: - new_timestamp = Metadata.from_bytes(data) + new_timestamp = Metadata[Timestamp].from_bytes(data) except DeserializationError as e: raise exceptions.RepositoryError("Failed to load timestamp") from e @@ -237,7 +237,7 @@ def update_timestamp(self, data: bytes): self._trusted_set["timestamp"] = new_timestamp logger.debug("Updated timestamp") - def update_snapshot(self, data: bytes): + def update_snapshot(self, data: bytes) -> None: """Verifies and loads 'data' as new snapshot metadata. Note that intermediate snapshot is considered valid even if it is @@ -276,7 +276,7 @@ def update_snapshot(self, data: bytes): ) from e try: - new_snapshot = Metadata.from_bytes(data) + new_snapshot = Metadata[Snapshot].from_bytes(data) except DeserializationError as e: raise exceptions.RepositoryError("Failed to load snapshot") from e @@ -314,7 +314,11 @@ def update_snapshot(self, data: bytes): self._trusted_set["snapshot"] = new_snapshot logger.debug("Updated snapshot") - def _check_final_snapshot(self): + def _check_final_snapshot(self) -> None: + """Check snapshot expiry and version before targets is updated""" + + assert self.snapshot is not None # nosec + assert self.timestamp is not None # nosec if self.snapshot.signed.is_expired(self.reference_time): raise exceptions.ExpiredMetadataError("snapshot.json is expired") @@ -328,7 +332,7 @@ def _check_final_snapshot(self): f"got {self.snapshot.signed.version}" ) - def update_targets(self, data: bytes): + def update_targets(self, data: bytes) -> None: """Verifies and loads 'data' as new top-level targets metadata. Args: @@ -342,7 +346,7 @@ def update_targets(self, data: bytes): def update_delegated_targets( self, data: bytes, role_name: str, delegator_name: str - ): + ) -> None: """Verifies and loads 'data' as new metadata for target 'role_name'. Args: @@ -383,7 +387,7 @@ def update_delegated_targets( ) from e try: - new_delegate = Metadata.from_bytes(data) + new_delegate = Metadata[Targets].from_bytes(data) except DeserializationError as e: raise exceptions.RepositoryError("Failed to load snapshot") from e @@ -405,3 +409,24 @@ def update_delegated_targets( self._trusted_set[role_name] = new_delegate logger.debug("Updated %s delegated by %s", role_name, delegator_name) + + def _load_trusted_root(self, data: bytes) -> None: + """Verifies and loads 'data' as trusted root metadata. + + Note that an expired initial root is considered valid: expiry is + only checked for the final root in update_timestamp(). + """ + try: + new_root = Metadata[Root].from_bytes(data) + except DeserializationError as e: + raise exceptions.RepositoryError("Failed to load root") from e + + if new_root.signed.type != "root": + raise exceptions.RepositoryError( + f"Expected 'root', got '{new_root.signed.type}'" + ) + + new_root.verify_delegate("root", new_root) + + self._trusted_set["root"] = new_root + logger.debug("Loaded trusted root") diff --git a/tuf/ngclient/updater.py b/tuf/ngclient/updater.py index 1acf2c40bd..a3c7189d75 100644 --- a/tuf/ngclient/updater.py +++ b/tuf/ngclient/updater.py @@ -112,12 +112,7 @@ def __init__( # Read trusted local root metadata data = self._load_local_metadata("root") self._trusted_set = trusted_metadata_set.TrustedMetadataSet(data) - - if fetcher is None: - self._fetcher = requests_fetcher.RequestsFetcher() - else: - self._fetcher = fetcher - + self._fetcher = fetcher or requests_fetcher.RequestsFetcher() self.config = config or UpdaterConfig() def refresh(self) -> None: @@ -225,7 +220,7 @@ def download_target( targetinfo: TargetFile, destination_directory: str, target_base_url: Optional[str] = None, - ): + ) -> None: """Downloads the target file specified by 'targetinfo'. Args: @@ -241,12 +236,14 @@ def download_target( TODO: download-related errors TODO: file write errors """ - if target_base_url is None and self._target_base_url is None: - raise ValueError( - "target_base_url must be set in either download_target() or " - "constructor" - ) + if target_base_url is None: + if self._target_base_url is None: + raise ValueError( + "target_base_url must be set in either " + "download_target() or constructor" + ) + target_base_url = self._target_base_url else: target_base_url = _ensure_trailing_slash(target_base_url) @@ -289,7 +286,7 @@ def _load_local_metadata(self, rolename: str) -> bytes: with open(os.path.join(self._dir, f"{rolename}.json"), "rb") as f: return f.read() - def _persist_metadata(self, rolename: str, data: bytes): + def _persist_metadata(self, rolename: str, data: bytes) -> None: with open(os.path.join(self._dir, f"{rolename}.json"), "wb") as f: f.write(data) @@ -344,6 +341,7 @@ def _load_snapshot(self) -> None: # Local snapshot does not exist or is invalid: update from remote logger.debug("Failed to load local snapshot %s", e) + assert self._trusted_set.timestamp is not None # nosec metainfo = self._trusted_set.timestamp.signed.meta["snapshot.json"] length = metainfo.length or self.config.snapshot_max_length version = None @@ -364,6 +362,7 @@ def _load_targets(self, role: str, parent_role: str) -> None: # Local 'role' does not exist or is invalid: update from remote logger.debug("Failed to load local %s: %s", role, e) + assert self._trusted_set.snapshot is not None # nosec metainfo = self._trusted_set.snapshot.signed.meta[f"{role}.json"] length = metainfo.length or self.config.targets_max_length version = None @@ -450,6 +449,6 @@ def _preorder_depth_first_walk( return None -def _ensure_trailing_slash(url: str): +def _ensure_trailing_slash(url: str) -> str: """Return url guaranteed to end in a slash""" return url if url.endswith("/") else f"{url}/"