diff --git a/README.md b/README.md index 0cac2cf78..70bee8085 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,9 @@ The sample chat app to run is found as [model-chat.py](https://github.com/micros - [Documentation](https://microsoft.github.io/Olive) - [Recipes](https://github.com/microsoft/olive-recipes) +## Data/Telemetry +Distributions of this project may collect usage data and send it to Microsoft to help improve our products and services. See the [privacy statement](docs/Privacy.md) for more details. + ## 🤝 Contributions and Feedback - We welcome contributions! Please read the [contribution guidelines](./CONTRIBUTING.md) for more details on how to contribute to the Olive project. - For feature requests or bug reports, file a [GitHub Issue](https://github.com/microsoft/Olive/issues). diff --git a/docs/Privacy.md b/docs/Privacy.md new file mode 100644 index 000000000..a74cf5b1c --- /dev/null +++ b/docs/Privacy.md @@ -0,0 +1,16 @@ +# Privacy + +## Data Collection +The software may collect information about you and your use of the software and send it to Microsoft. Microsoft may use this information to provide services and improve our products and services. You may turn off the telemetry as described in the repository. There are also some features in the software that may enable Microsoft to collect data from users of your applications. If you use these features, you must comply with applicable law, including providing appropriate notices to users of your applications together with a copy of Microsoft's privacy statement. Our privacy statement can be found [here](https://go.microsoft.com/fwlink/?LinkID=824704). You can learn more about data collection and use in the help documentation and our privacy statement. Your use of the software operates as your consent to these practices. + +*** + +## Technical Details +Olive uses the [OpenTelemetry](https://opentelemetry.io/) API for its implementation. Telemetry is turned ON by default. Based on user consent, this data may be periodically sent to Microsoft servers following GDPR and privacy regulations for anonymity and data access controls. Application, device, and version information is collected automatically. + +In addition, Olive may collect additional telemetry data such as: +- Invoked commands +- Performance data +- Exception information + +Collection of this additional telemetry can be disabled by adding the `--disable_telemetry` flag to any Olive CLI command, or by setting the `OLIVE_DISABLE_TELEMETRY` environment variable to `1` before running. If telemetry is enabled, but cannot be sent to Microsoft, it will be stored locally and sent when a connection is available. You can override the default cache location by setting the `OLIVE_TELEMETRY_CACHE_PATH` environment variable to a valid file path. diff --git a/olive/__init__.py b/olive/__init__.py index e3914adde..9a59f61ab 100644 --- a/olive/__init__.py +++ b/olive/__init__.py @@ -14,8 +14,6 @@ _logger.addHandler(_sc) _logger.propagate = False -__version__ = "0.11.0.dev0" - # pylint: disable=C0413 # Import Python API functions @@ -33,10 +31,12 @@ tune_session_params, ) from olive.engine.output import ModelOutput, WorkflowOutput # noqa: E402 +from olive.version import __version__ # noqa: E402 __all__ = [ "ModelOutput", "WorkflowOutput", + "__version__", # Python API functions "benchmark", "capture_onnx_graph", diff --git a/olive/cli/auto_opt.py b/olive/cli/auto_opt.py index a2e7948c1..2e0f73444 100644 --- a/olive/cli/auto_opt.py +++ b/olive/cli/auto_opt.py @@ -14,6 +14,7 @@ add_logging_options, add_save_config_file_options, add_shared_cache_options, + add_telemetry_options, get_input_model_config, update_accelerator_options, update_shared_cache_options, @@ -22,6 +23,7 @@ from olive.constants import Precision from olive.hardware.constants import ExecutionProvider from olive.package_config import OlivePackageConfig +from olive.telemetry import action class AutoOptCommand(BaseOliveCLICommand): @@ -167,8 +169,10 @@ def register_subcommand(parser: ArgumentParser): add_shared_cache_options(sub_parser) add_logging_options(sub_parser) add_save_config_file_options(sub_parser) + add_telemetry_options(sub_parser) sub_parser.set_defaults(func=AutoOptCommand) + @action def run(self): return self._run_workflow() diff --git a/olive/cli/base.py b/olive/cli/base.py index a063026fe..ffbea3a09 100644 --- a/olive/cli/base.py +++ b/olive/cli/base.py @@ -631,6 +631,12 @@ def add_search_options(sub_parser: ArgumentParser): search_strategy_group.add_argument("--seed", type=int, default=0, help="Random seed for search sampler") +def add_telemetry_options(sub_parser: ArgumentParser): + """Add telemetry options to the sub_parser.""" + sub_parser.add_argument("--disable_telemetry", action="store_true", help="Disable telemetry for this command.") + return sub_parser + + def update_search_options(args, config): to_replace = [] to_replace.extend( diff --git a/olive/cli/benchmark.py b/olive/cli/benchmark.py index 7c996e53c..753ca2224 100644 --- a/olive/cli/benchmark.py +++ b/olive/cli/benchmark.py @@ -11,10 +11,12 @@ add_logging_options, add_save_config_file_options, add_shared_cache_options, + add_telemetry_options, get_input_model_config, update_shared_cache_options, ) from olive.common.utils import set_nested_dict_value +from olive.telemetry import action class BenchmarkCommand(BaseOliveCLICommand): @@ -69,8 +71,10 @@ def register_subcommand(parser: ArgumentParser): add_logging_options(sub_parser) add_save_config_file_options(sub_parser) add_shared_cache_options(sub_parser) + add_telemetry_options(sub_parser) sub_parser.set_defaults(func=BenchmarkCommand) + @action def run(self): return self._run_workflow() diff --git a/olive/cli/capture_onnx.py b/olive/cli/capture_onnx.py index 198ac4225..87a2b92ea 100644 --- a/olive/cli/capture_onnx.py +++ b/olive/cli/capture_onnx.py @@ -13,12 +13,14 @@ add_logging_options, add_save_config_file_options, add_shared_cache_options, + add_telemetry_options, get_diffusers_input_model, get_input_model_config, update_shared_cache_options, ) from olive.common.utils import set_nested_dict_value from olive.model.utils.diffusers_utils import is_valid_diffusers_model +from olive.telemetry import action class ModelBuilderAccuracyLevel(IntEnum): @@ -170,8 +172,10 @@ def register_subcommand(parser: ArgumentParser): add_logging_options(sub_parser) add_save_config_file_options(sub_parser) add_shared_cache_options(sub_parser) + add_telemetry_options(sub_parser) sub_parser.set_defaults(func=CaptureOnnxGraphCommand) + @action def run(self): return self._run_workflow() diff --git a/olive/cli/configure_qualcomm_sdk.py b/olive/cli/configure_qualcomm_sdk.py index 883fef12e..7f42fd58d 100644 --- a/olive/cli/configure_qualcomm_sdk.py +++ b/olive/cli/configure_qualcomm_sdk.py @@ -4,7 +4,8 @@ # -------------------------------------------------------------------------- from argparse import ArgumentParser -from olive.cli.base import BaseOliveCLICommand +from olive.cli.base import BaseOliveCLICommand, add_telemetry_options +from olive.telemetry import action class ConfigureQualcommSDKCommand(BaseOliveCLICommand): @@ -21,9 +22,10 @@ def register_subcommand(parser: ArgumentParser): required=True, choices=["3.6", "3.8"], ) - + add_telemetry_options(sub_parser) sub_parser.set_defaults(func=ConfigureQualcommSDKCommand) + @action def run(self): from olive.platform_sdk.qualcomm.configure.configure import configure diff --git a/olive/cli/convert_adapters.py b/olive/cli/convert_adapters.py index 54ad5a065..558cde884 100644 --- a/olive/cli/convert_adapters.py +++ b/olive/cli/convert_adapters.py @@ -6,8 +6,9 @@ from argparse import ArgumentParser from typing import TYPE_CHECKING -from olive.cli.base import BaseOliveCLICommand, add_logging_options +from olive.cli.base import BaseOliveCLICommand, add_logging_options, add_telemetry_options from olive.common.utils import WeightsFileFormat, save_weights +from olive.telemetry import action if TYPE_CHECKING: from numpy.typing import NDArray @@ -75,8 +76,10 @@ def register_subcommand(parser: ArgumentParser): help="Quantization mode for int4 quantization of adapter weights. Default is symmetric.", ) add_logging_options(sub_parser) + add_telemetry_options(sub_parser) sub_parser.set_defaults(func=ConvertAdaptersCommand) + @action def run(self): import torch from peft import LoraConfig, load_peft_weights diff --git a/olive/cli/diffusion_lora.py b/olive/cli/diffusion_lora.py index d0061738a..e51d3ab17 100644 --- a/olive/cli/diffusion_lora.py +++ b/olive/cli/diffusion_lora.py @@ -11,11 +11,13 @@ add_logging_options, add_save_config_file_options, add_shared_cache_options, + add_telemetry_options, update_shared_cache_options, ) from olive.common.utils import set_nested_dict_value from olive.constants import DiffusersModelVariant from olive.passes.diffusers.lora import LRSchedulerType, MixedPrecision +from olive.telemetry import action class DiffusionLoraCommand(BaseOliveCLICommand): @@ -237,8 +239,10 @@ def register_subcommand(parser: ArgumentParser): add_shared_cache_options(sub_parser) add_logging_options(sub_parser) add_save_config_file_options(sub_parser) + add_telemetry_options(sub_parser) sub_parser.set_defaults(func=DiffusionLoraCommand) + @action def run(self): return self._run_workflow() diff --git a/olive/cli/extract_adapters.py b/olive/cli/extract_adapters.py index 47521a51f..2a9f7b0f0 100644 --- a/olive/cli/extract_adapters.py +++ b/olive/cli/extract_adapters.py @@ -6,8 +6,9 @@ from huggingface_hub.constants import HF_HUB_CACHE -from olive.cli.base import BaseOliveCLICommand, add_logging_options +from olive.cli.base import BaseOliveCLICommand, add_logging_options, add_telemetry_options from olive.common.utils import WeightsFileFormat, save_weights +from olive.telemetry import action class ExtractAdaptersCommand(BaseOliveCLICommand): @@ -54,8 +55,10 @@ def register_subcommand(parser: ArgumentParser): help="Cache dir to store temporary files in. Default is Hugging Face's default cache dir.", ) add_logging_options(sub_parser) + add_telemetry_options(sub_parser) sub_parser.set_defaults(func=ExtractAdaptersCommand) + @action def run(self): # Reference: https://huggingface.co/microsoft/Phi-4-multimodal-instruct-onnx/blob/05f620b467891affcb00b464e5a73e7cf2de61f9/onnx/builder.py#L318 import os diff --git a/olive/cli/finetune.py b/olive/cli/finetune.py index 0e79b7e02..56afb2e85 100644 --- a/olive/cli/finetune.py +++ b/olive/cli/finetune.py @@ -13,11 +13,13 @@ add_logging_options, add_save_config_file_options, add_shared_cache_options, + add_telemetry_options, get_input_model_config, update_dataset_options, update_shared_cache_options, ) from olive.common.utils import set_nested_dict_value +from olive.telemetry import action class FineTuneCommand(BaseOliveCLICommand): @@ -74,8 +76,10 @@ def register_subcommand(parser: ArgumentParser): add_shared_cache_options(sub_parser) add_logging_options(sub_parser) add_save_config_file_options(sub_parser) + add_telemetry_options(sub_parser) sub_parser.set_defaults(func=FineTuneCommand) + @action def run(self): return self._run_workflow() diff --git a/olive/cli/generate_adapter.py b/olive/cli/generate_adapter.py index 64b01d011..d1bd03659 100644 --- a/olive/cli/generate_adapter.py +++ b/olive/cli/generate_adapter.py @@ -11,11 +11,13 @@ add_logging_options, add_save_config_file_options, add_shared_cache_options, + add_telemetry_options, get_input_model_config, update_shared_cache_options, ) from olive.common.utils import WeightsFileFormat, set_nested_dict_value from olive.passes.onnx.common import AdapterType +from olive.telemetry import action class GenerateAdapterCommand(BaseOliveCLICommand): @@ -45,8 +47,10 @@ def register_subcommand(parser: ArgumentParser): add_logging_options(sub_parser) add_save_config_file_options(sub_parser) add_shared_cache_options(sub_parser) + add_telemetry_options(sub_parser) sub_parser.set_defaults(func=GenerateAdapterCommand) + @action def run(self): return self._run_workflow() diff --git a/olive/cli/generate_cost_model.py b/olive/cli/generate_cost_model.py index c8df9c7fb..2e765a145 100644 --- a/olive/cli/generate_cost_model.py +++ b/olive/cli/generate_cost_model.py @@ -5,8 +5,9 @@ import logging from pathlib import Path -from olive.cli.base import BaseOliveCLICommand, add_input_model_options, get_input_model_config +from olive.cli.base import BaseOliveCLICommand, add_input_model_options, add_telemetry_options, get_input_model_config from olive.model import ModelConfig +from olive.telemetry import action logger = logging.getLogger(__name__) @@ -34,8 +35,10 @@ def register_subcommand(parser): choices=PRECISON_TO_BYTES.keys(), help="Weight precision", ) + add_telemetry_options(sub_parser) sub_parser.set_defaults(func=GenerateCostModelCommand) + @action def run(self): import torch diff --git a/olive/cli/launcher.py b/olive/cli/launcher.py index b5452a378..d9088bc89 100644 --- a/olive/cli/launcher.py +++ b/olive/cli/launcher.py @@ -22,6 +22,7 @@ from olive.cli.run_pass import RunPassCommand from olive.cli.session_params_tuning import SessionParamsTuningCommand from olive.cli.shared_cache import SharedCacheCommand +from olive.telemetry import Telemetry def get_cli_parser(called_as_console_script: bool = True) -> ArgumentParser: @@ -61,6 +62,10 @@ def main(raw_args=None, called_as_console_script: bool = True): args, unknown_args = parser.parse_known_args(raw_args) + telemetry = Telemetry() + if args.disable_telemetry: + telemetry.disable_telemetry() + if not hasattr(args, "func"): parser.print_help() sys.exit(1) @@ -68,6 +73,7 @@ def main(raw_args=None, called_as_console_script: bool = True): # Run the command service = args.func(parser, args, unknown_args) service.run() + telemetry.shutdown() def legacy_call(deprecated_module: str, command_name: str, *args): diff --git a/olive/cli/optimize.py b/olive/cli/optimize.py index 7e1f4b30e..a66919f94 100644 --- a/olive/cli/optimize.py +++ b/olive/cli/optimize.py @@ -15,11 +15,13 @@ add_input_model_options, add_logging_options, add_save_config_file_options, + add_telemetry_options, get_input_model_config, ) from olive.common.utils import set_nested_dict_value from olive.constants import Precision, precision_bits_from_precision from olive.hardware.constants import ExecutionProvider +from olive.telemetry import action class OptimizeCommand(BaseOliveCLICommand): @@ -184,6 +186,7 @@ def register_subcommand(parser: ArgumentParser): add_logging_options(sub_parser) add_save_config_file_options(sub_parser) + add_telemetry_options(sub_parser) sub_parser.set_defaults(func=OptimizeCommand) def __init__(self, parser: ArgumentParser, args: Namespace, unknown_args: Optional[list] = None): @@ -216,6 +219,7 @@ def __init__(self, parser: ArgumentParser, args: Namespace, unknown_args: Option self.enable_compose_onnx_models = False self.enable_openvino_encapsulation = False + @action def run(self): return self._run_workflow() diff --git a/olive/cli/quantize.py b/olive/cli/quantize.py index 36ebc8624..b24f0ab3c 100644 --- a/olive/cli/quantize.py +++ b/olive/cli/quantize.py @@ -17,6 +17,7 @@ add_logging_options, add_save_config_file_options, add_shared_cache_options, + add_telemetry_options, update_dataset_options, update_input_model_options, update_shared_cache_options, @@ -24,6 +25,7 @@ from olive.common.utils import StrEnumBase, set_nested_dict_value from olive.constants import Precision, QuantAlgorithm, precision_bits_from_precision from olive.package_config import OlivePackageConfig +from olive.telemetry import action class ImplName(StrEnumBase): @@ -94,6 +96,7 @@ def register_subcommand(parser: ArgumentParser): add_shared_cache_options(sub_parser) add_logging_options(sub_parser) add_save_config_file_options(sub_parser) + add_telemetry_options(sub_parser) sub_parser.set_defaults(func=QuantizeCommand) def _check_data_name_arg(self, pinfo): @@ -210,6 +213,7 @@ def _get_run_config(self, tempdir: str) -> dict[str, Any]: self._customize_config(config) return config + @action def run(self): return self._run_workflow() diff --git a/olive/cli/run.py b/olive/cli/run.py index 0db1975e7..6d2a831ae 100644 --- a/olive/cli/run.py +++ b/olive/cli/run.py @@ -4,7 +4,14 @@ # -------------------------------------------------------------------------- from argparse import ArgumentParser -from olive.cli.base import BaseOliveCLICommand, add_input_model_options, add_logging_options, get_input_model_config +from olive.cli.base import ( + BaseOliveCLICommand, + add_input_model_options, + add_logging_options, + add_telemetry_options, + get_input_model_config, +) +from olive.telemetry import action class WorkflowRunCommand(BaseOliveCLICommand): @@ -37,8 +44,10 @@ def register_subcommand(parser: ArgumentParser): enable_onnx=True, required=False, ) + add_telemetry_options(sub_parser) sub_parser.set_defaults(func=WorkflowRunCommand) + @action def run(self): from olive.common.config_utils import load_config_file from olive.workflows import run as olive_run diff --git a/olive/cli/run_pass.py b/olive/cli/run_pass.py index 2f8afeddd..3ed269185 100644 --- a/olive/cli/run_pass.py +++ b/olive/cli/run_pass.py @@ -12,11 +12,14 @@ add_input_model_options, add_logging_options, add_save_config_file_options, + add_telemetry_options, get_input_model_config, update_accelerator_options, ) +from olive.telemetry import action +@action class RunPassCommand(BaseOliveCLICommand): @staticmethod def register_subcommand(parser: ArgumentParser): @@ -62,6 +65,7 @@ def register_subcommand(parser: ArgumentParser): add_logging_options(sub_parser) add_save_config_file_options(sub_parser) + add_telemetry_options(sub_parser) sub_parser.set_defaults(func=RunPassCommand) def _get_run_config(self, tempdir: str) -> dict[str, Any]: diff --git a/olive/cli/session_params_tuning.py b/olive/cli/session_params_tuning.py index a017d7848..69976122a 100644 --- a/olive/cli/session_params_tuning.py +++ b/olive/cli/session_params_tuning.py @@ -14,11 +14,13 @@ add_logging_options, add_save_config_file_options, add_shared_cache_options, + add_telemetry_options, get_input_model_config, update_accelerator_options, update_shared_cache_options, ) from olive.common.utils import set_nested_dict_value +from olive.telemetry import action class SessionParamsTuningCommand(BaseOliveCLICommand): @@ -96,6 +98,7 @@ def register_subcommand(parser: ArgumentParser): add_logging_options(sub_parser) add_save_config_file_options(sub_parser) add_shared_cache_options(sub_parser) + add_telemetry_options(sub_parser) sub_parser.set_defaults(func=SessionParamsTuningCommand) def _update_pass_config(self, default_pass_config) -> dict: @@ -135,6 +138,7 @@ def _get_run_config(self, tempdir) -> dict: update_shared_cache_options(config, self.args) return config + @action def run(self): workflow_output = self._run_workflow() diff --git a/olive/cli/shared_cache.py b/olive/cli/shared_cache.py index 40c72e460..78c89b5a5 100644 --- a/olive/cli/shared_cache.py +++ b/olive/cli/shared_cache.py @@ -4,8 +4,9 @@ # -------------------------------------------------------------------------- import logging -from olive.cli.base import BaseOliveCLICommand +from olive.cli.base import BaseOliveCLICommand, add_telemetry_options from olive.common.container_client_factory import AzureContainerClientFactory +from olive.telemetry import action logger = logging.getLogger(__name__) @@ -47,8 +48,10 @@ def register_subcommand(parser): type=str, help="The model hash to remove from the shared cache.", ) + add_telemetry_options(sub_parser) sub_parser.set_defaults(func=SharedCacheCommand) + @action def run(self): container_client_factory = AzureContainerClientFactory(self.args.account, self.args.container) if self.args.delete: diff --git a/olive/engine/engine.py b/olive/engine/engine.py index 515fb4ac1..de6b7019a 100644 --- a/olive/engine/engine.py +++ b/olive/engine/engine.py @@ -29,6 +29,7 @@ from olive.search.search_strategy import SearchStrategy, SearchStrategyConfig from olive.systems.common import SystemType from olive.systems.system_config import SystemConfig +from olive.telemetry import action if TYPE_CHECKING: from olive.engine.packaging.packaging_config import PackagingConfig @@ -148,6 +149,7 @@ def register( def set_input_passes_configs(self, pass_configs: dict[str, list[RunPassConfig]]): self.input_passes_configs = pass_configs + @action def run( self, input_model_config: ModelConfig, diff --git a/olive/telemetry/__init__.py b/olive/telemetry/__init__.py new file mode 100644 index 000000000..0ecbbc705 --- /dev/null +++ b/olive/telemetry/__init__.py @@ -0,0 +1,8 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from olive.telemetry.telemetry import Telemetry +from olive.telemetry.telemetry_extensions import action + +__all__ = ["Telemetry", "action"] diff --git a/olive/telemetry/constants.py b/olive/telemetry/constants.py new file mode 100644 index 000000000..ca9e150b1 --- /dev/null +++ b/olive/telemetry/constants.py @@ -0,0 +1,8 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""OneCollector connection string.""" + +CONNECTION_STRING = "SW5zdHJ1bWVudGF0aW9uS2V5PTlkNWRkYWVjNjFlMjQ1NjdiNzg4YTIwYWVhMzI0NjMxLTcyMzdkN2M2LWVlNjEtNGNmZC1iYjdiLTU5MDNhOTcyYzJlNC03MDQ3" diff --git a/olive/telemetry/deviceid/__init__.py b/olive/telemetry/deviceid/__init__.py new file mode 100644 index 000000000..50698c12d --- /dev/null +++ b/olive/telemetry/deviceid/__init__.py @@ -0,0 +1,3 @@ +from olive.telemetry.deviceid.deviceid import get_encrypted_device_id_and_status + +__all__ = ["get_encrypted_device_id_and_status"] diff --git a/olive/telemetry/deviceid/_store.py b/olive/telemetry/deviceid/_store.py new file mode 100644 index 000000000..7a22736b3 --- /dev/null +++ b/olive/telemetry/deviceid/_store.py @@ -0,0 +1,75 @@ +from pathlib import Path + +from olive.telemetry.utils import get_telemetry_base_dir + +REGISTRY_PATH = r"SOFTWARE\Microsoft\DeveloperTools\.onnxruntime" +REGISTRY_KEY = "deviceid" + + +class Store: + def __init__(self) -> None: + self._file_path: Path = self._build_path + + @property + def _build_path(self) -> Path: + return get_telemetry_base_dir() / "deviceid" + + @property + def retrieve_id(self) -> str: + """Retrieve the device id from the store location. + + :return: The device id. + :rtype: str + """ + # check if file doesnt exist and raise an Exception + if not self._file_path.is_file(): + raise FileExistsError(f"File {self._file_path.stem} does not exist") + + return self._file_path.read_text(encoding="utf-8").strip() + + def store_id(self, device_id: str) -> None: + """Store the device id in the store location. + + :param str device_id: The device id to store. + :type device_id: str + """ + # create the folder location if it does not exist + try: + self._file_path.parent.mkdir(parents=True) + except FileExistsError: + # This is unexpected, but is not an issue, + # since we want this file path to exist. + pass + + self._file_path.touch() + self._file_path.write_text(device_id, encoding="utf-8") + + +class WindowsStore: + @property + def retrieve_id(self) -> str: + """Retrieve the device id from the Windows registry.""" + import winreg + + device_id: str + + with winreg.OpenKeyEx( + winreg.HKEY_CURRENT_USER, REGISTRY_PATH, reserved=0, access=winreg.KEY_READ | winreg.KEY_WOW64_64KEY + ) as key_handle: + device_id = winreg.QueryValueEx(key_handle, REGISTRY_KEY) + return device_id[0].strip() + + def store_id(self, device_id: str) -> None: + """Store the device id in the windows registry. + + :param str device_id: The device id to store. + """ + import winreg + + with winreg.CreateKeyEx( + winreg.HKEY_CURRENT_USER, + REGISTRY_PATH, + reserved=0, + access=winreg.KEY_ALL_ACCESS | winreg.KEY_WOW64_64KEY, + ) as key_handle: + winreg.SetValueEx(key_handle, REGISTRY_KEY, 0, winreg.REG_SZ, device_id) diff --git a/olive/telemetry/deviceid/deviceid.py b/olive/telemetry/deviceid/deviceid.py new file mode 100644 index 000000000..09087f33c --- /dev/null +++ b/olive/telemetry/deviceid/deviceid.py @@ -0,0 +1,101 @@ +import hashlib +import platform +import uuid +from enum import Enum +from typing import Union + +from olive.telemetry.deviceid._store import Store, WindowsStore + + +class DeviceIdStatus(Enum): + NEW = "new" + EXISTING = "existing" + CORRUPTED = "corrupted" + FAILED = "failed" + + +_device_id_state = {"device_id": None, "status": DeviceIdStatus.NEW} + + +def get_device_id() -> str: + r"""Get the device id from the store or create one if it does not exist. + + An empty string is returned if an error occurs during saving or retrieval of the device id. + + Linux id location: $XDG_CACHE_HOME/Microsoft/DeveloperTools/.onnxruntime/deviceid if defined + else $HOME/.cache/Microsoft/DeveloperTools/.onnxruntime/deviceid + MacOS id location: $HOME/Library/Application Support/Microsoft/DeveloperTools/.onnxruntime/deviceid + Windows id location: HKEY_CURRENT_USER\SOFTWARE\Microsoft\.onnxruntime\deviceid + + :return: The device id. + :rtype: str + """ + device_id: str = "" + store: Union[Store, WindowsStore] + create_new_id = False + + try: + if platform.system() == "Windows": + store = WindowsStore() + elif platform.system() in ("Linux", "Darwin"): + store = Store() + else: + _device_id_state["status"] = DeviceIdStatus.FAILED + _device_id_state["device_id"] = device_id + return device_id + + device_id = store.retrieve_id + if len(device_id) > 256: + _device_id_state["status"] = DeviceIdStatus.CORRUPTED + _device_id_state["device_id"] = "" + create_new_id = True + else: + try: + uuid.UUID(device_id) + except ValueError: + _device_id_state["status"] = DeviceIdStatus.CORRUPTED + _device_id_state["device_id"] = "" + create_new_id = True + else: + _device_id_state["status"] = DeviceIdStatus.EXISTING + _device_id_state["device_id"] = device_id + return device_id + except (FileExistsError, FileNotFoundError): + _device_id_state["status"] = DeviceIdStatus.NEW + _device_id_state["device_id"] = "" + create_new_id = True + except (PermissionError, ValueError, NotImplementedError): + _device_id_state["status"] = DeviceIdStatus.FAILED + _device_id_state["device_id"] = device_id + return device_id + except Exception: + _device_id_state["status"] = DeviceIdStatus.FAILED + _device_id_state["device_id"] = device_id + return device_id + + if create_new_id: + device_id = str(uuid.uuid4()).lower() + + try: + store.store_id(device_id) + except Exception: + _device_id_state["status"] = DeviceIdStatus.FAILED + device_id = "" + _device_id_state["device_id"] = device_id + + return device_id + + +def get_encrypted_device_id_and_status() -> tuple[str, DeviceIdStatus]: + """Generate a FIPS-compliant encrypted device ID using SHA256 and returns the deviceIdStatus. + + This method uses SHA256 which is FIPS 140-2 approved for cryptographic operations. + The device ID is hashed to ensure deterministic but secure device identification. + + Returns: + str: FIPS-compliant encrypted device ID (base64-encoded) + + """ + device_id = _device_id_state["device_id"] if _device_id_state["device_id"] is not None else get_device_id() + encrypted_device_id = hashlib.sha256(device_id.encode("utf-8")).digest().hex().upper() if device_id else "" + return encrypted_device_id, _device_id_state["status"] diff --git a/olive/telemetry/library/__init__.py b/olive/telemetry/library/__init__.py new file mode 100644 index 000000000..39831da66 --- /dev/null +++ b/olive/telemetry/library/__init__.py @@ -0,0 +1,84 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""OneCollector Exporter for OpenTelemetry Python. + +This package provides an OpenTelemetry exporter that sends telemetry data +to Microsoft OneCollector using the Common Schema JSON format. + +Example usage: + + from onecollector_exporter import ( + OneCollectorLogExporter, + OneCollectorExporterOptions, + get_telemetry_logger, + ) + + # Option 1: Use with OpenTelemetry SDK directly + options = OneCollectorExporterOptions( + connection_string="InstrumentationKey=your-key-here" + ) + exporter = OneCollectorLogExporter(options=options) + + # Add to logger provider + from opentelemetry.sdk._logs import LoggerProvider + from opentelemetry.sdk._logs.export import BatchLogRecordProcessor + + provider = LoggerProvider() + provider.add_log_record_processor(BatchLogRecordProcessor(exporter)) + + # Option 2: Use the simplified telemetry logger + logger = get_telemetry_logger( + connection_string="InstrumentationKey=your-key-here" + ) + logger.log("MyEvent", {"key": "value"}) + logger.shutdown() +""" + +from olive.telemetry.library.callback_manager import CallbackManager, PayloadTransmittedCallbackArgs +from olive.telemetry.library.connection_string_parser import ConnectionStringParser +from olive.telemetry.library.event_source import OneCollectorEventId, OneCollectorEventSource, event_source +from olive.telemetry.library.exporter import OneCollectorLogExporter +from olive.telemetry.library.options import ( + CompressionType, + OneCollectorExporterOptions, + OneCollectorExporterValidationError, + OneCollectorTransportOptions, +) +from olive.telemetry.library.payload_builder import PayloadBuilder +from olive.telemetry.library.retry import RetryHandler +from olive.telemetry.library.serialization import CommonSchemaJsonSerializationHelper +from olive.telemetry.library.telemetry_logger import ( + TelemetryLogger, + get_telemetry_logger, + log_event, + shutdown_telemetry, +) +from olive.telemetry.library.transport import HttpJsonPostTransport, ITransport + +__version__ = "0.0.1" + +__all__ = [ + "CallbackManager", + "CommonSchemaJsonSerializationHelper", + "CompressionType", + "ConnectionStringParser", + "HttpJsonPostTransport", + "ITransport", + "OneCollectorEventId", + "OneCollectorEventSource", + "OneCollectorExporterOptions", + "OneCollectorExporterValidationError", + "OneCollectorLogExporter", + "OneCollectorTransportOptions", + "PayloadBuilder", + "PayloadTransmittedCallbackArgs", + "RetryHandler", + "TelemetryLogger", + "event_source", + "get_telemetry_logger", + "log_event", + "shutdown_telemetry", +] diff --git a/olive/telemetry/library/callback_manager.py b/olive/telemetry/library/callback_manager.py new file mode 100644 index 000000000..ee6255316 --- /dev/null +++ b/olive/telemetry/library/callback_manager.py @@ -0,0 +1,110 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Callback manager for payload transmission events.""" + +import threading +from dataclasses import dataclass +from typing import Callable, Optional + +from olive.telemetry.library.event_source import event_source + + +@dataclass +class PayloadTransmittedCallbackArgs: + """Arguments passed to payload transmitted callbacks.""" + + succeeded: bool + """Whether the transmission succeeded.""" + + status_code: Optional[int] + """HTTP status code, if available.""" + + payload_size_bytes: int + """Size of the transmitted payload in bytes.""" + + item_count: int + """Number of items in the payload.""" + + payload_bytes: Optional[bytes] = None + """Raw payload bytes (uncompressed), if available.""" + + +class CallbackManager: + """Manages callbacks for payload transmission events. + + Allows registration of callbacks that are invoked when payloads + are successfully transmitted or fail. + """ + + def __init__(self): + """Initialize the callback manager.""" + self._callbacks: list[tuple[Callable[[PayloadTransmittedCallbackArgs], None], bool]] = [] + self._lock = threading.Lock() + self._closed = False + + def register( + self, callback: Callable[[PayloadTransmittedCallbackArgs], None], include_failures: bool = False + ) -> Callable[[], None]: + """Register a callback to be invoked on payload transmission. + + Args: + callback: Function to call when payload is transmitted + include_failures: Whether to invoke callback on transmission failures + + Returns: + Function to call to unregister the callback + + """ + with self._lock: + if self._closed: + return lambda: None # No-op unregister if disposed + entry = (callback, include_failures) + self._callbacks.append(entry) + + def unregister(): + """Unregister this callback.""" + with self._lock: + try: + self._callbacks.remove(entry) + except ValueError: + # The callback was already removed. + pass + + return unregister + + def notify(self, args: PayloadTransmittedCallbackArgs) -> None: + """Notify all registered callbacks. + + Args: + args: Callback arguments + + """ + # Get snapshot of callbacks to avoid holding lock during invocation + with self._lock: + if self._closed: + return + callbacks_snapshot = self._callbacks.copy() + + # Invoke callbacks + for callback, include_failures in callbacks_snapshot: + # Check if we should invoke this callback + if not args.succeeded and not include_failures: + continue + + try: + callback(args) + except Exception as ex: + # Log but don't propagate exceptions from user code + event_source.exception_thrown_from_user_code("PayloadTransmittedCallback", ex) + + def close(self) -> None: + """Close the callback manager and prevent further registrations. + + This method is idempotent and can be called multiple times. + """ + with self._lock: + self._callbacks.clear() + self._closed = True diff --git a/olive/telemetry/library/connection_string_parser.py b/olive/telemetry/library/connection_string_parser.py new file mode 100644 index 000000000..dc6e05ed6 --- /dev/null +++ b/olive/telemetry/library/connection_string_parser.py @@ -0,0 +1,44 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Connection string parser for OneCollector exporter.""" + + +class ConnectionStringParser: + """Parses OneCollector connection strings to extract configuration.""" + + def __init__(self, connection_string: str): + """Initialize the parser with a connection string. + + Args: + connection_string: Connection string in the format "Key1=Value1;Key2=Value2" + + Raises: + ValueError: If the connection string is invalid or missing required fields + + """ + if not connection_string: + raise ValueError("Connection string cannot be empty") + + self.instrumentation_key: str | None = None + self._parse(connection_string) + + if not self.instrumentation_key: + raise ValueError("InstrumentationKey not found in connection string") + + def _parse(self, connection_string: str) -> None: + """Parse the connection string into key-value pairs.""" + parts = connection_string.split(";") + for raw_part in parts: + part = raw_part.strip() + if not part or "=" not in part: + continue + + key, value = part.split("=", 1) + key = key.strip().lower() + value = value.strip() + + if key == "instrumentationkey": + self.instrumentation_key = value diff --git a/olive/telemetry/library/event_source.py b/olive/telemetry/library/event_source.py new file mode 100644 index 000000000..e65d9d546 --- /dev/null +++ b/olive/telemetry/library/event_source.py @@ -0,0 +1,257 @@ +"""EventSource-style logging for OneCollector exporter. + +Provides structured logging similar to .NET EventSource for diagnostics and monitoring. +""" + +import logging +from enum import IntEnum + + +class OneCollectorEventId(IntEnum): + """Event IDs matching .NET EventSource implementation.""" + + EXPORT_EXCEPTION = 1 + TRANSPORT_DATA_SENT = 2 + SINK_DATA_WRITTEN = 3 + DATA_DROPPED = 4 + TRANSPORT_EXCEPTION = 5 + HTTP_ERROR_RESPONSE = 6 + EVENT_FULL_NAME_DISCARDED = 7 + EVENT_NAMESPACE_INVALID = 8 + EVENT_NAME_INVALID = 9 + USER_CODE_EXCEPTION = 10 + ATTRIBUTE_DROPPED = 11 + + +class OneCollectorEventSource: + """EventSource for OneCollector exporter diagnostics. + + Provides structured logging matching the .NET EventSource implementation. + """ + + def __init__(self): + self.logger = logging.getLogger("OpenTelemetry.Exporter.OneCollector") + # Set default level to INFO to match .NET behavior + if not self.logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + handler.setFormatter(formatter) + self.logger.addHandler(handler) + self.logger.setLevel(logging.INFO) + + @property + def is_informational_logging_enabled(self) -> bool: + """Check if informational level logging is enabled.""" + return self.logger.isEnabledFor(logging.INFO) + + @property + def is_warning_logging_enabled(self) -> bool: + """Check if warning level logging is enabled.""" + return self.logger.isEnabledFor(logging.WARNING) + + @property + def is_error_logging_enabled(self) -> bool: + """Check if error level logging is enabled.""" + return self.logger.isEnabledFor(logging.ERROR) + + def export_exception_thrown(self, item_type: str, exception: Exception) -> None: + """Log an exception thrown during export. + + Args: + item_type: Type of item being exported (e.g., 'LogData') + exception: The exception that was thrown + + """ + if self.is_error_logging_enabled: + self.logger.error( + "Exception thrown exporting '%s' batch: %s", + item_type, + exception, + exc_info=exception, + extra={"event_id": OneCollectorEventId.EXPORT_EXCEPTION}, + ) + + def transport_data_sent(self, item_type: str, num_records: int, transport_description: str) -> None: + """Log successful data transmission. + + Args: + item_type: Type of items sent + num_records: Number of records sent + transport_description: Description of transport used + + """ + if self.is_informational_logging_enabled: + self.logger.info( + "Sent '%s' batch of %s item(s) to '%s' transport", + item_type, + num_records, + transport_description, + extra={"event_id": OneCollectorEventId.TRANSPORT_DATA_SENT}, + ) + + def sink_data_written(self, item_type: str, num_records: int, sink_description: str) -> None: + """Log data written to sink. + + Args: + item_type: Type of items written + num_records: Number of records written + sink_description: Description of sink used + + """ + if self.is_informational_logging_enabled: + self.logger.info( + "Wrote '%s' batch of %s item(s) to '%s' sink", + item_type, + num_records, + sink_description, + extra={"event_id": OneCollectorEventId.SINK_DATA_WRITTEN}, + ) + + def data_dropped( + self, item_type: str, num_records: int, during_serialization: int, during_transmission: int + ) -> None: + """Log dropped data. + + Args: + item_type: Type of items dropped + num_records: Total number of records dropped + during_serialization: Number dropped during serialization + during_transmission: Number dropped during transmission + + """ + if self.is_warning_logging_enabled: + self.logger.warning( + "Dropped %s '%s' item(s). %s item(s) dropped during serialization. %s item(s) dropped due to " + "transmission failure", + num_records, + item_type, + during_serialization, + during_transmission, + extra={"event_id": OneCollectorEventId.DATA_DROPPED}, + ) + + def transport_exception_thrown(self, transport_type: str, exception: Exception) -> None: + """Log transport exception. + + Args: + transport_type: Type of transport + exception: The exception that was thrown + + """ + if self.is_error_logging_enabled: + self.logger.error( + "Exception thrown by '%s' transport: %s", + transport_type, + exception, + exc_info=exception, + extra={"event_id": OneCollectorEventId.TRANSPORT_EXCEPTION}, + ) + + def http_transport_error_response( + self, transport_type: str, status_code: int, error_message: str, error_details: str + ) -> None: + """Log HTTP error response. + + Args: + transport_type: Type of transport + status_code: HTTP status code + error_message: Error message from response + error_details: Additional error details + + """ + if self.is_error_logging_enabled: + self.logger.error( + "Error response received by '%s' transport. StatusCode: %s, ErrorMessage: '%s', ErrorDetails: '%s'", + transport_type, + status_code, + error_message, + error_details, + extra={"event_id": OneCollectorEventId.HTTP_ERROR_RESPONSE}, + ) + + def event_full_name_discarded(self, event_namespace: str, event_name: str) -> None: + """Log event full name discarded. + + Args: + event_namespace: Event namespace + event_name: Event name + + """ + if self.is_warning_logging_enabled: + self.logger.warning( + "Event full name discarded. EventNamespace: '%s', EventName: '%s'", + event_namespace, + event_name, + extra={"event_id": OneCollectorEventId.EVENT_FULL_NAME_DISCARDED}, + ) + + def event_namespace_invalid(self, event_namespace: str) -> None: + """Log invalid event namespace. + + Args: + event_namespace: The invalid namespace + + """ + if self.is_warning_logging_enabled: + self.logger.warning( + "Event namespace invalid. EventNamespace: '%s'", + event_namespace, + extra={"event_id": OneCollectorEventId.EVENT_NAMESPACE_INVALID}, + ) + + def event_name_invalid(self, event_name: str) -> None: + """Log invalid event name. + + Args: + event_name: The invalid event name + + """ + if self.is_warning_logging_enabled: + self.logger.warning( + "Event name invalid. EventName: '%s'", + event_name, + extra={"event_id": OneCollectorEventId.EVENT_NAME_INVALID}, + ) + + def exception_thrown_from_user_code(self, user_code_type: str, exception: Exception) -> None: + """Log exception from user code (e.g., callbacks). + + Args: + user_code_type: Type of user code that threw exception + exception: The exception that was thrown + + """ + if self.is_error_logging_enabled: + self.logger.error( + "Exception thrown by '%s' user code: %s", + user_code_type, + exception, + exc_info=exception, + extra={"event_id": OneCollectorEventId.USER_CODE_EXCEPTION}, + ) + + def attribute_dropped(self, item_type: str, attribute_name: str, reason: str) -> None: + """Log dropped attribute. + + Args: + item_type: Type of item + attribute_name: Name of dropped attribute + reason: Reason for dropping + + """ + if self.is_warning_logging_enabled: + self.logger.warning( + "Dropped %s attribute '%s': %s", + item_type, + attribute_name, + reason, + extra={"event_id": OneCollectorEventId.ATTRIBUTE_DROPPED}, + ) + + def disable(self) -> None: + """Disable telemetry logging.""" + self.logger.disabled = True + + +# Global event source instance +event_source = OneCollectorEventSource() diff --git a/olive/telemetry/library/exporter.py b/olive/telemetry/library/exporter.py new file mode 100644 index 000000000..68c57dccf --- /dev/null +++ b/olive/telemetry/library/exporter.py @@ -0,0 +1,326 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Main OneCollector log exporter implementation.""" + +import threading +from collections.abc import Sequence +from datetime import datetime, timezone +from time import time +from typing import TYPE_CHECKING, Any, Callable, Optional + +import requests +from opentelemetry.sdk._logs import ReadableLogRecord +from opentelemetry.sdk._logs.export import LogExportResult, LogRecordExporter +from opentelemetry.sdk.resources import Resource + +from olive.telemetry.library.callback_manager import CallbackManager +from olive.telemetry.library.event_source import event_source +from olive.telemetry.library.options import OneCollectorExporterOptions +from olive.telemetry.library.payload_builder import PayloadBuilder +from olive.telemetry.library.retry import RetryHandler +from olive.telemetry.library.serialization import CommonSchemaJsonSerializationHelper +from olive.telemetry.library.transport import HttpJsonPostTransport + +if TYPE_CHECKING: + from olive.telemetry.library.callback_manager import PayloadTransmittedCallbackArgs + + +class OneCollectorLogExporter(LogRecordExporter): + """OpenTelemetry log exporter for Microsoft OneCollector. + + Implements the OpenTelemetry LogRecordExporter interface and sends logs + to OneCollector using the Common Schema JSON format. + """ + + def __init__( + self, + options: Optional[OneCollectorExporterOptions] = None, + excluded_attributes: Optional[set[str]] = None, + ): + """Initialize the OneCollector log exporter. + + Args: + options: Exporter configuration options + excluded_attributes: Attribute keys to exclude from log attributes + + """ + # Validate options + options.validate() + + self._options = options + self._shutdown_lock = threading.Lock() + self._shutdown = False + self._shutdown_event = threading.Event() + if excluded_attributes is None: + self._excluded_attributes = { + "code.filepath", + "code.function", + "code.lineno", + "code.file.path", + "code.function.name", + "code.line.number", + } + else: + self._excluded_attributes = set(excluded_attributes) + + # Initialize transport + transport_opts = options.transport_options + + # Create or get HTTP session + if transport_opts.http_client_factory: + self._session = transport_opts.http_client_factory() + else: + self._session = requests.Session() + + # Build iKey with tenant prefix + self._ikey = f"{CommonSchemaJsonSerializationHelper.ONE_COLLECTOR_TENANCY_SYMBOL}:{options.tenant_token}" + + # Initialize callback manager + self._callback_manager = CallbackManager() + + # Initialize transport with callback manager + self._transport = HttpJsonPostTransport( + endpoint=transport_opts.endpoint, + ikey=options.instrumentation_key, + compression=transport_opts.compression, + session=self._session, + callback_manager=self._callback_manager, + ) + + # Initialize payload builder + self._payload_builder = PayloadBuilder( + max_size_bytes=transport_opts.max_payload_size_bytes, max_items=transport_opts.max_items_per_payload + ) + + # Initialize retry handler + self._retry_handler = RetryHandler(max_retries=6) + + # Initialize metadata + self._metadata: dict[str, Any] = {} + + # Cache for resource (populated on first export) + self._resource: Optional[Resource] = None + + def add_metadata(self, metadata: dict[str, Any]) -> None: + """Add custom metadata fields to all exported logs. + + Args: + metadata: Dictionary of metadata fields to add + + """ + self._metadata.update(metadata) + + def register_payload_transmitted_callback( + self, callback: Callable[["PayloadTransmittedCallbackArgs"], None], include_failures: bool = False + ) -> Callable[[], None]: + """Register a callback that will be invoked on payload transmission. + + Callbacks are invoked after each HTTP request completes. If retries are + enabled, callbacks will be invoked for each retry attempt. + + Args: + callback: Function to call when payload is transmitted. + Receives PayloadTransmittedCallbackArgs with transmission details. + include_failures: If True, callback is invoked on both success and failure. + If False, callback is only invoked on success. + + Returns: + Function to call to unregister the callback. + + Example: + >>> def on_transmitted(args): + ... if args.succeeded: + ... print(f"✅ Sent {args.item_count} items ({args.payload_size_bytes} bytes)") + ... else: + ... print(f"❌ Failed: status={args.status_code}") + >>> + >>> unregister = exporter.register_payload_transmitted_callback( + ... on_transmitted, + ... include_failures=True + ... ) + >>> # Later: unregister() + + """ + return self._transport.register_payload_transmitted_callback(callback, include_failures) + + def export(self, batch: Sequence[ReadableLogRecord]) -> LogExportResult: + """Export a batch of log records. + + Args: + batch: Sequence of log data records to export + + Returns: + LogExportResult indicating success or failure + + """ + if self._shutdown: + return LogExportResult.FAILURE + + try: + # Get resource (cache for subsequent calls) + if self._resource is None: + first_item = batch[0] if batch else None + resource = getattr(first_item, "resource", None) + if resource is None and first_item is not None: + resource = getattr(first_item.log_record, "resource", None) + self._resource = resource or Resource.create() + + # Serialize log records to JSON + serialized_items = [] + for log_data in batch: + try: + item_bytes = self._serialize_log_data(log_data) + serialized_items.append(item_bytes) + except Exception as ex: + event_source.export_exception_thrown("ReadableLogRecord", ex) + # Continue with other items + + if not serialized_items: + return LogExportResult.FAILURE + + # Build payloads respecting size/count limits + payloads = self._build_payloads(serialized_items) + + # Send each payload with retry logic + deadline_sec = time() + self._options.transport_options.timeout_seconds + + for payload in payloads: + # Count items in this payload (approximation based on newlines) + item_count = payload.count(b"\n") + 1 if payload else 0 + success = self._retry_handler.execute_with_retry( + operation=lambda payload=payload, item_count=item_count: self._transport.send( + payload, deadline_sec - time(), item_count=item_count + ), + deadline_sec=deadline_sec, + shutdown_event=self._shutdown_event, + ) + + if not success: + return LogExportResult.FAILURE + + # Check if shutdown occurred + if self._shutdown: + return LogExportResult.FAILURE + + # Log success + event_source.sink_data_written("ReadableLogRecord", len(batch), "OneCollector") + + return LogExportResult.SUCCESS + + except Exception as ex: + event_source.export_exception_thrown("ReadableLogRecord", ex) + return LogExportResult.FAILURE + + def _serialize_log_data(self, log_data: ReadableLogRecord) -> bytes: + """Serialize a single log record to JSON bytes. + + Args: + log_data: Log data to serialize + + Returns: + UTF-8 encoded JSON bytes + + """ + log_record = log_data.log_record + + # Build data dictionary + data = {} + + # Add resource attributes (if available) + if self._resource and self._resource.attributes: + for key, value in self._resource.attributes.items(): + # Map common resource attributes + if key == "service.name" and "app_name" not in data: + data["app_name"] = value + elif key == "service.version" and "app_version" not in data: + data["app_version"] = value + elif key == "service.instance.id" and "app_instance_id" not in data: + data["app_instance_id"] = value + else: + data[key] = value + + # Add log record attributes (override resource attributes) + if log_record.attributes: + data.update( + {key: value for key, value in log_record.attributes.items() if key not in self._excluded_attributes} + ) + + # Add custom metadata + data.update(self._metadata) + + # Format timestamp + if log_record.timestamp: + timestamp = datetime.fromtimestamp(log_record.timestamp / 1e9, tz=timezone.utc) + else: + timestamp = datetime.now(timezone.utc) + + # Create event envelope + event_name = str(log_record.body) if log_record.body else "UnnamedEvent" + + envelope = CommonSchemaJsonSerializationHelper.create_event_envelope( + event_name=event_name, timestamp=timestamp, ikey=self._ikey, data=data + ) + + # Serialize to JSON bytes + return CommonSchemaJsonSerializationHelper.serialize_to_json_bytes(envelope) + + def _build_payloads(self, serialized_items: list[bytes]) -> list[bytes]: + """Build payloads from serialized items respecting size and count limits. + + Args: + serialized_items: List of serialized item bytes + + Returns: + List of payload bytes + + """ + payloads = [] + self._payload_builder.reset() + + for item_bytes in serialized_items: + if not self._payload_builder.can_add(item_bytes) and not self._payload_builder.is_empty: + # Current payload is full, build it and start a new one + payloads.append(self._payload_builder.build()) + self._payload_builder.reset() + + self._payload_builder.add(item_bytes) + + # Build final payload + if not self._payload_builder.is_empty: + payloads.append(self._payload_builder.build()) + + return payloads + + def force_flush(self, timeout_millis: float = 10_000) -> bool: + """Force flush any buffered data. + + Note: This exporter doesn't buffer data internally, so this is a no-op. + + Args: + timeout_millis: Timeout in milliseconds + + Returns: + True (always succeeds) + + """ + return True + + def shutdown(self) -> None: + """Shutdown the exporter and release resources.""" + with self._shutdown_lock: + if self._shutdown: + return + + self._shutdown = True + self._shutdown_event.set() + + # Close HTTP session + if hasattr(self, "_session"): + self._session.close() + + # Close callback manager + if hasattr(self, "_callback_manager"): + self._callback_manager.close() diff --git a/olive/telemetry/library/options.py b/olive/telemetry/library/options.py new file mode 100644 index 000000000..dd934cad2 --- /dev/null +++ b/olive/telemetry/library/options.py @@ -0,0 +1,104 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Configuration options for OneCollector exporter.""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Callable, Optional + +import requests + +from olive.telemetry.library.connection_string_parser import ConnectionStringParser + + +class CompressionType(Enum): + """HTTP compression types supported by OneCollector.""" + + NO_COMPRESSION = "none" + DEFLATE = "deflate" + GZIP = "gzip" + + +@dataclass +class OneCollectorTransportOptions: + """Transport configuration options for OneCollector exporter.""" + + DEFAULT_ENDPOINT = "https://mobile.events.data.microsoft.com/OneCollector/1.0/" + DEFAULT_MAX_PAYLOAD_SIZE_BYTES = 4 * 1024 * 1024 # 4MB + DEFAULT_MAX_ITEMS_PER_PAYLOAD = 1500 + + endpoint: str = DEFAULT_ENDPOINT + max_payload_size_bytes: int = DEFAULT_MAX_PAYLOAD_SIZE_BYTES + max_items_per_payload: int = DEFAULT_MAX_ITEMS_PER_PAYLOAD + compression: CompressionType = CompressionType.DEFLATE + timeout_seconds: float = 10.0 + http_client_factory: Optional[Callable[[], requests.Session]] = None + + def validate(self) -> None: + """Validate the transport options. + + Raises: + OneCollectorExporterValidationError: If any option is invalid + + """ + if not self.endpoint: + raise OneCollectorExporterValidationError("Endpoint is required") + + if self.max_payload_size_bytes <= 0 and self.max_payload_size_bytes != -1: + raise OneCollectorExporterValidationError("max_payload_size_bytes must be positive or -1 for unlimited") + + if self.max_items_per_payload <= 0 and self.max_items_per_payload != -1: + raise OneCollectorExporterValidationError("max_items_per_payload must be positive or -1 for unlimited") + + if self.timeout_seconds <= 0: + raise OneCollectorExporterValidationError("timeout_seconds must be positive") + + +@dataclass +class OneCollectorExporterOptions: + """Configuration options for OneCollector exporter.""" + + connection_string: Optional[str] = None + transport_options: OneCollectorTransportOptions = field(default_factory=OneCollectorTransportOptions) + + # Internal fields populated during validation + instrumentation_key: Optional[str] = field(default=None, init=False) + tenant_token: Optional[str] = field(default=None, init=False) + + def validate(self) -> None: + """Validate the exporter options and populate derived fields. + + Raises: + OneCollectorExporterValidationError: If any option is invalid + + """ + if not self.connection_string: + raise OneCollectorExporterValidationError("ConnectionString is required") + + # Parse connection string + try: + parser = ConnectionStringParser(self.connection_string) + except ValueError as ex: + raise OneCollectorExporterValidationError(str(ex)) from ex + + self.instrumentation_key = parser.instrumentation_key + + if not self.instrumentation_key: + raise OneCollectorExporterValidationError("Instrumentation key not found in connection string") + + # Extract tenant token (part before first dash) + dash_pos = self.instrumentation_key.find("-") + if dash_pos < 0: + raise OneCollectorExporterValidationError(f"Invalid instrumentation key format: {self.instrumentation_key}") + + self.tenant_token = self.instrumentation_key[:dash_pos] + + # Validate transport options + self.transport_options.validate() + + +class OneCollectorExporterValidationError(Exception): + """Exception raised when OneCollector exporter options validation fails.""" diff --git a/olive/telemetry/library/payload_builder.py b/olive/telemetry/library/payload_builder.py new file mode 100644 index 000000000..aea601538 --- /dev/null +++ b/olive/telemetry/library/payload_builder.py @@ -0,0 +1,93 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Payload builder for batching telemetry items.""" + + +class PayloadBuilder: + """Builds payloads respecting size and item count limits. + + Matches the batching logic from the .NET implementation. + """ + + NEWLINE_SEPARATOR = b"\n" + + def __init__(self, max_size_bytes: int, max_items: int): + """Initialize payload builder. + + Args: + max_size_bytes: Maximum payload size in bytes (-1 for unlimited) + max_items: Maximum number of items per payload (-1 for unlimited) + + """ + self.max_size_bytes = max_size_bytes + self.max_items = max_items + self.reset() + + def reset(self) -> None: + """Reset the builder to start a new payload.""" + self.items: list[bytes] = [] + self.current_size = 0 + + def can_add(self, item_bytes: bytes) -> bool: + """Check if an item can be added to the current payload. + + Args: + item_bytes: Serialized item bytes + + Returns: + True if item can be added without exceeding limits + + """ + # Check item count limit + if self.max_items != -1 and len(self.items) >= self.max_items: + return False + + # Check size limit + if self.max_size_bytes != -1: + # Calculate new size including newline separator + separator_size = len(self.NEWLINE_SEPARATOR) if self.items else 0 + new_size = self.current_size + len(item_bytes) + separator_size + + if new_size > self.max_size_bytes: + return False + + return True + + def add(self, item_bytes: bytes) -> None: + """Add an item to the current payload. + + Args: + item_bytes: Serialized item bytes + + """ + self.items.append(item_bytes) + self.current_size += len(item_bytes) + + # Account for newline separator (except for first item) + if len(self.items) > 1: + self.current_size += len(self.NEWLINE_SEPARATOR) + + def build(self) -> bytes: + """Build the final payload. + + Returns: + Newline-delimited payload bytes (x-json-stream format) + + """ + if not self.items: + return b"" + + return self.NEWLINE_SEPARATOR.join(self.items) + + @property + def item_count(self) -> int: + """Get the number of items in the current payload.""" + return len(self.items) + + @property + def is_empty(self) -> bool: + """Check if the payload is empty.""" + return len(self.items) == 0 diff --git a/olive/telemetry/library/retry.py b/olive/telemetry/library/retry.py new file mode 100644 index 000000000..9f0cc7cfd --- /dev/null +++ b/olive/telemetry/library/retry.py @@ -0,0 +1,98 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Retry logic with exponential backoff for OneCollector exporter.""" + +import random +import threading +from time import time +from typing import Callable, Optional + +from olive.telemetry.library.event_source import event_source +from olive.telemetry.library.transport import HttpJsonPostTransport + + +class RetryHandler: + """Handles retry logic with exponential backoff and jitter. + + Implements retry strategy matching the .NET implementation. + """ + + def __init__(self, max_retries: int = 6, base_delay: float = 1.0, max_delay: float = 60.0): + """Initialize retry handler. + + Args: + max_retries: Maximum number of retry attempts + base_delay: Base delay for exponential backoff (seconds) + max_delay: Maximum delay between retries (seconds) + + """ + self.max_retries = max_retries + self.base_delay = base_delay + self.max_delay = max_delay + + def execute_with_retry( + self, + operation: Callable[[], tuple[bool, Optional[int]]], + deadline_sec: float, + shutdown_event: threading.Event, + ) -> bool: + """Execute an operation with retry logic. + + Args: + operation: Function that returns (success, status_code) + deadline_sec: Absolute deadline timestamp + shutdown_event: Event to signal shutdown + + Returns: + True if operation succeeded, False otherwise + + """ + for retry_num in range(self.max_retries): + # Check if we've exceeded the deadline + remaining_time = deadline_sec - time() + if remaining_time <= 0: + return False + + try: + # Execute the operation + success, status_code = operation() + + if success: + return True + + # Check if response is retryable + if not HttpJsonPostTransport.is_retryable(status_code): + return False + + except Exception as ex: + event_source.export_exception_thrown("RetryHandler", ex) + + # Last retry - don't wait + if retry_num + 1 == self.max_retries: + return False + + # Last retry - failed + if retry_num + 1 == self.max_retries: + return False + + # Calculate backoff with exponential increase and jitter + backoff = min(self.base_delay * (2**retry_num), self.max_delay) + # Add +/-20% jitter + backoff *= random.uniform(0.8, 1.2) + + # Don't wait longer than remaining time + remaining_time = deadline_sec - time() + wait_time = min(backoff, remaining_time) + + if wait_time <= 0: + return False + + # Wait with ability to interrupt on shutdown + if shutdown_event.wait(wait_time): + # Shutdown occurred + return False + + return False diff --git a/olive/telemetry/library/serialization.py b/olive/telemetry/library/serialization.py new file mode 100644 index 000000000..069f85d7e --- /dev/null +++ b/olive/telemetry/library/serialization.py @@ -0,0 +1,141 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""JSON serialization helper for Common Schema format.""" + +import base64 +import json +from datetime import date, datetime, time, timedelta, timezone +from typing import Any +from uuid import UUID + + +class CommonSchemaJsonSerializationHelper: + """Helper class for serializing values to Common Schema JSON format. + + Matches the .NET implementation in CommonSchemaJsonSerializationHelper.cs + """ + + # Common Schema constants + ONE_COLLECTOR_TENANCY_SYMBOL = "o" + SCHEMA_VERSION = "4.0" + + @staticmethod + def serialize_value(value: Any) -> Any: + """Serialize a Python value to JSON-compatible format. + + Args: + value: The value to serialize + + Returns: + JSON-serializable representation of the value + + """ + if value is None: + return None + + # Boolean + if isinstance(value, bool): + return value + + # Numeric types + if isinstance(value, (int, float)): + return value + + # String + if isinstance(value, str): + return value + + # DateTime types + if isinstance(value, datetime): + # Convert to UTC ISO 8601 format with 'Z' suffix + if value.tzinfo is None: + # Assume naive datetime is UTC + return value.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" + utc_value = value.astimezone(timezone.utc) + return utc_value.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" + + if isinstance(value, date): + return value.isoformat() + + if isinstance(value, time): + return value.isoformat() + + if isinstance(value, timedelta): + # Format as ISO 8601 duration + total_seconds = int(value.total_seconds()) + hours, remainder = divmod(abs(total_seconds), 3600) + minutes, seconds = divmod(remainder, 60) + sign = "-" if total_seconds < 0 else "" + return f"{sign}{hours:02d}:{minutes:02d}:{seconds:02d}" + + # UUID/GUID + if isinstance(value, UUID): + return str(value) + + # Bytes - encode as base64 + if isinstance(value, (bytes, bytearray)): + return base64.b64encode(bytes(value)).decode("ascii") + + # Arrays/Lists + if isinstance(value, (list, tuple)): + return [CommonSchemaJsonSerializationHelper.serialize_value(item) for item in value] + + # Dictionary/Map + if isinstance(value, dict): + result = {} + for k, v in value.items(): + if k: # Skip empty keys + result[str(k)] = CommonSchemaJsonSerializationHelper.serialize_value(v) + return result + + # Default: convert to string + try: + return str(value) + except Exception: + return f"ERROR: type {type(value).__name__} is not supported" + + @staticmethod + def create_event_envelope( + event_name: str, timestamp: datetime, ikey: str, data: dict[str, Any], extensions: dict[str, Any] | None = None + ) -> dict[str, Any]: + """Create a Common Schema event envelope. + + Args: + event_name: Full event name (namespace.name) + timestamp: Event timestamp + ikey: Instrumentation key with tenant prefix + data: Event data/attributes + extensions: Optional extension fields + + Returns: + Common Schema event envelope as dictionary + + """ + envelope = { + "ver": CommonSchemaJsonSerializationHelper.SCHEMA_VERSION, + "name": event_name, + "time": CommonSchemaJsonSerializationHelper.serialize_value(timestamp), + "iKey": ikey, + "data": CommonSchemaJsonSerializationHelper.serialize_value(data), + } + + if extensions: + envelope["ext"] = CommonSchemaJsonSerializationHelper.serialize_value(extensions) + + return envelope + + @staticmethod + def serialize_to_json_bytes(envelope: dict[str, Any]) -> bytes: + """Serialize an envelope to JSON bytes. + + Args: + envelope: Event envelope dictionary + + Returns: + UTF-8 encoded JSON bytes + + """ + return json.dumps(envelope, ensure_ascii=False, separators=(",", ":")).encode("utf-8") diff --git a/olive/telemetry/library/telemetry_logger.py b/olive/telemetry/library/telemetry_logger.py new file mode 100644 index 000000000..7eb236e75 --- /dev/null +++ b/olive/telemetry/library/telemetry_logger.py @@ -0,0 +1,197 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""High-level telemetry logger facade for easy usage.""" + +import logging +import uuid +from typing import Any, Callable, Optional + +from opentelemetry._logs import set_logger_provider +from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler +from opentelemetry.sdk._logs.export import BatchLogRecordProcessor +from opentelemetry.sdk.resources import Resource + +from olive.telemetry.library.exporter import OneCollectorLogExporter +from olive.telemetry.library.options import OneCollectorExporterOptions +from olive.version import __version__ as VERSION + + +class TelemetryLogger: + """Singleton telemetry logger for simplified OneCollector integration. + + Provides a simple interface for logging telemetry events without + needing to configure OpenTelemetry directly. + """ + + _instance: Optional["TelemetryLogger"] = None + _default_logger: Optional["TelemetryLogger"] = None + _logger: Optional[logging.Logger] = None + _logger_exporter: Optional[OneCollectorLogExporter] = None + _logger_provider: Optional[LoggerProvider] = None + + def __new__(cls, options: Optional[OneCollectorExporterOptions] = None): + """Create or return the singleton instance. + + Args: + options: Exporter options (only used on first instantiation) + + """ + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialize(options) + + return cls._instance + + def _initialize(self, options: Optional[OneCollectorExporterOptions]) -> None: + """Initialize the logger (called only once). + + Args: + options: Exporter configuration options + + """ + try: + # Create exporter + self._logger_exporter = OneCollectorLogExporter(options=options) + + # Create logger provider + self._logger_provider = LoggerProvider( + resource=Resource.create( + { + "service.name": __name__.split(".", maxsplit=1)[0], + "service.version": VERSION, + "service.instance.id": str(uuid.uuid4()), # Unique instance ID; can double as session ID + } + ) + ) + + # Set as global logger provider + set_logger_provider(self._logger_provider) + + # Add batch processor + self._logger_provider.add_log_record_processor( + BatchLogRecordProcessor( + self._logger_exporter, + schedule_delay_millis=1000, + ) + ) + + # Create logging handler + handler = LoggingHandler(level=logging.INFO, logger_provider=self._logger_provider) + + # Set up Python logger + logger = logging.getLogger(__name__) + logger.propagate = False + logger.setLevel(logging.INFO) + logger.addHandler(handler) + + self._logger = logger + + except Exception: + # Silently fail initialization - logger will be None + self._logger = None + self._logger_provider = None + self._logger_exporter = None + + def add_global_metadata(self, metadata: dict[str, Any]) -> None: + """Add metadata fields to all telemetry events. + + Args: + metadata: Dictionary of metadata to add + + """ + if self._logger_exporter: + self._logger_exporter.add_metadata(metadata) + + def register_payload_transmitted_callback( + self, callback, include_failures: bool = False + ) -> Optional[Callable[[], None]]: + """Register a callback for payload transmission events.""" + if self._logger_exporter: + return self._logger_exporter.register_payload_transmitted_callback(callback, include_failures) + return None + + def log(self, event_name: str, attributes: Optional[dict[str, Any]] = None) -> None: + """Log a telemetry event. + + Args: + event_name: Name of the event + attributes: Optional event attributes + + """ + if self._logger: + extra = attributes if attributes else {} + self._logger.info(event_name, extra=extra) + + def disable_telemetry(self) -> None: + """Disable telemetry logging.""" + if self._logger: + self._logger.disabled = True + + def enable_telemetry(self) -> None: + """Enable telemetry logging.""" + if self._logger: + self._logger.disabled = False + + def shutdown(self) -> None: + """Shutdown the telemetry logger and flush pending data.""" + if self._logger_provider: + self._logger_provider.shutdown() + + @classmethod + def get_default_logger(cls, connection_string: Optional[str] = None) -> "TelemetryLogger": + """Get or create the default telemetry logger. + + Args: + connection_string: OneCollector connection string (only used on first call) + + Returns: + TelemetryLogger instance + + """ + if cls._default_logger is None: + options = None + if connection_string: + options = OneCollectorExporterOptions(connection_string=connection_string) + cls._default_logger = cls(options=options) + + return cls._default_logger + + @classmethod + def shutdown_default_logger(cls) -> None: + """Shutdown the default telemetry logger.""" + if cls._default_logger: + cls._default_logger.shutdown() + cls._default_logger = None + + +def get_telemetry_logger(connection_string: Optional[str] = None) -> TelemetryLogger: + """Get or create the default telemetry logger. + + Args: + connection_string: OneCollector connection string (only used on first call) + + Returns: + TelemetryLogger instance + + """ + return TelemetryLogger.get_default_logger(connection_string=connection_string) + + +def log_event(event_name: str, attributes: Optional[dict[str, Any]] = None) -> None: + """Log a telemetry event using the default logger. + + Args: + event_name: Name of the event + attributes: Optional event attributes + + """ + logger = get_telemetry_logger() + logger.log(event_name, attributes) + + +def shutdown_telemetry() -> None: + """Shutdown the default telemetry logger.""" + TelemetryLogger.shutdown_default_logger() diff --git a/olive/telemetry/library/transport.py b/olive/telemetry/library/transport.py new file mode 100644 index 000000000..6d7e371a7 --- /dev/null +++ b/olive/telemetry/library/transport.py @@ -0,0 +1,261 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""HTTP transport implementation for OneCollector exporter.""" + +import gzip +import zlib +from abc import ABC, abstractmethod +from io import BytesIO +from typing import TYPE_CHECKING, Callable, Optional + +import requests + +from olive.telemetry.library.event_source import event_source +from olive.telemetry.library.options import CompressionType + +if TYPE_CHECKING: + from olive.telemetry.library.callback_manager import CallbackManager, PayloadTransmittedCallbackArgs + + +class ITransport(ABC): + """Abstract base class for transports.""" + + @abstractmethod + def send(self, payload: bytes, timeout_sec: float, item_count: int = 1) -> tuple[bool, Optional[int]]: + """Send a payload. + + Args: + payload: The data to send + timeout_sec: Timeout in seconds + item_count: Number of items in the payload (for callbacks) + + Returns: + Tuple of (success, status_code) + + """ + + @abstractmethod + def register_payload_transmitted_callback( + self, callback: Callable[["PayloadTransmittedCallbackArgs"], None], include_failures: bool = False + ) -> Callable[[], None]: + """Register a callback for payload transmission events. + + Args: + callback: Function to call when payload is transmitted + include_failures: Whether to invoke callback on failures + + Returns: + Function to call to unregister the callback + + """ + + +class HttpJsonPostTransport(ITransport): + """HTTP JSON POST transport implementation. + + Sends telemetry data to OneCollector via HTTP POST with JSON payload. + """ + + def __init__( + self, + endpoint: str, + ikey: str, + compression: CompressionType, + session: requests.Session, + callback_manager: Optional["CallbackManager"] = None, + sdk_version: str = "OTel-python-1.0.0", + ): + """Initialize the HTTP transport. + + Args: + endpoint: OneCollector endpoint URL + ikey: Instrumentation key + compression: Compression type to use + session: Requests session for connection pooling + callback_manager: Optional callback manager for payload events + sdk_version: SDK version string + + """ + self.endpoint = endpoint + self.ikey = ikey + self.compression = compression + self.session = session + self.sdk_version = sdk_version + self.callback_manager = callback_manager + + # Build base headers + self.headers = { + "x-apikey": ikey, + "User-Agent": "Python/3 HttpClient", + "Host": "mobile.events.data.microsoft.com", + "Content-Type": "application/x-json-stream; charset=utf-8", + "sdk-version": sdk_version, + "NoResponseBody": "true", + } + + if compression != CompressionType.NO_COMPRESSION: + self.headers["Content-Encoding"] = compression.value + + def register_payload_transmitted_callback( + self, callback: Callable[["PayloadTransmittedCallbackArgs"], None], include_failures: bool = False + ) -> Callable[[], None]: + """Register a callback for payload transmission events. + + Args: + callback: Function to call when payload is transmitted + include_failures: Whether to invoke callback on failures + + Returns: + Function to call to unregister the callback + + """ + if self.callback_manager is None: + # Import here to avoid circular dependency + from olive.telemetry.library.callback_manager import CallbackManager + + self.callback_manager = CallbackManager() + + return self.callback_manager.register(callback, include_failures) + + def send(self, payload: bytes, timeout_sec: float, item_count: int = 1) -> tuple[bool, Optional[int]]: + """Send payload via HTTP POST. + + Args: + payload: Uncompressed payload bytes + timeout_sec: Request timeout in seconds + item_count: Number of items in the payload (for callbacks) + + Returns: + Tuple of (success, status_code) + + """ + payload_size_bytes = len(payload) + + try: + # Compress payload + compressed_payload = self._compress(payload) + + # Update headers with content length + headers = {**self.headers, "Content-Length": str(len(compressed_payload))} + + # Send request + try: + response = self.session.post( + url=self.endpoint, data=compressed_payload, headers=headers, timeout=timeout_sec + ) + except requests.exceptions.ConnectionError: + # Retry once on connection error + response = self.session.post( + url=self.endpoint, data=compressed_payload, headers=headers, timeout=timeout_sec + ) + + # Check response + success = response.ok + status_code = response.status_code + + # Invoke callbacks + if self.callback_manager: + from olive.telemetry.library.callback_manager import PayloadTransmittedCallbackArgs + + self.callback_manager.notify( + PayloadTransmittedCallbackArgs( + succeeded=success, + status_code=status_code, + payload_size_bytes=payload_size_bytes, + item_count=item_count, + payload_bytes=payload, + ) + ) + + if success: + return True, status_code + else: + # Log error response + if event_source.is_error_logging_enabled: + collector_error = response.headers.get("Collector-Error", "") + error_details = response.text[:100] if response.text else "" + event_source.http_transport_error_response( + "HttpJsonPost", status_code, collector_error, error_details + ) + return False, status_code + + except requests.exceptions.Timeout: + # Invoke failure callbacks + if self.callback_manager: + from olive.telemetry.library.callback_manager import PayloadTransmittedCallbackArgs + + self.callback_manager.notify( + PayloadTransmittedCallbackArgs( + succeeded=False, + status_code=None, + payload_size_bytes=payload_size_bytes, + item_count=item_count, + payload_bytes=payload, + ) + ) + + event_source.transport_exception_thrown("HttpJsonPost", Exception("Request timeout")) + return False, None + except Exception as ex: + # Invoke failure callbacks + if self.callback_manager: + from olive.telemetry.library.callback_manager import PayloadTransmittedCallbackArgs + + self.callback_manager.notify( + PayloadTransmittedCallbackArgs( + succeeded=False, + status_code=None, + payload_size_bytes=payload_size_bytes, + item_count=item_count, + payload_bytes=payload, + ) + ) + + event_source.transport_exception_thrown("HttpJsonPost", ex) + return False, None + + def _compress(self, data: bytes) -> bytes: + """Compress data according to configured compression type. + + Args: + data: Uncompressed data + + Returns: + Compressed data + + """ + if self.compression == CompressionType.DEFLATE: + # Raw deflate (no zlib header) + compressor = zlib.compressobj(wbits=-zlib.MAX_WBITS) + compressed = compressor.compress(data) + compressed += compressor.flush() + return compressed + + elif self.compression == CompressionType.GZIP: + gzip_buffer = BytesIO() + with gzip.GzipFile(fileobj=gzip_buffer, mode="w") as gzip_file: + gzip_file.write(data) + return gzip_buffer.getvalue() + + else: # NO_COMPRESSION + return data + + @staticmethod + def is_retryable(status_code: Optional[int]) -> bool: + """Check if a response status code indicates the request should be retried. + + Args: + status_code: HTTP status code, or None if request failed + + Returns: + True if request should be retried + + """ + if status_code is None: + return True # Network errors are retryable + + # Retryable status codes + return status_code in {408, 429, 500, 502, 503, 504} diff --git a/olive/telemetry/telemetry.py b/olive/telemetry/telemetry.py new file mode 100644 index 000000000..372ee6df8 --- /dev/null +++ b/olive/telemetry/telemetry.py @@ -0,0 +1,720 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Thin wrapper around the OneCollector telemetry logger with event helpers.""" + +import base64 +import errno +import json +import os +import platform +import threading +import time +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional + +from olive.telemetry.constants import CONNECTION_STRING +from olive.telemetry.deviceid import get_encrypted_device_id_and_status +from olive.telemetry.library.event_source import event_source +from olive.telemetry.library.telemetry_logger import TelemetryLogger, get_telemetry_logger +from olive.telemetry.utils import _exclusive_file_lock, get_telemetry_base_dir + +if TYPE_CHECKING: + from olive.telemetry.library.callback_manager import PayloadTransmittedCallbackArgs + +# Default event names used by the high-level telemetry helpers. +HEARTBEAT_EVENT_NAME = "OliveHeartbeat" +ACTION_EVENT_NAME = "OliveAction" +ERROR_EVENT_NAME = "OliveError" + +ALLOWED_KEYS = { + HEARTBEAT_EVENT_NAME: { + "device_id", + "id_status", + "os.name", + "os.version", + "os.release", + "os.arch", + "app_version", + "app_instance_id", + "initTs", + }, + ACTION_EVENT_NAME: { + "invoked_from", + "action_name", + "duration_ms", + "success", + "app_version", + "app_instance_id", + "initTs", + }, + ERROR_EVENT_NAME: { + "exception_type", + "exception_message", + "app_version", + "app_instance_id", + "initTs", + }, +} + +CRITICAL_EVENTS = {HEARTBEAT_EVENT_NAME} +MAX_CACHE_SIZE_BYTES = 5 * 1024 * 1024 +HARD_MAX_CACHE_SIZE_BYTES = 10 * 1024 * 1024 +CACHE_FILE_NAME = "olive.json" + + +class TelemetryCacheHandler: + """Handles caching of failed telemetry events for offline resilience. + + Design decisions: + - Single shared cache file (olive.json) for simplicity + - Cache writes are synchronous (fast JSON operations don't need async) + - Cache flush runs in a separate thread (slow network I/O) + - Flush triggered on success when cached events exist + - All critical sections protected by lock to prevent race conditions + - Newline-delimited JSON format for human readability and partial corruption recovery + + Assumptions: + - File I/O (JSON lines) is fast enough for synchronous execution (~microseconds) + - Network I/O is slow and should not block the callback thread + - Successful send indicates network is available to retry cached events + - Cache persists across sessions for offline resilience + """ + + def __init__(self, telemetry: "Telemetry") -> None: + self._telemetry = telemetry + # Single shared cache file for all processes + self._cache_file_name = CACHE_FILE_NAME + self._shutdown = False + # Protects all shared state to prevent race conditions + self._lock = threading.Lock() + self._callback_condition = threading.Condition() + self._callbacks_item_count = 0 + self._events_logged = 0 + # Prevents concurrent flush operations + self._is_flushing = False + + def shutdown(self) -> None: + """Signal shutdown to prevent new operations. + + Note: Does NOT flush the cache. Cache persists across sessions for + offline resilience. If network is working, success callbacks already + flushed. If network is down, flushing would fail anyway. + """ + with self._lock: + self._shutdown = True + + def __del__(self): + """Cleanup cache handler resources on garbage collection. + + Safety net to ensure shutdown is called even if not done explicitly. + """ + try: + self.shutdown() + except Exception: + # Silently ignore errors during cleanup + pass + + def on_payload_transmitted(self, args: "PayloadTransmittedCallbackArgs") -> None: + """Telemetry payload transmission callback. + + Design decisions: + - Ignore callbacks during flush (unlikely to fail during successful flush) + - On success: flush cache if any cached events exist + - On failure: write to cache immediately (synchronous for simplicity) + + Assumptions: + - Successful transmission indicates network is available to retry cached events + - If flush is in progress, we already successfully sent an event, so unlikely an event would suddenly fail + - Multiple concurrent successes don't need multiple flush operations + - Failed payloads should be cached immediately to avoid loss + """ + try: + payload = None + should_flush = False + + with self._lock: + if self._shutdown: + return + + # Skip callbacks from replayed events during flush + # If a flush is in progress it means we successfully sent an event, + # so it's unlikely that an event would suddenly fail and need to be cached + # and we don't need to flush again. + if self._is_flushing: + with self._callback_condition: + self._callbacks_item_count += args.item_count + self._callback_condition.notify_all() + return + + if args.succeeded: + # Only flush if cache exists and no flush is in progress + cache_path = self.cache_path + if cache_path and cache_path.exists(): + should_flush = True + else: + payload = args.payload_bytes + + if should_flush: + # Release lock before scheduling (flush runs in separate thread) + self._schedule_flush() + elif payload: + # Write synchronously - JSON operations are fast enough + self._write_payload_to_cache(payload) + except Exception: + # Fail silently - telemetry should never crash the application + pass + finally: + with self._callback_condition: + self._callbacks_item_count += args.item_count + self._callback_condition.notify_all() + + def wait_for_callbacks(self, timeout_sec: float) -> bool: + deadline = time.time() + timeout_sec + while True: + with self._callback_condition: + callbacks_item_count = self._callbacks_item_count + expected_items = self._events_logged + if not self.is_flushing and callbacks_item_count >= expected_items: + return True + remaining = deadline - time.time() + if remaining <= 0: + return False + with self._callback_condition: + self._callback_condition.wait(timeout=remaining) + + def record_event_logged(self, count: int = 1) -> None: + with self._callback_condition: + self._events_logged += count + + def _schedule_flush(self) -> None: + """Schedule cache flush in a separate thread to avoid blocking the callback. + + Design decisions: + - Check _is_flushing before spawning thread to avoid unnecessary threads + - Run flush in daemon thread (don't block process exit) + - Acquire lock at start to set _is_flushing flag atomically + - Always clear _is_flushing flag even if flush fails + + Assumptions: + - Flush operations are slow (network I/O) and should not block callbacks + - Daemon thread is acceptable (flush is best-effort) + """ + # Check before spawning thread to avoid unnecessary thread creation + with self._lock: + if self._shutdown or self._is_flushing: + return + self._is_flushing = True + + def flush_task(): + try: + self._flush_cache() + except Exception: + # Fail silently + pass + finally: + # Always clear flag, even on exception + with self._lock: + self._is_flushing = False + + thread = threading.Thread(target=flush_task, daemon=True) + thread.start() + + @property + def cache_path(self) -> Optional[Path]: + """Get the path to the telemetry cache file. + + Returns: + Optional[Path]: Path to cache file, or None if base directory unavailable. + + """ + telemetry_cache_dir = None + if "OLIVE_TELEMETRY_CACHE_DIR" in os.environ: + telemetry_cache_dir = os.environ["OLIVE_TELEMETRY_CACHE_DIR"] + if not telemetry_cache_dir: + telemetry_cache_dir = get_telemetry_base_dir() / "cache" + return telemetry_cache_dir / self._cache_file_name + + def _write_payload_to_cache(self, payload: bytes) -> None: + """Write failed telemetry payload to cache for later retry. + + Design decisions: + - Parse payload to extract individual events (allows filtering) + - Filter to only critical events near size limit (preserves important data) + - Use file locking for multi-process safety (prevents corruption) + - Use exponential backoff for file contention (avoids spinning) + - Fail silently on errors (telemetry should never crash app) + + Assumptions: + - JSON operations are fast enough for synchronous execution + - File contention is rare and transient (retry a few times) + - Cache size limits prevent unbounded growth + - Critical events (heartbeat) are more important than others + """ + try: + cache_path = self.cache_path + if cache_path is None: + return + + # Parse payload into individual events for filtering + entries = _parse_payload(payload) + if not entries: + return + + cache_path.parent.mkdir(parents=True, exist_ok=True) + + max_retries = 3 + for attempt in range(max_retries + 1): + try: + cache_size = cache_path.stat().st_size if cache_path.exists() else 0 + + # Hard limit: stop caching entirely to prevent unbounded growth + if cache_size >= HARD_MAX_CACHE_SIZE_BYTES: + return + + # Soft limit: keep only critical events to preserve space + if cache_size >= MAX_CACHE_SIZE_BYTES: + entries = [entry for entry in entries if entry["event_name"] in CRITICAL_EVENTS] + if not entries: + return + + # Append newline-delimited JSON (human-readable, partial corruption recovery) + # Use exclusive file lock for multi-process safety + with _exclusive_file_lock(cache_path, mode="a") as cache_file: + for entry in entries: + # Write compact JSON on single line + json.dump(entry, cache_file, ensure_ascii=False, separators=(",", ":")) + cache_file.write("\n") + return + except OSError as exc: + # Retry only on transient access errors (file locked by another process) + if exc.errno not in {errno.EACCES, errno.EAGAIN, errno.EWOULDBLOCK, errno.EBUSY}: + return + if attempt >= max_retries: + return + # Exponential backoff: 50ms, 100ms, 200ms (aligned with C# implementation) + time.sleep(0.05 * (2**attempt)) + except Exception: + # Fail silently - telemetry errors should not crash the application + return + + def _flush_cache(self) -> None: + """Flush this process's cached events back to telemetry service.""" + cache_path = self.cache_path + if cache_path is None or not cache_path.exists(): + return + + self._flush_cache_file(cache_path) + + def _flush_cache_file(self, cache_path: Path) -> None: + """Flush cached events back to telemetry service. + + Approach: + 1. Atomically rename cache → .flush (claims ownership, prevents concurrent flushes) + 2. Read all events from .flush file + 3. Queue all events for sending via telemetry logger + 4. Force flush with 2-second timeout + 5. On success: delete .flush file + 6. On failure: restore .flush → cache for retry + + Multi-process coordination: + - `replace()` is atomic; only one process can successfully rename the cache file + - If another process already renamed it, we get FileNotFoundError and abort + - Stale .flush files from crashes are overwritten by the atomic rename + + Shutdown handling: + - If shutdown flag set during flush, restore cache before returning + - This preserves events even if callbacks don't fire during shutdown + + Callback behavior: + - Queued events trigger callbacks with success/failure + - Failed events are automatically re-cached via callbacks (unless shutting down) + - The _is_flushing flag prevents re-caching of replayed events during flush + """ + flush_path = None + try: + # Check shutdown before starting (under lock to prevent race) + with self._lock: + if self._shutdown: + return + + if not cache_path.exists(): + return + + # Atomically rename to .flush file to claim ownership + # Overwrite any stale .flush file from crashed process (C# pattern) + flush_path = cache_path.with_name(f"{cache_path.name}.flush") + try: + # On Windows/POSIX, replace() overwrites existing files atomically + cache_path.replace(flush_path) + except FileNotFoundError: + # Cache already claimed by another flush or doesn't exist + return + + # Read all cached entries + entries = _read_cache_entries(flush_path) + + if not entries: + # Empty cache, just delete the flush file + flush_path.unlink(missing_ok=True) + return + + # Replay all events through telemetry logger + # Note: _is_flushing flag (set by caller) prevents these callbacks from re-caching or triggering nested flushes + # (unlikely since we just successfully sent an event, indicating network is available) + for entry in entries: + try: + event_name = entry["event_name"] + event_data = entry["event_data"] + if not event_name or not event_data: + continue + attributes = json.loads(event_data) + if not isinstance(attributes, dict): + continue + # Preserve original timestamp + attributes["initTs"] = entry.get("initTs", entry["ts"]) + self._telemetry.log(event_name, attributes, None) + except Exception: + # Skip malformed entries + continue + + # Check if shutdown happened during flush + with self._lock: + if self._shutdown: + # Restore cache to avoid data loss during shutdown + if flush_path and flush_path.exists(): + try: + cache_path.parent.mkdir(parents=True, exist_ok=True) + flush_path.replace(cache_path) + except Exception: + # Silently ignore errors during cleanup + pass + return + + # Cleanup based on flush result + flush_success = False + with self._callback_condition: + callbacks_item_count = self._callbacks_item_count + expected_items = self._events_logged + if callbacks_item_count >= expected_items: + flush_success = True + if flush_success: + # Success: delete the flush file (events were sent) + if flush_path: + flush_path.unlink(missing_ok=True) + elif flush_path and flush_path.exists(): + # Failure: restore cache for retry later + cache_path.parent.mkdir(parents=True, exist_ok=True) + flush_path.replace(cache_path) + except Exception: + # Best-effort restore on any exception to prevent data loss + if flush_path and flush_path.exists(): + try: + cache_path.parent.mkdir(parents=True, exist_ok=True) + flush_path.replace(cache_path) + except Exception: + # If restore fails, we lose the data (acceptable for telemetry) + pass + return + + @property + def is_flushing(self) -> bool: + with self._lock: + return self._is_flushing + + +class Telemetry: + """Wrapper that wires environment configuration into the library logger. + + This is a singleton class - all instances share the same state. + Use Telemetry() to get the singleton instance. + """ + + _instance: Optional["Telemetry"] = None + _lock = threading.Lock() + + def __new__(cls): + """Create or return the singleton instance. + + Thread-safe singleton implementation using double-checked locking. + """ + if cls._instance is None: + with cls._lock: + # Double-check pattern to prevent race conditions + if cls._instance is None: + instance = super().__new__(cls) + instance._initialized = False + cls._instance = instance + return cls._instance + + def __init__(self): + """Initialize the telemetry logger (only runs once for singleton).""" + # Prevent re-initialization + if self._initialized: + return + + self._logger = self._create_logger() + event_source.disable() + + self._cache_handler = TelemetryCacheHandler(self) + self._initialized = True + self._setup_payload_callbacks() + self._log_heartbeat() + if os.environ.get("OLIVE_DISABLE_TELEMETRY") == "1": + self.disable_telemetry() + + def _create_logger(self) -> Optional[TelemetryLogger]: + try: + return get_telemetry_logger(base64.b64decode(CONNECTION_STRING).decode()) + except Exception: + return None + + def _setup_payload_callbacks(self) -> None: + # Register callback for payload transmission events + # No need to store unregister function - logger shutdown will clean up callbacks + self._logger.register_payload_transmitted_callback( + self._cache_handler.on_payload_transmitted, + include_failures=True, + ) + + def add_global_metadata(self, metadata: dict[str, Any]) -> None: + """Add metadata to all telemetry events. + + Args: + metadata: Dictionary of metadata key-value pairs to add to all events. + These will be included in every telemetry event sent. + + Example: + >>> telemetry = Telemetry() + >>> telemetry.add_global_metadata({"user_id": "12345", "environment": "production"}) + + """ + self._logger.add_global_metadata(metadata) + + def log( + self, + event_name: str, + attributes: Optional[dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, + ) -> None: + """Log a telemetry event. + + Args: + event_name: Name of the event to log (e.g., "UserLogin", "ModelTrained"). + attributes: Optional dictionary of event-specific attributes. + metadata: Optional dictionary of additional metadata to merge with attributes. + + Example: + >>> telemetry = Telemetry() + >>> telemetry.log("ModelOptimized", {"model_type": "bert", "duration_ms": 1500}) + + """ + attrs = _merge_metadata(attributes, metadata) + self._logger.log(event_name, attrs) + if self._cache_handler: + self._cache_handler.record_event_logged() + + def _log_heartbeat( + self, + metadata: Optional[dict[str, Any]] = None, + ) -> None: + """Log a heartbeat event with system information. + + Args: + metadata: Optional additional metadata to include. + + """ + encrypted_device_id, device_id_status = get_encrypted_device_id_and_status() + attributes = { + "device_id": encrypted_device_id, + "id_status": device_id_status.value, + "os": { + "name": platform.system().lower(), + "version": platform.version(), + "release": platform.release(), + "arch": platform.machine(), + }, + } + self.log(HEARTBEAT_EVENT_NAME, attributes, metadata) + + def disable_telemetry(self) -> None: + """Disable all telemetry logging. + + After calling this method, no telemetry events will be sent until + telemetry is explicitly re-enabled. + """ + self._logger.disable_telemetry() + + def shutdown(self, timeout_millis: float = 10_000, callback_timeout_millis: float = 2_000) -> None: + """Shutdown telemetry and flush pending events. + + Shutdown sequence: + 1. Wait for in-flight flush to complete (up to 1 second) + 2. Wait for callbacks + signal shutdown to cache handler + 3. Shutdown logger (cleans up callbacks automatically) + """ + # Step 1: Wait for pending flush to complete (matches C# 1-second timeout) + start_time = time.time() + while time.time() - start_time < 1.0: + if not self._cache_handler or not self._cache_handler.is_flushing: + break + time.sleep(0.05) + + # Step 2: Wait for callbacks/flush to complete before shutting down cache handler + if self._cache_handler: + # Nothing can be done if callbacks don't complete in time, so we ignore the result + _ = self._cache_handler.wait_for_callbacks(callback_timeout_millis / 1000) + self._cache_handler.shutdown() + + # Step 3: Shutdown logger (callbacks cleaned up automatically) + self._logger.shutdown() + + def __del__(self): + """Cleanup telemetry resources on garbage collection. + + This is a safety net to ensure resources are cleaned up even if + shutdown() is not explicitly called. However, relying on __del__ + is not recommended - always call shutdown() explicitly when done. + """ + try: + self.shutdown() + except Exception: + # Silently ignore errors during cleanup + pass + + +def _get_logger() -> Telemetry: + """Get or create the singleton Telemetry instance.""" + return Telemetry() + + +def _merge_metadata(attributes: Optional[dict[str, Any]], metadata: Optional[dict[str, Any]]) -> dict[str, Any]: + merged = dict(attributes or {}) + if metadata: + merged.update(metadata) + return merged + + +def _parse_payload(payload: bytes) -> list[dict[str, Any]]: + """Parse telemetry payload into individual event entries. + + Design decisions: + - Filter events to only allowed keys (privacy/security) + - Store as minimal JSON (reduces cache size) + - Fail silently on malformed data (telemetry should be robust) + + Assumptions: + - Payload is newline-delimited JSON (OneCollector format) + - Events have "name", "time", and "data" fields + - Only whitelisted events and fields should be cached + """ + entries = [] + try: + payload_text = payload.decode("utf-8") + lines = payload_text.splitlines() + + for raw_line in lines: + line = raw_line.strip() + if not line: + continue + try: + event = json.loads(line) + event_name = event["name"] + if not event_name: + continue + # Filter to only allowed keys for privacy/security + filtered_data = _filter_event_data(event_name, event["data"]) + if not filtered_data: + continue + entries.append( + { + "ts": event["time"] or time.time(), + "event_name": event_name, + # Compact JSON to reduce cache size + "event_data": json.dumps(filtered_data, ensure_ascii=False, separators=(",", ":")), + } + ) + except Exception: + # Skip malformed lines + continue + except Exception: + # If entire payload is malformed, return empty list + return [] + + return entries + + +def _filter_event_data(event_name: str, data: dict[str, Any]) -> Optional[dict[str, Any]]: + """Filter event data to only allowed keys for privacy/security. + + Design decisions: + - Whitelist approach (only explicitly allowed keys are included) + - Support nested keys with dot notation (e.g., "os.name") + - Return None if no allowed keys found (filters out unknown events) + + Assumptions: + - ALLOWED_KEYS dict defines all cacheable events and their fields + - Unknown events should not be cached (privacy/security) + """ + if event_name not in ALLOWED_KEYS: + return None + allowed_keys = ALLOWED_KEYS[event_name] + + filtered: dict[str, Any] = {} + for key in allowed_keys: + value = _get_nested_value(data, key) + if value is None: + continue + _set_nested_value(filtered, key, value) + return filtered or None + + +def _get_nested_value(data: dict[str, Any], key: str) -> Any: + current = data + for part in key.split("."): + if not isinstance(current, dict) or part not in current: + return None + current = current[part] + return current + + +def _set_nested_value(data: dict[str, Any], key: str, value: Any) -> None: + current = data + parts = key.split(".") + for part in parts[:-1]: + current = current.setdefault(part, {}) + current[parts[-1]] = value + + +def _read_cache_entries(cache_path: Path) -> list[dict[str, Any]]: + """Read all entries from a cache file. + + Design decisions: + - Use file locking for multi-process safety + - Continue reading past malformed entries (partial data recovery) + - Return empty list on complete read failure (fail gracefully) + + Assumptions: + - Cache file contains newline-delimited JSON (one event per line) + - Each line is independent (one malformed line doesn't affect others) + - Empty or whitespace-only lines are skipped + """ + entries = [] + try: + with _exclusive_file_lock(cache_path, mode="r") as cache_file: + for raw_line in cache_file: + line = raw_line.strip() + if not line: + continue + try: + entry = json.loads(line) + if isinstance(entry, dict): + entries.append(entry) + except Exception: + # Malformed line, skip and continue + continue + except Exception: + # If file cannot be opened or read, return empty list + return [] + return entries diff --git a/olive/telemetry/telemetry_extensions.py b/olive/telemetry/telemetry_extensions.py new file mode 100644 index 000000000..e5b13395d --- /dev/null +++ b/olive/telemetry/telemetry_extensions.py @@ -0,0 +1,154 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import functools +import inspect +import time +from types import TracebackType +from typing import Any, Callable, Optional, TypeVar + +from olive.telemetry.telemetry import ACTION_EVENT_NAME, ERROR_EVENT_NAME, _get_logger +from olive.telemetry.utils import _format_exception_message + +_TFunc = TypeVar("_TFunc", bound=Callable[..., Any]) + + +def log_action( + invoked_from: str, + action_name: str, + duration_ms: float, + success: bool, + metadata: Optional[dict[str, Any]] = None, +) -> None: + telemetry = _get_logger() + attributes = { + "invoked_from": invoked_from, + "action_name": action_name, + "duration_ms": duration_ms, + "success": success, + } + telemetry.log(ACTION_EVENT_NAME, attributes, metadata) + + +def log_error( + exception_type: str, + exception_message: str, + metadata: Optional[dict[str, Any]] = None, +) -> None: + telemetry = _get_logger() + attributes = { + "exception_type": exception_type, + "exception_message": exception_message, + } + telemetry.log(ERROR_EVENT_NAME, attributes, metadata) + + +def _resolve_invoked_from(skip_frames: int = 0) -> str: + """Resolve how Olive was invoked by examining the call stack. + + Walks up the stack to find the first frame outside the olive package, + which indicates how the user invoked Olive (CLI, script, interactive, etc.). + + :param skip_frames: Number of additional frames to skip (for internal use). + :return: A string indicating how Olive was invoked. + """ + for frame_info in inspect.stack()[2 + skip_frames :]: # skip this function and caller + module = inspect.getmodule(frame_info.frame) + if module is None: + # Could be interactive or dynamically generated code + continue + module_name = module.__name__ + # Skip olive internals to find user code + if module_name.startswith("olive."): + continue + if module_name == "__main__": + return "Script" + return module_name + return "Interactive" + + +class ActionContext: + """Context manager for recording telemetry around a block of work.""" + + def __init__( + self, + action_name: str, + invoked_from: Optional[str] = None, + metadata: Optional[dict[str, Any]] = None, + ): + self.action_name = action_name + self.invoked_from = invoked_from if invoked_from is not None else _resolve_invoked_from() + self.metadata = metadata or {} + self._start_time: Optional[float] = None + + def add_metadata(self, key: str, value: Any) -> None: + self.metadata[key] = value + + def __enter__(self) -> "ActionContext": + self._start_time = time.perf_counter() + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + duration_ms = int((time.perf_counter() - (self._start_time or time.perf_counter())) * 1000) + success = exc_type is None + + log_action( + invoked_from=self.invoked_from, + action_name=self.action_name, + duration_ms=duration_ms, + success=success, + metadata=self.metadata, + ) + + if exc_type is not None and exc_val is not None: + log_error( + exception_type=exc_type.__name__, + exception_message=_format_exception_message(exc_val, exc_tb), + metadata=self.metadata, + ) + + # Do not suppress exceptions + return False + + +def action(func: _TFunc) -> _TFunc: + """Record telemetry around a function call.""" + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any): + invoked_from = _resolve_invoked_from() + action_name = func.__name__ + if args and hasattr(args[0], "__class__"): + cls_name = args[0].__class__.__name__ + cls_name = cls_name[: -len("Command")] if cls_name.endswith("Command") else cls_name + if cls_name: + action_name = cls_name if action_name == "run" else f"{cls_name}.{action_name}" + + start_time = time.perf_counter() + success = True + try: + return func(*args, **kwargs) + except Exception as exc: + success = False + log_error( + exception_type=type(exc).__name__, + exception_message=_format_exception_message(exc, exc.__traceback__), + ) + raise + finally: + duration_ms = int((time.perf_counter() - start_time) * 1000) + log_action( + invoked_from=invoked_from, + action_name=action_name, + duration_ms=duration_ms, + success=success, + ) + + return wrapper # type: ignore[return-value] diff --git a/olive/telemetry/utils.py b/olive/telemetry/utils.py new file mode 100644 index 000000000..b3ef57e66 --- /dev/null +++ b/olive/telemetry/utils.py @@ -0,0 +1,108 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import functools +import os +import platform +import traceback +from pathlib import Path +from types import TracebackType +from typing import Optional + +ORT_SUPPORT_DIR = r"Microsoft/DeveloperTools/.onnxruntime" + + +@property +@functools.lru_cache(maxsize=1) +def get_telemetry_base_dir() -> Path: + os_name = platform.system() + if os_name == "Windows": + base_dir = os.environ.get("LOCALAPPDATA") or os.environ.get("APPDATA") + if not base_dir: + base_dir = str(Path.home() / "AppData" / "Local") + return Path(base_dir) / "Microsoft" / ".onnxruntime" + + if os_name == "Darwin": + home = os.getenv("HOME") + if home is None: + raise ValueError("HOME environment variable not set") + return Path(home) / "Library" / "Application Support" / ORT_SUPPORT_DIR + + home = os.getenv("XDG_CACHE_HOME", f"{os.getenv('HOME')}/.cache") + if not home: + raise ValueError("HOME environment variable not set") + + return Path(home) / ORT_SUPPORT_DIR + + +def _format_exception_message(ex: BaseException, tb: Optional[TracebackType] = None) -> str: + """Format an exception and trim local paths for readability.""" + folder = "Olive" + file_line = 'File "' + formatted = traceback.format_exception(type(ex), ex, tb, limit=5) + lines = [] + for line in formatted: + line_trunc = line.strip() + if line_trunc.startswith(file_line) and folder in line_trunc: + idx = line_trunc.find(folder) + if idx != -1: + line_trunc = line_trunc[idx + len(folder) :] + elif line_trunc.startswith(file_line): + idx = line_trunc[len(file_line) :].find('"') + line_trunc = line_trunc[idx + len(file_line) :] + lines.append(line_trunc) + return "\n".join(lines) + + +class _ExclusiveFileLock: + """Cross-platform exclusive file lock context manager. + + Uses fcntl on Unix/Linux/macOS, msvcrt on Windows. + Prevents cache corruption when multiple processes access the same file. + + Design decisions: + - Lock is held for the entire duration of file access (prevents partial reads/writes) + - Lock is released automatically on close (even on exceptions) + - Platform-specific implementation (fcntl for POSIX, msvcrt for Windows) + + Assumptions: + - File locking is supported on the platform + - Lock is advisory on some systems (cooperative locking) + """ + + def __init__(self, file_path: Path, mode: str): + self.file_path = file_path + self.mode = mode + self.file = None + + def __enter__(self): + self.file = open(self.file_path, self.mode, encoding="utf-8") + + # Platform-specific locking + if os.name == "posix": + import fcntl + + fcntl.flock(self.file.fileno(), fcntl.LOCK_EX) + elif os.name == "nt": + import msvcrt + + # Lock 1 byte at position 0 + msvcrt.locking(self.file.fileno(), msvcrt.LK_LOCK, 1) + + return self.file + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.file: + # Unlock happens automatically on close + self.file.close() + + +def _exclusive_file_lock(file_path: Path, mode: str): + """Create an exclusive file lock context manager. + + :param file_path: Path to the file to lock. + :param mode: File open mode ('r', 'a', 'w', etc.). + :return: Context manager that returns an open file handle. + """ + return _ExclusiveFileLock(file_path, mode) diff --git a/olive/version.py b/olive/version.py new file mode 100644 index 000000000..994be33c1 --- /dev/null +++ b/olive/version.py @@ -0,0 +1 @@ +__version__ = "0.11.0.dev0" diff --git a/requirements.txt b/requirements.txt index 9534506de..400f0bfe7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,8 @@ hf-xet numpy onnx onnx_ir>=0.1.2 -onnxscript>=0.3.0 +onnxscript>=0.5.3 +opentelemetry-sdk>=1.39.1 optuna pandas pydantic>=2.0 diff --git a/setup.py b/setup.py index 8678bb60b..b4aebf070 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ def get_extra_deps(rel_path): # use techniques described at https://packaging.python.org/en/latest/guides/single-sourcing-package-version/ # Don't use technique 6 since it needs extra dependencies. -VERSION = get_version("olive/__init__.py") +VERSION = get_version("olive/version.py") EXTRAS = get_extra_deps("olive/olive_config.json") with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")) as req_file: @@ -50,7 +50,6 @@ def get_extra_deps(rel_path): "Topic :: Software Development :: Libraries :: Python Modules", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", diff --git a/test/conftest.py b/test/conftest.py index c57411ed7..db97c685a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -8,6 +8,7 @@ import pytest from packaging import version +from olive.telemetry.telemetry import Telemetry from test.utils import create_onnx_model_file, delete_onnx_model_files @@ -39,3 +40,8 @@ def maybe_patch_inc(): yield else: yield + + +@pytest.fixture(scope="session", autouse=True) +def disable_telemetry(): + Telemetry().disable_telemetry()